feat: add river-backed async task queue
This commit is contained in:
parent
d69aaed444
commit
7e220b7477
@ -33,18 +33,19 @@ func main() {
|
||||
if recovery, err := db.RecoverInterruptedRuntimeState(ctx); err != nil {
|
||||
logger.Error("recover interrupted runtime state failed", "error", err)
|
||||
os.Exit(1)
|
||||
} else if recovery.ReleasedConcurrencyLeases > 0 || recovery.ReleasedRateReservations > 0 || recovery.FailedAttempts > 0 || recovery.FailedTasks > 0 {
|
||||
} else if recovery.ReleasedConcurrencyLeases > 0 || recovery.ReleasedRateReservations > 0 || recovery.FailedAttempts > 0 || recovery.FailedTasks > 0 || recovery.RequeuedAsyncTasks > 0 {
|
||||
logger.Warn("interrupted runtime state recovered",
|
||||
"releasedConcurrencyLeases", recovery.ReleasedConcurrencyLeases,
|
||||
"releasedRateReservations", recovery.ReleasedRateReservations,
|
||||
"failedAttempts", recovery.FailedAttempts,
|
||||
"failedTasks", recovery.FailedTasks,
|
||||
"requeuedAsyncTasks", recovery.RequeuedAsyncTasks,
|
||||
)
|
||||
}
|
||||
|
||||
server := &http.Server{
|
||||
Addr: cfg.HTTPAddr,
|
||||
Handler: httpapi.NewServer(cfg, db, logger),
|
||||
Handler: httpapi.NewServerWithContext(ctx, cfg, db, logger),
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
|
||||
@ -4,14 +4,28 @@ go 1.23
|
||||
|
||||
require (
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2
|
||||
github.com/jackc/pgx/v5 v5.7.2
|
||||
golang.org/x/crypto v0.31.0
|
||||
github.com/jackc/pgx/v5 v5.9.2
|
||||
github.com/riverqueue/river v0.24.0
|
||||
github.com/riverqueue/river/riverdriver/riverpgxv5 v0.24.0
|
||||
github.com/riverqueue/river/rivertype v0.24.0
|
||||
golang.org/x/crypto v0.37.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
golang.org/x/sync v0.10.0 // indirect
|
||||
golang.org/x/text v0.21.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/riverqueue/river/riverdriver v0.24.0 // indirect
|
||||
github.com/riverqueue/river/rivershared v0.24.0 // indirect
|
||||
github.com/stretchr/testify v1.11.1 // indirect
|
||||
github.com/tidwall/gjson v1.18.0 // indirect
|
||||
github.com/tidwall/match v1.2.0 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
go.uber.org/goleak v1.3.0 // indirect
|
||||
golang.org/x/sync v0.20.0 // indirect
|
||||
golang.org/x/text v0.36.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
@ -3,28 +3,63 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438 h1:Dj0L5fhJ9F82ZJyVOmBx6msDp/kfd1t9GRfny/mfJA0=
|
||||
github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438/go.mod h1:a/s9Lp5W7n/DD0VrVoyJ00FbP2ytTPDVOivvn2bMlds=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.7.2 h1:mLoDLV6sonKlvjIEsV56SkWNCnuNv531l94GaIzO+XI=
|
||||
github.com/jackc/pgx/v5 v5.7.2/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ=
|
||||
github.com/jackc/pgx/v5 v5.9.2 h1:3ZhOzMWnR4yJ+RW1XImIPsD1aNSz4T4fyP7zlQb56hw=
|
||||
github.com/jackc/pgx/v5 v5.9.2/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
|
||||
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/riverqueue/river v0.24.0 h1:CesL6vymWgz0d+zNwtnSGRWaB+E8Dax+o9cxD7sUmKc=
|
||||
github.com/riverqueue/river v0.24.0/go.mod h1:UZ3AxU5t6WtyqNssaea/AkRS8h/kJ+E9ImSB3xyb3ns=
|
||||
github.com/riverqueue/river/riverdriver v0.24.0 h1:HqGgGkls11u+YKDA7cKOdYKlQwRNJyHuGa3UtOvpdT0=
|
||||
github.com/riverqueue/river/riverdriver v0.24.0/go.mod h1:dEew9DDIKenNvzpm8Edw8+PkqP3c0zl1fKjiQTq2n/w=
|
||||
github.com/riverqueue/river/riverdriver/riverpgxv5 v0.24.0 h1:yV37OIbRrhRwIiGeRT7P4D3szhAemu87BgCf8gTCoU4=
|
||||
github.com/riverqueue/river/riverdriver/riverpgxv5 v0.24.0/go.mod h1:QfznySVKC4ljx53syd/bA/LRSsydAyuD3Q9/EbSniKA=
|
||||
github.com/riverqueue/river/rivershared v0.24.0 h1:KysokksW75pug2a5RTOc6WESOupWmsylVc6VWvAx+4Y=
|
||||
github.com/riverqueue/river/rivershared v0.24.0/go.mod h1:UIBfSdai0oWFlwFcoqG4DZX83iA/fLWTEBGrj7Oe1ho=
|
||||
github.com/riverqueue/river/rivertype v0.24.0 h1:xrQZm/h6U8TBPyTsQPYD5leOapuoBAcdz30bdBwTqOg=
|
||||
github.com/riverqueue/river/rivertype v0.24.0/go.mod h1:lmdl3vLNDfchDWbYdW2uAocIuwIN+ZaXqAukdSCFqWs=
|
||||
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
|
||||
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
|
||||
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
|
||||
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
|
||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
|
||||
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/match v1.2.0 h1:0pt8FlkOwjN2fPt4bIl4BoNxb98gGHN2ObFEDkrfZnM=
|
||||
github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
|
||||
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
|
||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
|
||||
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
@ -329,6 +329,7 @@ func TestVolcesClientVideoSubmitsAndPollsTask(t *testing.T) {
|
||||
var gotModel string
|
||||
var gotText string
|
||||
var gotFirstFrameRole string
|
||||
var submittedRemoteTaskID string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotAuth = r.Header.Get("Authorization")
|
||||
switch r.Method + " " + r.URL.Path {
|
||||
@ -385,6 +386,13 @@ func TestVolcesClientVideoSubmitsAndPollsTask(t *testing.T) {
|
||||
"volcesPollTimeoutSeconds": 1,
|
||||
},
|
||||
},
|
||||
OnRemoteTaskSubmitted: func(remoteTaskID string, payload map[string]any) error {
|
||||
submittedRemoteTaskID = remoteTaskID
|
||||
if payload["id"] != "cgt-test" {
|
||||
t.Fatalf("unexpected submitted payload: %+v", payload)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("run volces video: %v", err)
|
||||
@ -392,6 +400,9 @@ func TestVolcesClientVideoSubmitsAndPollsTask(t *testing.T) {
|
||||
if submitPath != "/contents/generations/tasks" || pollPath != "/contents/generations/tasks/cgt-test" || gotAuth != "Bearer volces-key" {
|
||||
t.Fatalf("unexpected paths/auth submit=%s poll=%s auth=%s", submitPath, pollPath, gotAuth)
|
||||
}
|
||||
if submittedRemoteTaskID != "cgt-test" {
|
||||
t.Fatalf("remote task submit callback did not receive task id, got %q", submittedRemoteTaskID)
|
||||
}
|
||||
if gotModel != "doubao-seedance-2-0-260128" || gotFirstFrameRole != "first_frame" {
|
||||
t.Fatalf("unexpected submitted model=%s role=%s", gotModel, gotFirstFrameRole)
|
||||
}
|
||||
@ -407,6 +418,53 @@ func TestVolcesClientVideoSubmitsAndPollsTask(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestVolcesClientVideoResumePollsExistingTaskID(t *testing.T) {
|
||||
var submitCalled bool
|
||||
var pollPath string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method + " " + r.URL.Path {
|
||||
case "POST /contents/generations/tasks":
|
||||
submitCalled = true
|
||||
t.Fatalf("resume should skip upstream submit when remote task id exists")
|
||||
case "GET /contents/generations/tasks/cgt-existing":
|
||||
pollPath = r.URL.Path
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "cgt-existing",
|
||||
"status": "succeeded",
|
||||
"created_at": 789,
|
||||
"content": map[string]any{"video_url": "https://example.com/resumed.mp4"},
|
||||
})
|
||||
default:
|
||||
t.Fatalf("unexpected request %s %s", r.Method, r.URL.Path)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
response, err := (VolcesClient{HTTPClient: server.Client()}).Run(context.Background(), Request{
|
||||
Kind: "videos.generations",
|
||||
ModelType: "video_generate",
|
||||
Model: "豆包Seedance-2.0",
|
||||
Body: map[string]any{"prompt": "resume polling", "pollIntervalMs": 100, "pollTimeoutSeconds": 1},
|
||||
RemoteTaskID: "cgt-existing",
|
||||
Candidate: store.RuntimeModelCandidate{
|
||||
BaseURL: server.URL,
|
||||
ModelName: "豆包Seedance-2.0",
|
||||
Credentials: map[string]any{"apiKey": "volces-key"},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("resume volces video: %v", err)
|
||||
}
|
||||
if submitCalled || pollPath != "/contents/generations/tasks/cgt-existing" {
|
||||
t.Fatalf("resume should poll existing task only, submit=%v poll=%s", submitCalled, pollPath)
|
||||
}
|
||||
data, _ := response.Result["data"].([]any)
|
||||
item, _ := data[0].(map[string]any)
|
||||
if response.Result["upstream_task_id"] != "cgt-existing" || item["url"] != "https://example.com/resumed.mp4" {
|
||||
t.Fatalf("unexpected resumed response: %+v", response.Result)
|
||||
}
|
||||
}
|
||||
|
||||
func extractText(result map[string]any) string {
|
||||
choices, _ := result["choices"].([]any)
|
||||
choice, _ := choices[0].(map[string]any)
|
||||
|
||||
@ -11,14 +11,17 @@ import (
|
||||
)
|
||||
|
||||
type Request struct {
|
||||
Kind string
|
||||
ModelType string
|
||||
Model string
|
||||
Body map[string]any
|
||||
Candidate store.RuntimeModelCandidate
|
||||
HTTPClient *http.Client
|
||||
Stream bool
|
||||
StreamDelta StreamDelta
|
||||
Kind string
|
||||
ModelType string
|
||||
Model string
|
||||
Body map[string]any
|
||||
Candidate store.RuntimeModelCandidate
|
||||
HTTPClient *http.Client
|
||||
RemoteTaskID string
|
||||
RemoteTaskPayload map[string]any
|
||||
OnRemoteTaskSubmitted func(remoteTaskID string, payload map[string]any) error
|
||||
Stream bool
|
||||
StreamDelta StreamDelta
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
|
||||
@ -67,16 +67,25 @@ func (c VolcesClient) runImage(ctx context.Context, request Request, apiKey stri
|
||||
}
|
||||
|
||||
func (c VolcesClient) runVideo(ctx context.Context, request Request, apiKey string) (Response, error) {
|
||||
body := volcesVideoBody(request)
|
||||
submitStartedAt := time.Now()
|
||||
submitResult, submitRequestID, err := c.postJSON(ctx, request, request.Candidate.BaseURL, "/contents/generations/tasks", apiKey, body)
|
||||
submitFinishedAt := time.Now()
|
||||
if err != nil {
|
||||
return Response{}, annotateResponseError(err, submitRequestID, submitStartedAt, submitFinishedAt)
|
||||
}
|
||||
upstreamTaskID := strings.TrimSpace(stringFromAny(submitResult["id"]))
|
||||
submitRequestID := strings.TrimSpace(request.RemoteTaskID)
|
||||
upstreamTaskID := strings.TrimSpace(request.RemoteTaskID)
|
||||
if upstreamTaskID == "" {
|
||||
return Response{}, &ClientError{Code: "invalid_response", Message: "volces video task id is missing", RequestID: submitRequestID, Retryable: false}
|
||||
body := volcesVideoBody(request)
|
||||
submitResult, requestID, err := c.postJSON(ctx, request, request.Candidate.BaseURL, "/contents/generations/tasks", apiKey, body)
|
||||
submitRequestID = requestID
|
||||
if err != nil {
|
||||
return Response{}, annotateResponseError(err, submitRequestID, submitStartedAt, time.Now())
|
||||
}
|
||||
upstreamTaskID = strings.TrimSpace(stringFromAny(submitResult["id"]))
|
||||
if upstreamTaskID == "" {
|
||||
return Response{}, &ClientError{Code: "invalid_response", Message: "volces video task id is missing", RequestID: submitRequestID, Retryable: false}
|
||||
}
|
||||
if request.OnRemoteTaskSubmitted != nil {
|
||||
if err := request.OnRemoteTaskSubmitted(upstreamTaskID, submitResult); err != nil {
|
||||
return Response{}, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
interval := volcesPollInterval(request)
|
||||
|
||||
@ -18,6 +18,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
@ -37,7 +38,11 @@ func TestCoreLocalFlow(t *testing.T) {
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
handler := NewServer(config.Config{
|
||||
assertRuntimeRecoveryReleasesPendingRateReservations(t, ctx, db)
|
||||
|
||||
serverCtx, cancelServer := context.WithCancel(ctx)
|
||||
defer cancelServer()
|
||||
handler := NewServerWithContext(serverCtx, config.Config{
|
||||
AppEnv: "test",
|
||||
HTTPAddr: ":0",
|
||||
DatabaseURL: databaseURL,
|
||||
@ -169,13 +174,14 @@ func TestCoreLocalFlow(t *testing.T) {
|
||||
if err := testPool.QueryRow(ctx, `SELECT count(*) FROM gateway_tasks`).Scan(&taskCountBefore); err != nil {
|
||||
t.Fatalf("count tasks before scoped request: %v", err)
|
||||
}
|
||||
defaultImageModel := "openai:gpt-image-1"
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/v1/images/generations", chatOnlyAPIKeyResponse.Secret, map[string]any{
|
||||
"model": "gpt-image-1",
|
||||
"model": defaultImageModel,
|
||||
"prompt": "scope should block this",
|
||||
}, http.StatusForbidden, nil)
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/v1/pricing/estimate", chatOnlyAPIKeyResponse.Secret, map[string]any{
|
||||
"kind": "images.generations",
|
||||
"model": "gpt-image-1",
|
||||
"model": defaultImageModel,
|
||||
"prompt": "scope should block this estimate",
|
||||
}, http.StatusForbidden, nil)
|
||||
var taskCountAfter int
|
||||
@ -309,8 +315,9 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil {
|
||||
Result map[string]any `json:"result"`
|
||||
} `json:"task"`
|
||||
}
|
||||
defaultTextModel := "openai:gpt-4o-mini"
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
|
||||
"model": "gpt-4o-mini",
|
||||
"model": defaultTextModel,
|
||||
"runMode": "simulation",
|
||||
"simulation": true,
|
||||
"simulationDurationMs": 5,
|
||||
@ -334,7 +341,7 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil {
|
||||
|
||||
var compatChat map[string]any
|
||||
doJSON(t, server.URL, http.MethodPost, "/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
|
||||
"model": "gpt-4o-mini",
|
||||
"model": defaultTextModel,
|
||||
"runMode": "simulation",
|
||||
"messages": []map[string]any{{"role": "user", "content": "ping"}},
|
||||
"simulation": true,
|
||||
@ -352,7 +359,7 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil {
|
||||
} `json:"task"`
|
||||
}
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/v1/images/generations", apiKeyResponse.Secret, map[string]any{
|
||||
"model": "gpt-image-1",
|
||||
"model": defaultImageModel,
|
||||
"runMode": "simulation",
|
||||
"prompt": "a tiny gateway console",
|
||||
"size": "1024x1024",
|
||||
@ -372,7 +379,7 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil {
|
||||
} `json:"task"`
|
||||
}
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/v1/images/edits", apiKeyResponse.Secret, map[string]any{
|
||||
"model": "gpt-image-1",
|
||||
"model": defaultImageModel,
|
||||
"runMode": "simulation",
|
||||
"prompt": "replace background with clean studio light",
|
||||
"image": "https://example.com/source.png",
|
||||
@ -384,6 +391,42 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil {
|
||||
t.Fatalf("unexpected image edit task: %+v", imageEditResponse.Task)
|
||||
}
|
||||
|
||||
doubaoLiteImageEditModel := "doubao-5.0-lite图像编辑"
|
||||
var doubaoLitePlatformModel struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/admin/platforms/"+platform.ID+"/models", loginResponse.AccessToken, map[string]any{
|
||||
"canonicalModelKey": "easyai:doubao-5.0-lite图像编辑",
|
||||
"modelName": doubaoLiteImageEditModel,
|
||||
"modelAlias": doubaoLiteImageEditModel,
|
||||
"modelType": []string{"image_edit", "image_generate"},
|
||||
"displayName": doubaoLiteImageEditModel,
|
||||
}, http.StatusCreated, &doubaoLitePlatformModel)
|
||||
var doubaoLiteAsyncEdit struct {
|
||||
TaskID string `json:"taskId"`
|
||||
Task struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
AsyncMode bool `json:"asyncMode"`
|
||||
} `json:"task"`
|
||||
}
|
||||
doJSONWithHeaders(t, server.URL, http.MethodPost, "/api/v1/images/edits", apiKeyResponse.Secret, map[string]any{
|
||||
"model": doubaoLiteImageEditModel,
|
||||
"runMode": "simulation",
|
||||
"prompt": "turn the attached bright desktop object into a clean product-style render",
|
||||
"image": "https://example.com/doubao-lite-source.png",
|
||||
"size": "2048x2048",
|
||||
"simulation": true,
|
||||
"simulationDurationMs": 5,
|
||||
}, map[string]string{"X-Async": "true"}, http.StatusAccepted, &doubaoLiteAsyncEdit)
|
||||
if doubaoLiteAsyncEdit.TaskID == "" || !doubaoLiteAsyncEdit.Task.AsyncMode {
|
||||
t.Fatalf("doubao-5.0-lite image edit async task should be accepted: %+v", doubaoLiteAsyncEdit)
|
||||
}
|
||||
doubaoLiteAsyncEditDone := waitForTaskStatus(t, server.URL, apiKeyResponse.Secret, doubaoLiteAsyncEdit.TaskID, []string{"succeeded"}, 5*time.Second)
|
||||
if doubaoLiteAsyncEditDone.Status != "succeeded" {
|
||||
t.Fatalf("doubao-5.0-lite image edit async task should succeed through river queue, got %+v", doubaoLiteAsyncEditDone)
|
||||
}
|
||||
|
||||
var gptImageModelTypesRaw []byte
|
||||
if err := testPool.QueryRow(ctx, `
|
||||
SELECT model_type
|
||||
@ -641,6 +684,7 @@ WHERE reference_type = 'gateway_task'
|
||||
}
|
||||
|
||||
rateLimitedModel := "rate-limit-smoke-" + suffixText
|
||||
rateLimitWindowSeconds := 3
|
||||
var rateLimitPolicySet struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
@ -652,7 +696,7 @@ WHERE reference_type = 'gateway_task'
|
||||
"maxAttempts": 1,
|
||||
},
|
||||
"rateLimitPolicy": map[string]any{
|
||||
"rules": []map[string]any{{"metric": "rpm", "limit": 1, "windowSeconds": 60}},
|
||||
"rules": []map[string]any{{"metric": "rpm", "limit": 1, "windowSeconds": rateLimitWindowSeconds}},
|
||||
},
|
||||
}, http.StatusCreated, &rateLimitPolicySet)
|
||||
var rateLimitPlatformModel map[string]any
|
||||
@ -682,6 +726,7 @@ WHERE reference_type = 'gateway_task'
|
||||
if rateLimitFailedTask.Task.Status != "failed" || rateLimitFailedTask.Task.ErrorCode != "bad_request" {
|
||||
t.Fatalf("failed rate-limited task should fail before consuming rpm: %+v", rateLimitFailedTask.Task)
|
||||
}
|
||||
waitForRateLimitWindowHead(t, rateLimitWindowSeconds)
|
||||
var rateLimitTaskOne struct {
|
||||
Task struct {
|
||||
Status string `json:"status"`
|
||||
@ -713,6 +758,38 @@ WHERE reference_type = 'gateway_task'
|
||||
if rateLimitTaskTwo.Task.Status != "failed" || rateLimitTaskTwo.Task.ErrorCode != "rate_limit" {
|
||||
t.Fatalf("runtime policy rate limit should fail second task with rate_limit: %+v", rateLimitTaskTwo.Task)
|
||||
}
|
||||
var asyncRateLimitTask struct {
|
||||
TaskID string `json:"taskId"`
|
||||
Task struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
AsyncMode bool `json:"asyncMode"`
|
||||
} `json:"task"`
|
||||
}
|
||||
doJSONWithHeaders(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
|
||||
"model": rateLimitedModel,
|
||||
"runMode": "simulation",
|
||||
"simulation": true,
|
||||
"simulationDurationMs": 5,
|
||||
"messages": []map[string]any{{"role": "user", "content": "async queued"}},
|
||||
}, map[string]string{"X-Async": "true"}, http.StatusAccepted, &asyncRateLimitTask)
|
||||
if asyncRateLimitTask.TaskID == "" || asyncRateLimitTask.Task.ID != asyncRateLimitTask.TaskID || !asyncRateLimitTask.Task.AsyncMode {
|
||||
t.Fatalf("async task response should expose task id and async mode: %+v", asyncRateLimitTask)
|
||||
}
|
||||
asyncRateLimitDetail := waitForTaskStatus(t, server.URL, apiKeyResponse.Secret, asyncRateLimitTask.TaskID, []string{"queued"}, 2*time.Second)
|
||||
if asyncRateLimitDetail.Status != "queued" {
|
||||
t.Fatalf("async rate-limited task should return to queued state, got %+v", asyncRateLimitDetail)
|
||||
}
|
||||
if len(asyncRateLimitDetail.Attempts) == 0 || asyncRateLimitDetail.Attempts[0].ErrorCode != "rate_limit" {
|
||||
t.Fatalf("async rate-limited task should record a rate_limit attempt before requeue: %+v", asyncRateLimitDetail)
|
||||
}
|
||||
asyncRateLimitCompleted := waitForTaskStatus(t, server.URL, apiKeyResponse.Secret, asyncRateLimitTask.TaskID, []string{"succeeded"}, time.Duration(rateLimitWindowSeconds+3)*time.Second)
|
||||
if asyncRateLimitCompleted.Status != "succeeded" {
|
||||
t.Fatalf("async rate-limited task should be pulled from queue after the limit window resets, got %+v", asyncRateLimitCompleted)
|
||||
}
|
||||
if len(asyncRateLimitCompleted.Attempts) < 2 || asyncRateLimitCompleted.Attempts[len(asyncRateLimitCompleted.Attempts)-1].Status != "succeeded" {
|
||||
t.Fatalf("async rate-limited task should create a new successful attempt after requeue: %+v", asyncRateLimitCompleted)
|
||||
}
|
||||
|
||||
videoRouteModel := "video-route-smoke-" + suffixText
|
||||
var videoRoutePlatformModel map[string]any
|
||||
@ -823,16 +900,19 @@ WHERE reference_type = 'gateway_task'
|
||||
Metrics map[string]any `json:"metrics"`
|
||||
}
|
||||
doJSON(t, server.URL, http.MethodGet, "/api/v1/tasks/"+failoverTask.Task.ID, apiKeyResponse.Secret, nil, http.StatusOK, &failoverDetail)
|
||||
if len(failoverDetail.Attempts) != 2 {
|
||||
t.Fatalf("failover task history should include two attempts, got %+v", failoverDetail.Attempts)
|
||||
if len(failoverDetail.Attempts) != 3 {
|
||||
t.Fatalf("failover task history should include two failed retries plus one successful failover attempt, got %+v", failoverDetail.Attempts)
|
||||
}
|
||||
if failoverDetail.Attempts[0].PlatformName != "OpenAI Retryable Failure" || failoverDetail.Attempts[0].Status != "failed" || !failoverDetail.Attempts[0].Retryable || failoverDetail.Attempts[0].ErrorCode == "" {
|
||||
t.Fatalf("first failover attempt should preserve failed platform and reason: %+v", failoverDetail.Attempts[0])
|
||||
}
|
||||
if failoverDetail.Attempts[1].PlatformName != "OpenAI Retry Success" || failoverDetail.Attempts[1].Status != "succeeded" || failoverDetail.Attempts[1].ResponseMS <= 0 {
|
||||
t.Fatalf("second failover attempt should preserve successful platform: %+v", failoverDetail.Attempts[1])
|
||||
if failoverDetail.Attempts[1].PlatformName != "OpenAI Retryable Failure" || failoverDetail.Attempts[1].Status != "failed" || !failoverDetail.Attempts[1].Retryable {
|
||||
t.Fatalf("second failover attempt should preserve the same-client retry failure: %+v", failoverDetail.Attempts[1])
|
||||
}
|
||||
if summary, ok := failoverDetail.Metrics["attempts"].([]any); !ok || len(summary) != 2 {
|
||||
if failoverDetail.Attempts[2].PlatformName != "OpenAI Retry Success" || failoverDetail.Attempts[2].Status != "succeeded" || failoverDetail.Attempts[2].ResponseMS <= 0 {
|
||||
t.Fatalf("third failover attempt should preserve successful platform: %+v", failoverDetail.Attempts[2])
|
||||
}
|
||||
if summary, ok := failoverDetail.Metrics["attempts"].([]any); !ok || len(summary) != 3 {
|
||||
t.Fatalf("task metrics should keep attempt-chain summary, got %+v", failoverDetail.Metrics)
|
||||
}
|
||||
|
||||
@ -1045,6 +1125,9 @@ WHERE m.platform_id = $1::uuid
|
||||
if resp.StatusCode != http.StatusOK || !bytes.Contains(body, []byte("task.completed")) {
|
||||
t.Fatalf("unexpected events response status=%d body=%s", resp.StatusCode, string(body))
|
||||
}
|
||||
if !bytes.Contains(body, []byte("task.running")) {
|
||||
t.Fatalf("events response should include running transition event body=%s", string(body))
|
||||
}
|
||||
if !bytes.Contains(body, []byte("task.progress")) {
|
||||
t.Fatalf("events response should include progress events body=%s", string(body))
|
||||
}
|
||||
@ -1074,6 +1157,47 @@ WHERE m.platform_id = $1::uuid
|
||||
if callbackRows == 0 {
|
||||
t.Fatal("task progress callback outbox should receive events")
|
||||
}
|
||||
|
||||
var restartAsyncTask struct {
|
||||
TaskID string `json:"taskId"`
|
||||
Task struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
AsyncMode bool `json:"asyncMode"`
|
||||
} `json:"task"`
|
||||
}
|
||||
doJSONWithHeaders(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
|
||||
"model": defaultTextModel,
|
||||
"runMode": "simulation",
|
||||
"simulation": true,
|
||||
"simulationDurationMs": 2000,
|
||||
"messages": []map[string]any{{"role": "user", "content": "river worker restart"}},
|
||||
}, map[string]string{"X-Async": "true"}, http.StatusAccepted, &restartAsyncTask)
|
||||
if restartAsyncTask.TaskID == "" || !restartAsyncTask.Task.AsyncMode {
|
||||
t.Fatalf("restart async task should be accepted as async: %+v", restartAsyncTask)
|
||||
}
|
||||
restartRunning := waitForTaskStatus(t, server.URL, apiKeyResponse.Secret, restartAsyncTask.TaskID, []string{"running"}, 3*time.Second)
|
||||
if restartRunning.Status != "running" {
|
||||
t.Fatalf("restart async task should be running before worker restart, got %+v", restartRunning)
|
||||
}
|
||||
cancelServer()
|
||||
serverCtx2, cancelServer2 := context.WithCancel(ctx)
|
||||
defer cancelServer2()
|
||||
server2 := httptest.NewServer(NewServerWithContext(serverCtx2, config.Config{
|
||||
AppEnv: "test",
|
||||
HTTPAddr: ":0",
|
||||
DatabaseURL: databaseURL,
|
||||
IdentityMode: "hybrid",
|
||||
JWTSecret: "test-secret",
|
||||
TaskProgressCallbackEnabled: true,
|
||||
TaskProgressCallbackURL: "http://callback.local/task-progress",
|
||||
CORSAllowedOrigin: "*",
|
||||
}, db, slog.New(slog.NewTextHandler(io.Discard, nil))))
|
||||
defer server2.Close()
|
||||
restartRecovered := waitForTaskStatus(t, server2.URL, apiKeyResponse.Secret, restartAsyncTask.TaskID, []string{"succeeded"}, 8*time.Second)
|
||||
if restartRecovered.Status != "succeeded" {
|
||||
t.Fatalf("river worker restart should recover and finish async task, got %+v", restartRecovered)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOriginAllowedSupportsCommaSeparatedOrigins(t *testing.T) {
|
||||
@ -1089,6 +1213,51 @@ func TestOriginAllowedSupportsCommaSeparatedOrigins(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func assertRuntimeRecoveryReleasesPendingRateReservations(t *testing.T, ctx context.Context, db *store.Store) {
|
||||
t.Helper()
|
||||
suffix := strconv.FormatInt(time.Now().UnixNano(), 10)
|
||||
task, err := db.CreateTask(ctx, store.CreateTaskInput{
|
||||
Kind: "chat.completions",
|
||||
Model: "recovery-reservation-smoke-" + suffix,
|
||||
RunMode: "test",
|
||||
Async: true,
|
||||
Request: map[string]any{"messages": []any{}},
|
||||
}, &auth.User{ID: "recovery-user-" + suffix})
|
||||
if err != nil {
|
||||
t.Fatalf("create recovery reservation task: %v", err)
|
||||
}
|
||||
scopeKey := "recovery-client-" + suffix
|
||||
if _, err := db.ReserveRateLimits(ctx, task.ID, "", []store.RateLimitReservation{{
|
||||
ScopeType: "client",
|
||||
ScopeKey: scopeKey,
|
||||
Metric: "rpm",
|
||||
Limit: 10,
|
||||
Amount: 3,
|
||||
WindowSeconds: 60,
|
||||
}}); err != nil {
|
||||
t.Fatalf("reserve recovery rate limit: %v", err)
|
||||
}
|
||||
recovery, err := db.RecoverInterruptedRuntimeState(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("recover interrupted runtime state with pending reservation: %v", err)
|
||||
}
|
||||
if recovery.ReleasedRateReservations == 0 {
|
||||
t.Fatalf("recovery should release pending rate reservation, got %+v", recovery)
|
||||
}
|
||||
var reservedValue float64
|
||||
if err := db.Pool().QueryRow(ctx, `
|
||||
SELECT COALESCE(MAX(reserved_value), 0)::float8
|
||||
FROM gateway_rate_limit_counters
|
||||
WHERE scope_type = 'client'
|
||||
AND scope_key = $1
|
||||
AND metric = 'rpm'`, scopeKey).Scan(&reservedValue); err != nil {
|
||||
t.Fatalf("query recovered rate limit counter: %v", err)
|
||||
}
|
||||
if reservedValue != 0 {
|
||||
t.Fatalf("recovery should release reserved counter value, got %f", reservedValue)
|
||||
}
|
||||
}
|
||||
|
||||
func applyMigration(t *testing.T, ctx context.Context, databaseURL string) {
|
||||
t.Helper()
|
||||
_, filename, _, _ := runtime.Caller(0)
|
||||
@ -1114,6 +1283,11 @@ func applyMigration(t *testing.T, ctx context.Context, databaseURL string) {
|
||||
}
|
||||
|
||||
func doJSON(t *testing.T, baseURL string, method string, path string, token string, payload any, expectedStatus int, out any) {
|
||||
t.Helper()
|
||||
doJSONWithHeaders(t, baseURL, method, path, token, payload, nil, expectedStatus, out)
|
||||
}
|
||||
|
||||
func doJSONWithHeaders(t *testing.T, baseURL string, method string, path string, token string, payload any, headers map[string]string, expectedStatus int, out any) {
|
||||
t.Helper()
|
||||
var body io.Reader
|
||||
if payload != nil {
|
||||
@ -1133,6 +1307,9 @@ func doJSON(t *testing.T, baseURL string, method string, path string, token stri
|
||||
if token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
for key, value := range headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("%s %s: %v", method, path, err)
|
||||
@ -1149,6 +1326,50 @@ func doJSON(t *testing.T, baseURL string, method string, path string, token stri
|
||||
}
|
||||
}
|
||||
|
||||
type taskWaitDetail struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
Attempts []struct {
|
||||
Status string `json:"status"`
|
||||
ErrorCode string `json:"errorCode"`
|
||||
} `json:"attempts"`
|
||||
}
|
||||
|
||||
func waitForTaskStatus(t *testing.T, baseURL string, token string, taskID string, statuses []string, timeout time.Duration) taskWaitDetail {
|
||||
t.Helper()
|
||||
wanted := map[string]bool{}
|
||||
for _, status := range statuses {
|
||||
wanted[status] = true
|
||||
}
|
||||
var detail taskWaitDetail
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
doJSON(t, baseURL, http.MethodGet, "/api/v1/tasks/"+taskID, token, nil, http.StatusOK, &detail)
|
||||
if wanted[detail.Status] {
|
||||
return detail
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
return detail
|
||||
}
|
||||
|
||||
func waitForRateLimitWindowHead(t *testing.T, windowSeconds int) {
|
||||
t.Helper()
|
||||
if windowSeconds <= 0 {
|
||||
return
|
||||
}
|
||||
window := time.Duration(windowSeconds) * time.Second
|
||||
deadline := time.Now().Add(window + time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
elapsed := time.Duration(time.Now().UnixNano() % int64(window))
|
||||
if elapsed < 300*time.Millisecond {
|
||||
return
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("timed out waiting for %ds rate limit window head", windowSeconds)
|
||||
}
|
||||
|
||||
func stringSliceContains(values []string, target string) bool {
|
||||
for _, value := range values {
|
||||
if value == target {
|
||||
|
||||
@ -542,11 +542,13 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler {
|
||||
writeError(w, http.StatusForbidden, "api key scope does not allow this capability")
|
||||
return
|
||||
}
|
||||
asyncMode := asyncRequest(r)
|
||||
|
||||
task, err := s.store.CreateTask(r.Context(), store.CreateTaskInput{
|
||||
Kind: kind,
|
||||
Model: model,
|
||||
RunMode: runModeFromRequest(body),
|
||||
Async: asyncMode,
|
||||
Request: body,
|
||||
}, user)
|
||||
if err != nil {
|
||||
@ -554,6 +556,14 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler {
|
||||
writeError(w, http.StatusInternalServerError, "create task failed")
|
||||
return
|
||||
}
|
||||
if asyncMode {
|
||||
if err := s.runner.EnqueueAsyncTask(r.Context(), task); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, err.Error(), "enqueue_failed")
|
||||
return
|
||||
}
|
||||
writeTaskAccepted(w, task)
|
||||
return
|
||||
}
|
||||
if compatible {
|
||||
if boolValue(body, "stream") {
|
||||
flusher := prepareCompatibleStream(w)
|
||||
@ -602,13 +612,23 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler {
|
||||
s.logger.Warn("task completed with failure", "kind", kind, "taskId", task.ID, "error", runErr)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusAccepted, map[string]any{
|
||||
"task": result.Task,
|
||||
"next": map[string]string{
|
||||
"events": fmt.Sprintf("/api/v1/tasks/%s/events", task.ID),
|
||||
"detail": fmt.Sprintf("/api/v1/tasks/%s", task.ID),
|
||||
},
|
||||
})
|
||||
writeTaskAccepted(w, result.Task)
|
||||
})
|
||||
}
|
||||
|
||||
func asyncRequest(r *http.Request) bool {
|
||||
value := strings.TrimSpace(strings.ToLower(r.Header.Get("x-async")))
|
||||
return value == "1" || value == "true" || value == "yes" || value == "on"
|
||||
}
|
||||
|
||||
func writeTaskAccepted(w http.ResponseWriter, task store.GatewayTask) {
|
||||
writeJSON(w, http.StatusAccepted, map[string]any{
|
||||
"taskId": task.ID,
|
||||
"task": task,
|
||||
"next": map[string]string{
|
||||
"events": fmt.Sprintf("/api/v1/tasks/%s/events", task.ID),
|
||||
"detail": fmt.Sprintf("/api/v1/tasks/%s", task.ID),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package httpapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
@ -12,6 +13,7 @@ import (
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
ctx context.Context
|
||||
cfg config.Config
|
||||
store *store.Store
|
||||
auth *auth.Authenticator
|
||||
@ -20,7 +22,12 @@ type Server struct {
|
||||
}
|
||||
|
||||
func NewServer(cfg config.Config, db *store.Store, logger *slog.Logger) http.Handler {
|
||||
return NewServerWithContext(context.Background(), cfg, db, logger)
|
||||
}
|
||||
|
||||
func NewServerWithContext(ctx context.Context, cfg config.Config, db *store.Store, logger *slog.Logger) http.Handler {
|
||||
server := &Server{
|
||||
ctx: ctx,
|
||||
cfg: cfg,
|
||||
store: db,
|
||||
auth: auth.New(cfg.JWTSecret, cfg.ServerMainBaseURL, cfg.ServerMainInternalToken),
|
||||
@ -28,6 +35,7 @@ func NewServer(cfg config.Config, db *store.Store, logger *slog.Logger) http.Han
|
||||
logger: logger,
|
||||
}
|
||||
server.auth.LocalAPIKeyVerifier = db.VerifyLocalAPIKey
|
||||
server.runner.StartAsyncQueueWorker(ctx)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("GET /healthz", server.health)
|
||||
@ -146,7 +154,7 @@ func (s *Server) cors(next http.Handler) http.Handler {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
w.Header().Set("Vary", "Origin")
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Comfy-Api-Key")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Comfy-Api-Key, X-Async")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
||||
}
|
||||
if r.Method == http.MethodOptions {
|
||||
|
||||
@ -2,13 +2,49 @@ package runner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||
)
|
||||
|
||||
type localRateLimitError struct {
|
||||
clientErr *clients.ClientError
|
||||
cause error
|
||||
retryAfter time.Duration
|
||||
}
|
||||
|
||||
func (e *localRateLimitError) Error() string {
|
||||
if e == nil || e.clientErr == nil {
|
||||
return store.ErrRateLimited.Error()
|
||||
}
|
||||
return e.clientErr.Error()
|
||||
}
|
||||
|
||||
func (e *localRateLimitError) Unwrap() []error {
|
||||
if e == nil || e.clientErr == nil {
|
||||
if e != nil && e.cause != nil {
|
||||
return []error{e.cause}
|
||||
}
|
||||
return []error{store.ErrRateLimited}
|
||||
}
|
||||
if e.cause != nil {
|
||||
return []error{e.clientErr, e.cause}
|
||||
}
|
||||
return []error{e.clientErr, store.ErrRateLimited}
|
||||
}
|
||||
|
||||
func localRateLimitRetryAfter(err error) time.Duration {
|
||||
var limitErr *localRateLimitError
|
||||
if errors.As(err, &limitErr) && limitErr.retryAfter > 0 {
|
||||
return limitErr.retryAfter
|
||||
}
|
||||
return store.RateLimitRetryAfter(err)
|
||||
}
|
||||
|
||||
func (s *Service) rateLimitReservations(ctx context.Context, user *auth.User, candidate store.RuntimeModelCandidate, body map[string]any) []store.RateLimitReservation {
|
||||
out := make([]store.RateLimitReservation, 0)
|
||||
out = append(out, reservationsFromPolicy("platform_model", candidate.PlatformModelID, effectiveRateLimitPolicy(candidate), body)...)
|
||||
|
||||
216
apps/api/internal/runner/queue_worker.go
Normal file
216
apps/api/internal/runner/queue_worker.go
Normal file
@ -0,0 +1,216 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||
"github.com/riverqueue/river"
|
||||
"github.com/riverqueue/river/riverdriver/riverpgxv5"
|
||||
"github.com/riverqueue/river/rivermigrate"
|
||||
"github.com/riverqueue/river/rivertype"
|
||||
)
|
||||
|
||||
const asyncTaskQueueName = "gateway_tasks"
|
||||
|
||||
type asyncTaskArgs struct {
|
||||
TaskID string `json:"task_id" river:"unique"`
|
||||
}
|
||||
|
||||
func (asyncTaskArgs) Kind() string { return "gateway_task_run" }
|
||||
|
||||
type asyncTaskWorker struct {
|
||||
river.WorkerDefaults[asyncTaskArgs]
|
||||
|
||||
service *Service
|
||||
}
|
||||
|
||||
func (w *asyncTaskWorker) Work(ctx context.Context, job *river.Job[asyncTaskArgs]) error {
|
||||
task, err := w.service.store.GetTask(ctx, job.Args.TaskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if task.Status == "succeeded" || task.Status == "failed" || task.Status == "cancelled" {
|
||||
return nil
|
||||
}
|
||||
result, runErr := w.service.Execute(ctx, task, authUserFromTask(task))
|
||||
if runErr == nil {
|
||||
w.service.logger.Debug("river async task completed", "taskID", task.ID, "status", result.Task.Status, "riverJobID", job.ID)
|
||||
return nil
|
||||
}
|
||||
var queuedErr *TaskQueuedError
|
||||
if errors.As(runErr, &queuedErr) {
|
||||
return river.JobSnooze(queuedErr.Delay)
|
||||
}
|
||||
if ctx.Err() != nil {
|
||||
queued, queueErr := w.service.requeueInterruptedAsyncTask(context.WithoutCancel(ctx), task)
|
||||
if queueErr != nil {
|
||||
return queueErr
|
||||
}
|
||||
w.service.logger.Debug("river async task interrupted and requeued", "taskID", task.ID, "status", queued.Status, "riverJobID", job.ID)
|
||||
return river.JobSnooze(0)
|
||||
}
|
||||
w.service.logger.Warn("river async task completed with failure", "taskID", task.ID, "error", runErr, "riverJobID", job.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) StartAsyncQueueWorker(ctx context.Context) {
|
||||
if err := s.startRiverQueue(ctx); err != nil {
|
||||
s.logger.Error("start river async queue failed", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) startRiverQueue(ctx context.Context) error {
|
||||
driver := riverpgxv5.New(s.store.Pool())
|
||||
migrator, err := rivermigrate.New(driver, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := migrator.Migrate(ctx, rivermigrate.DirectionUp, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
workers := river.NewWorkers()
|
||||
if err := river.AddWorkerSafely(workers, &asyncTaskWorker{service: s}); err != nil {
|
||||
return err
|
||||
}
|
||||
riverClient, err := river.NewClient(driver, &river.Config{
|
||||
ID: asyncWorkerID(),
|
||||
JobTimeout: -1,
|
||||
Logger: s.logger,
|
||||
CompletedJobRetentionPeriod: 24 * time.Hour,
|
||||
Queues: map[string]river.QueueConfig{
|
||||
asyncTaskQueueName: {MaxWorkers: 32},
|
||||
},
|
||||
RescueStuckJobsAfter: 30 * time.Second,
|
||||
TestOnly: s.cfg.AppEnv == "test",
|
||||
Workers: workers,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.riverClient = riverClient
|
||||
if err := riverClient.Start(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.recoverAsyncRiverJobs(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
stopCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := riverClient.StopAndCancel(stopCtx); err != nil {
|
||||
s.logger.Warn("stop river async queue failed", "error", err)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) EnqueueAsyncTask(ctx context.Context, task store.GatewayTask) error {
|
||||
if s.riverClient == nil {
|
||||
return errors.New("river async queue is not started")
|
||||
}
|
||||
result, err := s.riverClient.Insert(ctx, asyncTaskArgs{TaskID: task.ID}, asyncTaskInsertOpts(task))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if result.Job != nil {
|
||||
return s.store.SetTaskRiverJobID(ctx, task.ID, result.Job.ID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) WakeAsyncQueueAfter(ctx context.Context, delay time.Duration) {
|
||||
}
|
||||
|
||||
func (s *Service) RunAsyncTask(ctx context.Context, task store.GatewayTask, user *auth.User) {
|
||||
if err := s.EnqueueAsyncTask(ctx, task); err != nil {
|
||||
s.logger.Warn("enqueue river async task failed", "taskID", task.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) recoverAsyncRiverJobs(ctx context.Context) error {
|
||||
items, err := s.store.ListRecoverableAsyncTasks(ctx, 1000)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, item := range items {
|
||||
task := store.GatewayTask{ID: item.ID}
|
||||
result, err := s.riverClient.Insert(ctx, asyncTaskArgs{TaskID: item.ID}, asyncTaskInsertOpts(task))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if result.Job != nil {
|
||||
if err := s.store.SetTaskRiverJobID(ctx, item.ID, result.Job.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(items) > 0 {
|
||||
s.logger.Info("river async queue recovered persisted tasks", "count", len(items))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func asyncTaskInsertOpts(task store.GatewayTask) *river.InsertOpts {
|
||||
priority := 2
|
||||
if task.ID == "" {
|
||||
priority = 3
|
||||
}
|
||||
return &river.InsertOpts{
|
||||
MaxAttempts: 1000,
|
||||
Priority: priority,
|
||||
Queue: asyncTaskQueueName,
|
||||
Tags: []string{"gateway-task"},
|
||||
UniqueOpts: river.UniqueOpts{
|
||||
ByArgs: true,
|
||||
ByQueue: true,
|
||||
ByState: []rivertype.JobState{
|
||||
rivertype.JobStateAvailable,
|
||||
rivertype.JobStatePending,
|
||||
rivertype.JobStateRetryable,
|
||||
rivertype.JobStateRunning,
|
||||
rivertype.JobStateScheduled,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func authUserFromTask(task store.GatewayTask) *auth.User {
|
||||
roles := []string{"user"}
|
||||
if strings.TrimSpace(task.UserID) == "" {
|
||||
roles = nil
|
||||
}
|
||||
return &auth.User{
|
||||
ID: firstNonEmptyString(task.GatewayUserID, task.UserID),
|
||||
Roles: roles,
|
||||
TenantID: task.TenantID,
|
||||
GatewayTenantID: task.GatewayTenantID,
|
||||
TenantKey: task.TenantKey,
|
||||
Source: firstNonEmptyString(task.UserSource, "gateway"),
|
||||
GatewayUserID: task.GatewayUserID,
|
||||
UserGroupID: task.UserGroupID,
|
||||
UserGroupKey: task.UserGroupKey,
|
||||
APIKeyID: task.APIKeyID,
|
||||
APIKeyName: task.APIKeyName,
|
||||
APIKeyPrefix: task.APIKeyPrefix,
|
||||
}
|
||||
}
|
||||
|
||||
func asyncWorkerID() string {
|
||||
host, _ := os.Hostname()
|
||||
host = strings.TrimSpace(host)
|
||||
if host == "" {
|
||||
host = "localhost"
|
||||
}
|
||||
return fmt.Sprintf("%s:%d:%d", host, os.Getpid(), time.Now().UnixNano())
|
||||
}
|
||||
|
||||
var _ river.Worker[asyncTaskArgs] = (*asyncTaskWorker)(nil)
|
||||
@ -1,6 +1,7 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@ -54,6 +55,9 @@ func shouldRetrySameClient(candidate store.RuntimeModelCandidate, err error) boo
|
||||
func retryDecisionForCandidate(candidate store.RuntimeModelCandidate, err error) retryDecision {
|
||||
policy := effectiveRetryPolicy(candidate)
|
||||
info := failureInfoFromError(err)
|
||||
if errors.Is(err, store.ErrRateLimited) {
|
||||
return retryDecision{Retry: false, Reason: "local_rate_limit_wait_queue", Match: policyRuleMatch{Source: "gateway_rate_limits", Policy: "rateLimitPolicy", Rule: "localCapacity", Value: "exceeded"}, Info: info}
|
||||
}
|
||||
if !boolFromPolicy(policy, "enabled", true) {
|
||||
return retryDecision{Retry: false, Reason: "retry_disabled", Match: policyRuleMatch{Source: "model_runtime_policy_sets.retry_policy", Policy: "retryPolicy", Rule: "enabled", Value: "false"}, Info: info}
|
||||
}
|
||||
@ -94,6 +98,9 @@ func failoverDecisionForCandidate(runnerPolicy store.RunnerPolicy, candidate sto
|
||||
if cooldownSeconds <= 0 {
|
||||
cooldownSeconds = 300
|
||||
}
|
||||
if errors.Is(err, store.ErrRateLimited) && store.RateLimitRetryable(err) {
|
||||
return failoverDecision{Retry: true, Action: "next", Reason: "local_rate_limit_try_next_candidate", CooldownSeconds: cooldownSeconds, Match: policyRuleMatch{Source: "gateway_rate_limits", Policy: "rateLimitPolicy", Rule: "localCapacity", Value: "exceeded"}, Info: info}
|
||||
}
|
||||
if match, ok := failoverAllowMatchWithSources(runnerPolicy.FailoverPolicy, overridePolicy, info); ok {
|
||||
return failoverDecision{Retry: true, Action: action, Reason: "failover_allow_policy", CooldownSeconds: cooldownSeconds, Match: match, Info: info}
|
||||
}
|
||||
|
||||
@ -2,6 +2,7 @@ package runner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
|
||||
@ -65,6 +66,9 @@ func (s *Service) applyFailoverAction(ctx context.Context, taskID string, candid
|
||||
}
|
||||
|
||||
func (s *Service) applyPriorityDemotePolicy(ctx context.Context, taskID string, attemptNo int, runnerPolicy store.RunnerPolicy, candidate store.RuntimeModelCandidate, cause error, simulated bool) {
|
||||
if errors.Is(cause, store.ErrRateLimited) {
|
||||
return
|
||||
}
|
||||
decision := priorityDemoteDecisionForCandidate(runnerPolicy, cause)
|
||||
if !decision.Demote {
|
||||
return
|
||||
|
||||
@ -3,6 +3,7 @@ package runner
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -12,6 +13,8 @@ import (
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/riverqueue/river"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
@ -20,6 +23,7 @@ type Service struct {
|
||||
logger *slog.Logger
|
||||
clients map[string]clients.Client
|
||||
httpClients *httpClientCache
|
||||
riverClient *river.Client[pgx.Tx]
|
||||
}
|
||||
|
||||
type Result struct {
|
||||
@ -27,6 +31,20 @@ type Result struct {
|
||||
Output map[string]any
|
||||
}
|
||||
|
||||
var ErrTaskQueued = errors.New("task queued")
|
||||
|
||||
type TaskQueuedError struct {
|
||||
Delay time.Duration
|
||||
}
|
||||
|
||||
func (e *TaskQueuedError) Error() string {
|
||||
return ErrTaskQueued.Error()
|
||||
}
|
||||
|
||||
func (e *TaskQueuedError) Is(target error) bool {
|
||||
return target == ErrTaskQueued
|
||||
}
|
||||
|
||||
func New(cfg config.Config, db *store.Store, logger *slog.Logger) *Service {
|
||||
httpClients := newHTTPClientCache()
|
||||
return &Service{
|
||||
@ -55,6 +73,14 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
|
||||
executeStartedAt := time.Now()
|
||||
body := normalizeRequest(task.Kind, task.Request)
|
||||
modelType := modelTypeFromKind(task.Kind, body)
|
||||
if err := s.store.MarkTaskRunning(ctx, task.ID, modelType, body); err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
if task.Status != "running" {
|
||||
if err := s.emit(ctx, task.ID, "task.running", "running", "starting", 0.12, "task pulled from queue and started", map[string]any{"modelType": modelType}, task.RunMode == "simulation"); err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
}
|
||||
if err := validateRequest(task.Kind, body); err != nil {
|
||||
failed, finishErr := s.failTask(ctx, task.ID, "bad_request", err.Error(), task.RunMode == "simulation", err)
|
||||
if finishErr != nil {
|
||||
@ -83,9 +109,6 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
|
||||
return Result{}, err
|
||||
}
|
||||
}
|
||||
if err := s.store.MarkTaskRunning(ctx, task.ID, modelType, body); err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
if err := s.emit(ctx, task.ID, "task.progress", "running", "normalizing", 0.15, "request normalized", map[string]any{"modelType": modelType}, task.RunMode == "simulation"); err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
@ -96,7 +119,7 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
|
||||
}
|
||||
maxPlatforms := maxPlatformsForCandidates(candidates, runnerPolicy)
|
||||
maxFailoverDuration := maxFailoverDurationForCandidates(candidates, runnerPolicy)
|
||||
attemptNo := 0
|
||||
attemptNo := task.AttemptCount
|
||||
var lastErr error
|
||||
for index, candidate := range candidates {
|
||||
if index >= maxPlatforms {
|
||||
@ -251,6 +274,20 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
|
||||
if lastErr != nil {
|
||||
message = lastErr.Error()
|
||||
}
|
||||
if task.AsyncMode && ctx.Err() != nil {
|
||||
queued, queueErr := s.requeueInterruptedAsyncTask(context.WithoutCancel(ctx), task)
|
||||
if queueErr != nil {
|
||||
return Result{}, queueErr
|
||||
}
|
||||
return Result{Task: queued, Output: queued.Result}, &TaskQueuedError{Delay: 0}
|
||||
}
|
||||
if task.AsyncMode && errors.Is(lastErr, store.ErrRateLimited) && store.RateLimitRetryable(lastErr) {
|
||||
queued, delay, queueErr := s.requeueRateLimitedTask(ctx, task, lastErr)
|
||||
if queueErr != nil {
|
||||
return Result{}, queueErr
|
||||
}
|
||||
return Result{Task: queued, Output: queued.Result}, &TaskQueuedError{Delay: delay}
|
||||
}
|
||||
failed, err := s.failTask(ctx, task.ID, code, message, task.RunMode == "simulation", lastErr)
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
@ -261,7 +298,7 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
|
||||
func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user *auth.User, body map[string]any, candidate store.RuntimeModelCandidate, attemptNo int, onDelta clients.StreamDelta) (clients.Response, error) {
|
||||
simulated := isSimulation(task, candidate)
|
||||
if err := s.emit(ctx, task.ID, "task.attempt.started", "running", "submitting", 0.25, "client attempt started", map[string]any{"attempt": attemptNo, "clientId": candidate.ClientID}, simulated); err != nil {
|
||||
return clients.Response{}, err
|
||||
return clients.Response{}, fmt.Errorf("emit attempt started: %w", err)
|
||||
}
|
||||
attemptID, err := s.store.CreateTaskAttempt(ctx, store.CreateTaskAttemptInput{
|
||||
TaskID: task.ID,
|
||||
@ -276,21 +313,22 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user
|
||||
Metrics: attemptMetrics(candidate, attemptNo, simulated),
|
||||
})
|
||||
if err != nil {
|
||||
return clients.Response{}, err
|
||||
return clients.Response{}, fmt.Errorf("create task attempt: %w", err)
|
||||
}
|
||||
reservations := s.rateLimitReservations(ctx, user, candidate, body)
|
||||
limitResult, err := s.store.ReserveRateLimits(ctx, task.ID, attemptID, reservations)
|
||||
if err != nil {
|
||||
clientErr := &clients.ClientError{Code: "rate_limit", Message: err.Error(), Retryable: false}
|
||||
retryable := store.RateLimitRetryable(err)
|
||||
clientErr := &clients.ClientError{Code: "rate_limit", Message: err.Error(), Retryable: retryable}
|
||||
_ = s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{
|
||||
AttemptID: attemptID,
|
||||
Status: "failed",
|
||||
Retryable: false,
|
||||
Metrics: mergeMetrics(attemptMetrics(candidate, attemptNo, simulated), map[string]any{"error": err.Error(), "retryable": false, "trace": []any{failureTraceEntry(clientErr, false)}}),
|
||||
Retryable: retryable,
|
||||
Metrics: mergeMetrics(attemptMetrics(candidate, attemptNo, simulated), map[string]any{"error": err.Error(), "retryable": retryable, "retryAfterMs": localRateLimitRetryAfter(err).Milliseconds(), "trace": []any{failureTraceEntry(clientErr, retryable)}}),
|
||||
ErrorCode: "rate_limit",
|
||||
ErrorMessage: err.Error(),
|
||||
})
|
||||
return clients.Response{}, clientErr
|
||||
return clients.Response{}, &localRateLimitError{clientErr: clientErr, cause: err, retryAfter: localRateLimitRetryAfter(err)}
|
||||
}
|
||||
rateReservationsFinalized := false
|
||||
defer func() {
|
||||
@ -301,7 +339,7 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user
|
||||
defer s.store.ReleaseConcurrencyLeases(context.WithoutCancel(ctx), limitResult.LeaseIDs)
|
||||
|
||||
if err := s.store.RecordClientAssignment(ctx, candidate); err != nil {
|
||||
return clients.Response{}, err
|
||||
return clients.Response{}, fmt.Errorf("record client assignment: %w", err)
|
||||
}
|
||||
defer s.store.RecordClientRelease(context.WithoutCancel(ctx), candidate.ClientID, "")
|
||||
|
||||
@ -315,17 +353,25 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user
|
||||
ErrorCode: clients.ErrorCode(err),
|
||||
ErrorMessage: err.Error(),
|
||||
})
|
||||
return clients.Response{}, err
|
||||
return clients.Response{}, fmt.Errorf("prepare http client: %w", err)
|
||||
}
|
||||
client := s.clientFor(candidate, simulated)
|
||||
callStartedAt := time.Now()
|
||||
response, err := client.Run(ctx, clients.Request{
|
||||
Kind: task.Kind,
|
||||
ModelType: candidate.ModelType,
|
||||
Model: task.Model,
|
||||
Body: body,
|
||||
Candidate: candidate,
|
||||
HTTPClient: requestHTTPClient,
|
||||
Kind: task.Kind,
|
||||
ModelType: candidate.ModelType,
|
||||
Model: task.Model,
|
||||
Body: body,
|
||||
Candidate: candidate,
|
||||
HTTPClient: requestHTTPClient,
|
||||
RemoteTaskID: task.RemoteTaskID,
|
||||
RemoteTaskPayload: task.RemoteTaskPayload,
|
||||
OnRemoteTaskSubmitted: func(remoteTaskID string, payload map[string]any) error {
|
||||
if strings.TrimSpace(remoteTaskID) == "" {
|
||||
return nil
|
||||
}
|
||||
return s.store.SetTaskRemoteTask(context.WithoutCancel(ctx), task.ID, attemptID, remoteTaskID, payload)
|
||||
},
|
||||
Stream: boolFromMap(body, "stream"),
|
||||
StreamDelta: onDelta,
|
||||
})
|
||||
@ -400,11 +446,11 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user
|
||||
response.Result = uploadedResult
|
||||
for _, progress := range response.Progress {
|
||||
if err := s.emit(ctx, task.ID, "task.progress", "running", progress.Phase, progress.Progress, progress.Message, progress.Payload, simulated); err != nil {
|
||||
return clients.Response{}, err
|
||||
return clients.Response{}, fmt.Errorf("emit task progress: %w", err)
|
||||
}
|
||||
}
|
||||
if err := s.store.CommitRateLimitReservations(ctx, limitResult.Reservations, tokenUsageAmounts(response.Usage)); err != nil {
|
||||
return clients.Response{}, err
|
||||
return clients.Response{}, fmt.Errorf("commit rate limit reservations: %w", err)
|
||||
}
|
||||
rateReservationsFinalized = true
|
||||
if err := s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{
|
||||
@ -418,7 +464,7 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user
|
||||
ResponseFinishedAt: response.ResponseFinishedAt,
|
||||
ResponseDurationMS: response.ResponseDurationMS,
|
||||
}); err != nil {
|
||||
return clients.Response{}, err
|
||||
return clients.Response{}, fmt.Errorf("finish task attempt: %w", err)
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
@ -459,6 +505,41 @@ func (s *Service) failTask(ctx context.Context, taskID string, code string, mess
|
||||
return failed, nil
|
||||
}
|
||||
|
||||
func (s *Service) requeueRateLimitedTask(ctx context.Context, task store.GatewayTask, cause error) (store.GatewayTask, time.Duration, error) {
|
||||
delay := localRateLimitRetryAfter(cause)
|
||||
if delay <= 0 {
|
||||
delay = 5 * time.Second
|
||||
}
|
||||
queued, err := s.store.RequeueTask(ctx, task.ID, delay)
|
||||
if err != nil {
|
||||
return store.GatewayTask{}, 0, err
|
||||
}
|
||||
payload := map[string]any{
|
||||
"code": "rate_limit",
|
||||
"message": cause.Error(),
|
||||
"retryAfterMs": delay.Milliseconds(),
|
||||
}
|
||||
if eventErr := s.emit(ctx, task.ID, "task.queued", "queued", "rate_limited", 0.2, "task queued by local rate limit", payload, task.RunMode == "simulation"); eventErr != nil {
|
||||
return store.GatewayTask{}, 0, eventErr
|
||||
}
|
||||
return queued, delay, nil
|
||||
}
|
||||
|
||||
func (s *Service) requeueInterruptedAsyncTask(ctx context.Context, task store.GatewayTask) (store.GatewayTask, error) {
|
||||
queued, err := s.store.RequeueTask(ctx, task.ID, 0)
|
||||
if err != nil {
|
||||
return store.GatewayTask{}, err
|
||||
}
|
||||
payload := map[string]any{"code": "worker_interrupted"}
|
||||
if task.RemoteTaskID != "" {
|
||||
payload["remoteTaskId"] = task.RemoteTaskID
|
||||
}
|
||||
if eventErr := s.emit(ctx, task.ID, "task.queued", "queued", "worker_interrupted", 0.2, "async task queued after worker interruption", payload, task.RunMode == "simulation"); eventErr != nil {
|
||||
return store.GatewayTask{}, eventErr
|
||||
}
|
||||
return queued, nil
|
||||
}
|
||||
|
||||
func (s *Service) withAttemptHistory(ctx context.Context, taskID string, metrics map[string]any) map[string]any {
|
||||
attempts, err := s.store.ListTaskAttempts(ctx, taskID)
|
||||
if err != nil {
|
||||
|
||||
@ -19,7 +19,7 @@ SELECT p.id::text, p.platform_key, p.name, p.provider,
|
||||
COALESCE(p.dynamic_priority, p.priority) AS effective_priority,
|
||||
m.id::text, COALESCE(m.base_model_id::text, ''), COALESCE(b.canonical_model_key, ''),
|
||||
COALESCE(NULLIF(m.provider_model_name, ''), m.model_name), m.model_name, COALESCE(m.model_alias, ''),
|
||||
$2 AS requested_model_type, m.display_name, m.capabilities, m.capability_override,
|
||||
$2::text AS requested_model_type, m.display_name, m.capabilities, m.capability_override,
|
||||
COALESCE(b.base_billing_config, '{}'::jsonb), m.billing_config, m.billing_config_override,
|
||||
m.pricing_mode, COALESCE(m.discount_factor, 0)::float8, COALESCE(m.pricing_rule_set_id::text, ''),
|
||||
COALESCE(b.pricing_rule_set_id::text, ''),
|
||||
@ -33,21 +33,21 @@ LEFT JOIN model_catalog_providers cp ON cp.provider_key = p.provider OR cp.provi
|
||||
LEFT JOIN base_model_catalog b ON b.id = m.base_model_id
|
||||
LEFT JOIN model_runtime_policy_sets rp ON rp.id = COALESCE(m.runtime_policy_set_id, b.runtime_policy_set_id)
|
||||
LEFT JOIN runtime_client_states s
|
||||
ON s.client_id = p.platform_key || ':' || $2 || ':' || COALESCE(NULLIF(m.provider_model_name, ''), m.model_name)
|
||||
ON s.client_id = p.platform_key || ':' || $2::text || ':' || COALESCE(NULLIF(m.provider_model_name, ''), m.model_name)
|
||||
WHERE p.status = 'enabled'
|
||||
AND p.deleted_at IS NULL
|
||||
AND m.enabled = true
|
||||
AND m.model_type @> jsonb_build_array($2)
|
||||
AND m.model_type @> jsonb_build_array($2::text)
|
||||
AND (p.cooldown_until IS NULL OR p.cooldown_until <= now())
|
||||
AND (m.cooldown_until IS NULL OR m.cooldown_until <= now())
|
||||
AND (
|
||||
(COALESCE(m.model_alias, '') <> '' AND m.model_alias = $1)
|
||||
(COALESCE(m.model_alias, '') <> '' AND m.model_alias = $1::text)
|
||||
OR (
|
||||
COALESCE(m.model_alias, '') = ''
|
||||
AND (
|
||||
m.model_name = $1
|
||||
OR b.canonical_model_key = $1
|
||||
OR b.provider_model_name = $1
|
||||
m.model_name = $1::text
|
||||
OR b.canonical_model_key = $1::text
|
||||
OR b.provider_model_name = $1::text
|
||||
)
|
||||
)
|
||||
)
|
||||
@ -184,15 +184,15 @@ LEFT JOIN base_model_catalog b ON b.id = m.base_model_id
|
||||
WHERE p.status = 'enabled'
|
||||
AND p.deleted_at IS NULL
|
||||
AND m.enabled = true
|
||||
AND m.model_type @> jsonb_build_array($2)
|
||||
AND m.model_type @> jsonb_build_array($2::text)
|
||||
AND (
|
||||
(COALESCE(m.model_alias, '') <> '' AND m.model_alias = $1)
|
||||
(COALESCE(m.model_alias, '') <> '' AND m.model_alias = $1::text)
|
||||
OR (
|
||||
COALESCE(m.model_alias, '') = ''
|
||||
AND (
|
||||
m.model_name = $1
|
||||
OR b.canonical_model_key = $1
|
||||
OR b.provider_model_name = $1
|
||||
m.model_name = $1::text
|
||||
OR b.canonical_model_key = $1::text
|
||||
OR b.provider_model_name = $1::text
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@ -53,6 +53,10 @@ func (s *Store) Ping(ctx context.Context) error {
|
||||
return s.pool.Ping(ctx)
|
||||
}
|
||||
|
||||
func (s *Store) Pool() *pgxpool.Pool {
|
||||
return s.pool
|
||||
}
|
||||
|
||||
type Platform struct {
|
||||
ID string `json:"id"`
|
||||
Provider string `json:"provider"`
|
||||
@ -374,6 +378,7 @@ type CreateTaskInput struct {
|
||||
Kind string `json:"kind"`
|
||||
Model string `json:"model"`
|
||||
RunMode string `json:"runMode"`
|
||||
Async bool `json:"async"`
|
||||
Request map[string]any `json:"request"`
|
||||
}
|
||||
|
||||
@ -398,7 +403,12 @@ type GatewayTask struct {
|
||||
ResolvedModel string `json:"resolvedModel,omitempty"`
|
||||
RequestID string `json:"requestId,omitempty"`
|
||||
Request map[string]any `json:"request,omitempty"`
|
||||
AsyncMode bool `json:"asyncMode"`
|
||||
RiverJobID int64 `json:"riverJobId,omitempty"`
|
||||
Status string `json:"status"`
|
||||
AttemptCount int `json:"attemptCount"`
|
||||
RemoteTaskID string `json:"remoteTaskId,omitempty"`
|
||||
RemoteTaskPayload map[string]any `json:"remoteTaskPayload,omitempty"`
|
||||
Result map[string]any `json:"result,omitempty"`
|
||||
Billings []any `json:"billings,omitempty"`
|
||||
Usage map[string]any `json:"usage"`
|
||||
@ -423,7 +433,9 @@ COALESCE(gateway_tenant_id::text, ''), COALESCE(tenant_id, ''), COALESCE(tenant_
|
||||
COALESCE(api_key_id, ''), COALESCE(api_key_name, ''), COALESCE(api_key_prefix, ''),
|
||||
COALESCE(user_group_id::text, ''), COALESCE(user_group_key, ''), model,
|
||||
COALESCE(model_type, ''), COALESCE(requested_model, ''), COALESCE(resolved_model, ''), COALESCE(request_id, ''),
|
||||
request, status, COALESCE(result, '{}'::jsonb), COALESCE(billings, '[]'::jsonb),
|
||||
request, COALESCE(async_mode, false), COALESCE(river_job_id, 0), status, COALESCE(attempt_count, 0),
|
||||
COALESCE(remote_task_id, ''), COALESCE(remote_task_payload, '{}'::jsonb),
|
||||
COALESCE(result, '{}'::jsonb), COALESCE(billings, '[]'::jsonb),
|
||||
COALESCE(usage, '{}'::jsonb), COALESCE(metrics, '{}'::jsonb), COALESCE(billing_summary, '{}'::jsonb),
|
||||
COALESCE(final_charge_amount, 0)::float8, COALESCE(response_started_at::text, ''),
|
||||
COALESCE(response_finished_at::text, ''), COALESCE(response_duration_ms, 0), COALESCE(error, ''),
|
||||
@ -1675,11 +1687,11 @@ func (s *Store) CreateTask(ctx context.Context, input CreateTaskInput, user *aut
|
||||
INSERT INTO gateway_tasks (
|
||||
kind, run_mode, user_id, gateway_user_id, user_source, gateway_tenant_id, tenant_id, tenant_key,
|
||||
api_key_id, api_key_name, api_key_prefix, user_group_id, user_group_key,
|
||||
model, requested_model, request, status, result, billings, finished_at
|
||||
model, requested_model, request, async_mode, status, result, billings, finished_at
|
||||
)
|
||||
VALUES ($1, $2, $3, NULLIF($4, '')::uuid, COALESCE(NULLIF($5, ''), 'gateway'), NULLIF($6, '')::uuid, NULLIF($7, ''), NULLIF($8, ''), NULLIF($9, ''), NULLIF($10, ''), NULLIF($11, ''), NULLIF($12, '')::uuid, NULLIF($13, ''), $14, $14, $15, $16, $17::jsonb, $18::jsonb, CASE WHEN $19 THEN now() ELSE NULL END)
|
||||
VALUES ($1, $2, $3, NULLIF($4, '')::uuid, COALESCE(NULLIF($5, ''), 'gateway'), NULLIF($6, '')::uuid, NULLIF($7, ''), NULLIF($8, ''), NULLIF($9, ''), NULLIF($10, ''), NULLIF($11, ''), NULLIF($12, '')::uuid, NULLIF($13, ''), $14, $14, $15, $16, $17, $18::jsonb, $19::jsonb, CASE WHEN $20 THEN now() ELSE NULL END)
|
||||
RETURNING `+gatewayTaskColumns,
|
||||
input.Kind, runMode, user.ID, user.GatewayUserID, user.Source, user.GatewayTenantID, user.TenantID, user.TenantKey, user.APIKeyID, user.APIKeyName, user.APIKeyPrefix, user.UserGroupID, user.UserGroupKey, input.Model, requestBody, status, resultBody, billingsBody, false,
|
||||
input.Kind, runMode, user.ID, user.GatewayUserID, user.Source, user.GatewayTenantID, user.TenantID, user.TenantKey, user.APIKeyID, user.APIKeyName, user.APIKeyPrefix, user.UserGroupID, user.UserGroupKey, input.Model, requestBody, input.Async, status, resultBody, billingsBody, false,
|
||||
))
|
||||
if err != nil {
|
||||
return GatewayTask{}, err
|
||||
@ -1689,7 +1701,7 @@ func (s *Store) CreateTask(ctx context.Context, input CreateTaskInput, user *aut
|
||||
payload, _ := json.Marshal(event.Payload)
|
||||
if _, err := tx.Exec(ctx, `
|
||||
INSERT INTO gateway_task_events (task_id, seq, event_type, status, phase, progress, message, payload, simulated)
|
||||
VALUES ($1::uuid, $2, $3, NULLIF($4, ''), NULLIF($5, ''), $6, NULLIF($7, ''), $8::jsonb, $9)`,
|
||||
VALUES ($1::uuid, $2, $3::text, NULLIF($4::text, ''), NULLIF($5::text, ''), $6, NULLIF($7::text, ''), $8::jsonb, $9)`,
|
||||
task.ID, event.Seq, event.EventType, event.Status, event.Phase, event.Progress, event.Message, string(payload), event.Simulated,
|
||||
); err != nil {
|
||||
return GatewayTask{}, err
|
||||
@ -1730,6 +1742,7 @@ func scanGatewayTask(scanner taskScanner) (GatewayTask, error) {
|
||||
var usageBytes []byte
|
||||
var metricsBytes []byte
|
||||
var billingSummaryBytes []byte
|
||||
var remoteTaskPayloadBytes []byte
|
||||
if err := scanner.Scan(
|
||||
&task.ID,
|
||||
&task.Kind,
|
||||
@ -1751,7 +1764,12 @@ func scanGatewayTask(scanner taskScanner) (GatewayTask, error) {
|
||||
&task.ResolvedModel,
|
||||
&task.RequestID,
|
||||
&requestBytes,
|
||||
&task.AsyncMode,
|
||||
&task.RiverJobID,
|
||||
&task.Status,
|
||||
&task.AttemptCount,
|
||||
&task.RemoteTaskID,
|
||||
&remoteTaskPayloadBytes,
|
||||
&resultBytes,
|
||||
&billingsBytes,
|
||||
&usageBytes,
|
||||
@ -1771,6 +1789,7 @@ func scanGatewayTask(scanner taskScanner) (GatewayTask, error) {
|
||||
return GatewayTask{}, err
|
||||
}
|
||||
task.Request = decodeObject(requestBytes)
|
||||
task.RemoteTaskPayload = decodeObject(remoteTaskPayloadBytes)
|
||||
task.Result = decodeObject(resultBytes)
|
||||
task.Billings = decodeArray(billingsBytes)
|
||||
task.Usage = decodeObject(usageBytes)
|
||||
|
||||
@ -31,6 +31,7 @@ type ModelRateLimitStatus struct {
|
||||
PlatformCooldownUntil string `json:"platformCooldownUntil,omitempty"`
|
||||
ModelCooldownUntil string `json:"modelCooldownUntil,omitempty"`
|
||||
Concurrent RateLimitMetricStatus `json:"concurrent"`
|
||||
QueuedTasks float64 `json:"queuedTasks"`
|
||||
RPM RateLimitMetricStatus `json:"rpm"`
|
||||
TPM RateLimitMetricStatus `json:"tpm"`
|
||||
LoadRatio float64 `json:"loadRatio"`
|
||||
@ -38,15 +39,16 @@ type ModelRateLimitStatus struct {
|
||||
|
||||
func (s *Store) ListModelRateLimitStatuses(ctx context.Context) ([]ModelRateLimitStatus, error) {
|
||||
rows, err := s.pool.Query(ctx, `
|
||||
SELECT m.id::text, m.platform_id::text, p.name, p.provider,
|
||||
m.model_name, COALESCE(NULLIF(m.provider_model_name, ''), m.model_name), COALESCE(m.model_alias, ''),
|
||||
m.model_type, m.display_name, m.enabled,
|
||||
p.rate_limit_policy, COALESCE(rp.rate_limit_policy, '{}'::jsonb), COALESCE(NULLIF(m.runtime_policy_override, '{}'::jsonb), b.runtime_policy_override, '{}'::jsonb), m.rate_limit_policy,
|
||||
COALESCE(to_char(p.cooldown_until AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.MS"Z"'), ''),
|
||||
COALESCE(to_char(m.cooldown_until AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.MS"Z"'), ''),
|
||||
COALESCE(con.active, 0)::float8,
|
||||
COALESCE(rpm.used_value, 0)::float8, COALESCE(rpm.reserved_value, 0)::float8, COALESCE(rpm.reset_at::text, ''),
|
||||
COALESCE(tpm.used_value, 0)::float8, COALESCE(tpm.reserved_value, 0)::float8, COALESCE(tpm.reset_at::text, '')
|
||||
SELECT m.id::text, m.platform_id::text, p.name, p.provider,
|
||||
m.model_name, COALESCE(NULLIF(m.provider_model_name, ''), m.model_name), COALESCE(m.model_alias, ''),
|
||||
m.model_type, m.display_name, m.enabled,
|
||||
p.rate_limit_policy, COALESCE(rp.rate_limit_policy, '{}'::jsonb), COALESCE(NULLIF(m.runtime_policy_override, '{}'::jsonb), b.runtime_policy_override, '{}'::jsonb), m.rate_limit_policy,
|
||||
COALESCE(to_char(p.cooldown_until AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.MS"Z"'), ''),
|
||||
COALESCE(to_char(m.cooldown_until AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.MS"Z"'), ''),
|
||||
COALESCE(con.active, 0)::float8,
|
||||
COALESCE(queued.waiting, 0)::float8,
|
||||
COALESCE(rpm.used_value, 0)::float8, COALESCE(rpm.reserved_value, 0)::float8, COALESCE(rpm.reset_at::text, ''),
|
||||
COALESCE(tpm.used_value, 0)::float8, COALESCE(tpm.reserved_value, 0)::float8, COALESCE(tpm.reset_at::text, '')
|
||||
FROM platform_models m
|
||||
JOIN integration_platforms p ON p.id = m.platform_id
|
||||
LEFT JOIN base_model_catalog b ON b.id = m.base_model_id
|
||||
@ -59,6 +61,19 @@ LEFT JOIN (
|
||||
AND expires_at > now()
|
||||
GROUP BY scope_key
|
||||
) con ON con.scope_key = m.id::text
|
||||
LEFT JOIN (
|
||||
SELECT latest.platform_model_id, COUNT(*) AS waiting
|
||||
FROM (
|
||||
SELECT DISTINCT ON (a.task_id) a.task_id, a.platform_model_id::text AS platform_model_id
|
||||
FROM gateway_tasks t
|
||||
JOIN gateway_task_attempts a ON a.task_id = t.id
|
||||
WHERE t.async_mode = true
|
||||
AND t.status = 'queued'
|
||||
AND a.platform_model_id IS NOT NULL
|
||||
ORDER BY a.task_id, a.attempt_no DESC, a.started_at DESC
|
||||
) latest
|
||||
GROUP BY latest.platform_model_id
|
||||
) queued ON queued.platform_model_id = m.id::text
|
||||
LEFT JOIN (
|
||||
SELECT DISTINCT ON (scope_key) scope_key, used_value, reserved_value, reset_at
|
||||
FROM gateway_rate_limit_counters
|
||||
@ -93,6 +108,7 @@ ORDER BY p.priority ASC, m.model_name ASC`)
|
||||
var platformCooldownUntil string
|
||||
var modelCooldownUntil string
|
||||
var concurrentCurrent float64
|
||||
var queuedTasks float64
|
||||
var rpmUsed float64
|
||||
var rpmReserved float64
|
||||
var rpmResetAt string
|
||||
@ -117,6 +133,7 @@ ORDER BY p.priority ASC, m.model_name ASC`)
|
||||
&platformCooldownUntil,
|
||||
&modelCooldownUntil,
|
||||
&concurrentCurrent,
|
||||
&queuedTasks,
|
||||
&rpmUsed,
|
||||
&rpmReserved,
|
||||
&rpmResetAt,
|
||||
@ -136,6 +153,7 @@ ORDER BY p.priority ASC, m.model_name ASC`)
|
||||
item.PlatformCooldownUntil = platformCooldownUntil
|
||||
item.ModelCooldownUntil = modelCooldownUntil
|
||||
item.RateLimitPolicy = policy
|
||||
item.QueuedTasks = queuedTasks
|
||||
item.Concurrent = metricStatus(concurrentCurrent, concurrentCurrent, 0, rateLimitForMetric(policy, "concurrent"), "")
|
||||
item.RPM = metricStatus(rpmUsed+rpmReserved, rpmUsed, rpmReserved, rateLimitForMetric(policy, "rpm"), rpmResetAt)
|
||||
item.TPM = metricStatus(tpmUsed+tpmReserved, tpmUsed, tpmReserved, tpmLimit(policy), tpmResetAt)
|
||||
|
||||
@ -3,6 +3,7 @@ package store
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
@ -13,6 +14,7 @@ type RuntimeRecoveryResult struct {
|
||||
ReleasedRateReservations int64 `json:"releasedRateReservations"`
|
||||
FailedAttempts int64 `json:"failedAttempts"`
|
||||
FailedTasks int64 `json:"failedTasks"`
|
||||
RequeuedAsyncTasks int64 `json:"requeuedAsyncTasks"`
|
||||
}
|
||||
|
||||
func (s *Store) ReserveRateLimits(ctx context.Context, taskID string, attemptID string, reservations []RateLimitReservation) (RateLimitResult, error) {
|
||||
@ -28,7 +30,11 @@ func (s *Store) ReserveRateLimits(ctx context.Context, taskID string, attemptID
|
||||
continue
|
||||
}
|
||||
if reservation.Metric == "" || reservation.Amount > reservation.Limit {
|
||||
return RateLimitResult{}, ErrRateLimited
|
||||
return RateLimitResult{}, &RateLimitExceededError{
|
||||
Metric: reservation.Metric,
|
||||
Message: fmt.Sprintf("rate limit exceeded: %s request amount %.0f is greater than limit %.0f", reservation.Metric, reservation.Amount, reservation.Limit),
|
||||
Retryable: false,
|
||||
}
|
||||
}
|
||||
if reservation.WindowSeconds <= 0 {
|
||||
reservation.WindowSeconds = 60
|
||||
@ -55,8 +61,10 @@ func reserveConcurrencyLease(ctx context.Context, tx pgx.Tx, taskID string, atte
|
||||
reservation.LeaseTTLSeconds = 120
|
||||
}
|
||||
var active float64
|
||||
var nextAvailableAt time.Time
|
||||
if err := tx.QueryRow(ctx, `
|
||||
SELECT COALESCE(SUM(lease_value), 0)::float8
|
||||
SELECT COALESCE(SUM(lease_value), 0)::float8,
|
||||
COALESCE(MIN(expires_at), now() + ($3::int * interval '1 second'))
|
||||
FROM gateway_concurrency_leases
|
||||
WHERE scope_type = $1
|
||||
AND scope_key = $2
|
||||
@ -64,11 +72,17 @@ WHERE scope_type = $1
|
||||
AND expires_at > now()`,
|
||||
reservation.ScopeType,
|
||||
reservation.ScopeKey,
|
||||
).Scan(&active); err != nil {
|
||||
reservation.LeaseTTLSeconds,
|
||||
).Scan(&active, &nextAvailableAt); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if active+reservation.Amount > reservation.Limit {
|
||||
return "", ErrRateLimited
|
||||
return "", &RateLimitExceededError{
|
||||
Metric: reservation.Metric,
|
||||
Message: fmt.Sprintf("rate limit exceeded: concurrent active %.0f plus request %.0f is greater than limit %.0f", active, reservation.Amount, reservation.Limit),
|
||||
RetryAfter: concurrencyRetryAfter(nextAvailableAt),
|
||||
Retryable: true,
|
||||
}
|
||||
}
|
||||
var leaseID string
|
||||
if err := tx.QueryRow(ctx, `
|
||||
@ -92,13 +106,16 @@ func reserveCounterWindow(ctx context.Context, tx pgx.Tx, taskID string, attempt
|
||||
reservedAmount := reservation.Amount
|
||||
var windowStart time.Time
|
||||
err := tx.QueryRow(ctx, `
|
||||
WITH bounds AS (
|
||||
SELECT
|
||||
to_timestamp(floor(extract(epoch FROM now()) / $7::int) * $7::int) AS window_start,
|
||||
to_timestamp(floor(extract(epoch FROM now()) / $7::int) * $7::int) + ($7::int * interval '1 second') AS reset_at
|
||||
)
|
||||
INSERT INTO gateway_rate_limit_counters (
|
||||
scope_type, scope_key, metric, window_start, limit_value, used_value, reserved_value, reset_at
|
||||
)
|
||||
VALUES (
|
||||
$1, $2, $3, date_trunc('minute', now()), $4, $5, $6,
|
||||
date_trunc('minute', now()) + ($7::int * interval '1 second')
|
||||
)
|
||||
SELECT $1, $2, $3, bounds.window_start, $4, $5, $6, bounds.reset_at
|
||||
FROM bounds
|
||||
ON CONFLICT (scope_type, scope_key, metric, window_start) DO UPDATE
|
||||
SET limit_value = EXCLUDED.limit_value,
|
||||
used_value = gateway_rate_limit_counters.used_value + EXCLUDED.used_value,
|
||||
@ -117,7 +134,28 @@ RETURNING window_start`,
|
||||
).Scan(&windowStart)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return RateLimitReservation{}, ErrRateLimited
|
||||
resetAt := time.Now().Add(time.Duration(reservation.WindowSeconds) * time.Second)
|
||||
_ = tx.QueryRow(ctx, `
|
||||
WITH bounds AS (
|
||||
SELECT to_timestamp(floor(extract(epoch FROM now()) / $4::int) * $4::int) AS window_start
|
||||
)
|
||||
SELECT counters.reset_at
|
||||
FROM gateway_rate_limit_counters counters
|
||||
JOIN bounds ON counters.window_start = bounds.window_start
|
||||
WHERE scope_type = $1
|
||||
AND scope_key = $2
|
||||
AND metric = $3`,
|
||||
reservation.ScopeType,
|
||||
reservation.ScopeKey,
|
||||
reservation.Metric,
|
||||
reservation.WindowSeconds,
|
||||
).Scan(&resetAt)
|
||||
return RateLimitReservation{}, &RateLimitExceededError{
|
||||
Metric: reservation.Metric,
|
||||
Message: fmt.Sprintf("rate limit exceeded: %s window has no remaining capacity", reservation.Metric),
|
||||
RetryAfter: retryAfterUntil(resetAt),
|
||||
Retryable: true,
|
||||
}
|
||||
}
|
||||
return RateLimitReservation{}, err
|
||||
}
|
||||
@ -144,6 +182,28 @@ RETURNING id::text`,
|
||||
return reservation, nil
|
||||
}
|
||||
|
||||
func retryAfterUntil(when time.Time) time.Duration {
|
||||
if when.IsZero() {
|
||||
return 0
|
||||
}
|
||||
duration := time.Until(when)
|
||||
if duration < time.Second {
|
||||
return time.Second
|
||||
}
|
||||
return duration
|
||||
}
|
||||
|
||||
func concurrencyRetryAfter(leaseExpiresAt time.Time) time.Duration {
|
||||
if leaseExpiresAt.IsZero() {
|
||||
return time.Second
|
||||
}
|
||||
duration := time.Until(leaseExpiresAt)
|
||||
if duration <= time.Second {
|
||||
return time.Second
|
||||
}
|
||||
return time.Second
|
||||
}
|
||||
|
||||
func (s *Store) CommitRateLimitReservations(ctx context.Context, reservations []RateLimitReservation, actualByMetric map[string]float64) error {
|
||||
return s.finishRateLimitReservations(ctx, reservations, actualByMetric, "committed", "success")
|
||||
}
|
||||
@ -189,23 +249,26 @@ RETURNING scope_type, scope_key, metric, window_start, reserved_amount::float8`)
|
||||
if err != nil {
|
||||
return RuntimeRecoveryResult{}, err
|
||||
}
|
||||
reservations := make([]RateLimitReservation, 0)
|
||||
for rows.Next() {
|
||||
var reservation RateLimitReservation
|
||||
if err := rows.Scan(&reservation.ScopeType, &reservation.ScopeKey, &reservation.Metric, &reservation.WindowStart, &reservation.Amount); err != nil {
|
||||
rows.Close()
|
||||
return RuntimeRecoveryResult{}, err
|
||||
}
|
||||
if err := releaseCounterReservation(ctx, tx, reservation.ScopeType, reservation.ScopeKey, reservation.Metric, reservation.WindowStart, reservation.Amount); err != nil {
|
||||
rows.Close()
|
||||
return RuntimeRecoveryResult{}, err
|
||||
}
|
||||
result.ReleasedRateReservations++
|
||||
reservations = append(reservations, reservation)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
rows.Close()
|
||||
return RuntimeRecoveryResult{}, err
|
||||
}
|
||||
rows.Close()
|
||||
for _, reservation := range reservations {
|
||||
if err := releaseCounterReservation(ctx, tx, reservation.ScopeType, reservation.ScopeKey, reservation.Metric, reservation.WindowStart, reservation.Amount); err != nil {
|
||||
return RuntimeRecoveryResult{}, err
|
||||
}
|
||||
}
|
||||
result.ReleasedRateReservations = int64(len(reservations))
|
||||
|
||||
tag, err := tx.Exec(ctx, `
|
||||
UPDATE gateway_concurrency_leases
|
||||
@ -220,7 +283,7 @@ WHERE released_at IS NULL
|
||||
tag, err = tx.Exec(ctx, `
|
||||
UPDATE gateway_task_attempts
|
||||
SET status = 'failed',
|
||||
retryable = false,
|
||||
retryable = true,
|
||||
error_code = 'server_restarted',
|
||||
error_message = 'attempt interrupted by service restart',
|
||||
finished_at = now()
|
||||
@ -230,6 +293,57 @@ WHERE status = 'running'`)
|
||||
}
|
||||
result.FailedAttempts = tag.RowsAffected()
|
||||
|
||||
asyncTaskRows, err := tx.Query(ctx, `
|
||||
UPDATE gateway_tasks
|
||||
SET status = 'queued',
|
||||
error = NULL,
|
||||
error_code = NULL,
|
||||
error_message = NULL,
|
||||
locked_by = NULL,
|
||||
locked_at = NULL,
|
||||
heartbeat_at = NULL,
|
||||
next_run_at = now(),
|
||||
finished_at = NULL,
|
||||
updated_at = now()
|
||||
WHERE async_mode = true
|
||||
AND status = 'running'
|
||||
RETURNING id::text`)
|
||||
if err != nil {
|
||||
return RuntimeRecoveryResult{}, err
|
||||
}
|
||||
asyncTaskIDs := make([]string, 0)
|
||||
for asyncTaskRows.Next() {
|
||||
var taskID string
|
||||
if err := asyncTaskRows.Scan(&taskID); err != nil {
|
||||
asyncTaskRows.Close()
|
||||
return RuntimeRecoveryResult{}, err
|
||||
}
|
||||
asyncTaskIDs = append(asyncTaskIDs, taskID)
|
||||
}
|
||||
if err := asyncTaskRows.Err(); err != nil {
|
||||
asyncTaskRows.Close()
|
||||
return RuntimeRecoveryResult{}, err
|
||||
}
|
||||
asyncTaskRows.Close()
|
||||
for _, taskID := range asyncTaskIDs {
|
||||
if _, err := tx.Exec(ctx, `
|
||||
INSERT INTO gateway_task_events (task_id, seq, event_type, status, phase, progress, message, payload, simulated)
|
||||
VALUES (
|
||||
$1::uuid,
|
||||
COALESCE((SELECT MAX(seq) + 1 FROM gateway_task_events WHERE task_id = $1::uuid), 1),
|
||||
'task.recovered',
|
||||
'queued',
|
||||
'recovered',
|
||||
0.2,
|
||||
'async task recovered after service restart',
|
||||
'{"code":"server_restarted"}'::jsonb,
|
||||
false
|
||||
)`, taskID); err != nil {
|
||||
return RuntimeRecoveryResult{}, err
|
||||
}
|
||||
}
|
||||
result.RequeuedAsyncTasks = int64(len(asyncTaskIDs))
|
||||
|
||||
taskRows, err := tx.Query(ctx, `
|
||||
UPDATE gateway_tasks
|
||||
SET status = 'failed',
|
||||
@ -238,7 +352,8 @@ SET status = 'failed',
|
||||
error_message = 'task interrupted by service restart',
|
||||
finished_at = now(),
|
||||
updated_at = now()
|
||||
WHERE status IN ('queued', 'running')
|
||||
WHERE async_mode = false
|
||||
AND status = 'running'
|
||||
RETURNING id::text`)
|
||||
if err != nil {
|
||||
return RuntimeRecoveryResult{}, err
|
||||
@ -301,9 +416,9 @@ func (s *Store) finishRateLimitReservations(ctx context.Context, reservations []
|
||||
var stored RateLimitReservation
|
||||
err := tx.QueryRow(ctx, `
|
||||
UPDATE gateway_rate_limit_reservations
|
||||
SET status = $2,
|
||||
reason = NULLIF($3, ''),
|
||||
actual_amount = CASE WHEN $2 = 'committed' THEN $4 ELSE actual_amount END,
|
||||
SET status = $2::text,
|
||||
reason = NULLIF($3::text, ''),
|
||||
actual_amount = CASE WHEN $2::text = 'committed' THEN $4 ELSE actual_amount END,
|
||||
finalized_at = now(),
|
||||
updated_at = now()
|
||||
WHERE id = $1::uuid
|
||||
|
||||
@ -32,6 +32,43 @@ func ModelCandidateErrorCode(err error) string {
|
||||
return "no_model_candidate"
|
||||
}
|
||||
|
||||
type RateLimitExceededError struct {
|
||||
Metric string
|
||||
Message string
|
||||
RetryAfter time.Duration
|
||||
Retryable bool
|
||||
}
|
||||
|
||||
func (e *RateLimitExceededError) Error() string {
|
||||
if strings.TrimSpace(e.Message) != "" {
|
||||
return e.Message
|
||||
}
|
||||
if strings.TrimSpace(e.Metric) != "" {
|
||||
return "rate limit exceeded for " + e.Metric
|
||||
}
|
||||
return ErrRateLimited.Error()
|
||||
}
|
||||
|
||||
func (e *RateLimitExceededError) Unwrap() error {
|
||||
return ErrRateLimited
|
||||
}
|
||||
|
||||
func RateLimitRetryAfter(err error) time.Duration {
|
||||
var limitErr *RateLimitExceededError
|
||||
if errors.As(err, &limitErr) && limitErr.RetryAfter > 0 {
|
||||
return limitErr.RetryAfter
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func RateLimitRetryable(err error) bool {
|
||||
var limitErr *RateLimitExceededError
|
||||
if errors.As(err, &limitErr) {
|
||||
return limitErr.Retryable
|
||||
}
|
||||
return errors.Is(err, ErrRateLimited)
|
||||
}
|
||||
|
||||
type CreatePlatformModelInput struct {
|
||||
PlatformID string `json:"platformId"`
|
||||
BaseModelID string `json:"baseModelId"`
|
||||
@ -132,6 +169,11 @@ type CreateTaskAttemptInput struct {
|
||||
Metrics map[string]any
|
||||
}
|
||||
|
||||
type AsyncTaskQueueItem struct {
|
||||
ID string
|
||||
Priority int
|
||||
}
|
||||
|
||||
type FinishTaskAttemptInput struct {
|
||||
AttemptID string
|
||||
Status string
|
||||
|
||||
@ -157,7 +157,7 @@ func (s *Store) MarkTaskRunning(ctx context.Context, taskID string, modelType st
|
||||
_, err := s.pool.Exec(ctx, `
|
||||
UPDATE gateway_tasks
|
||||
SET status = 'running',
|
||||
model_type = NULLIF($2, ''),
|
||||
model_type = NULLIF($2::text, ''),
|
||||
normalized_request = $3::jsonb,
|
||||
locked_at = now(),
|
||||
heartbeat_at = now(),
|
||||
@ -166,6 +166,124 @@ WHERE id = $1::uuid`, taskID, modelType, string(normalizedJSON))
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) ClaimAsyncQueuedTask(ctx context.Context, workerID string) (GatewayTask, error) {
|
||||
return scanGatewayTask(s.pool.QueryRow(ctx, `
|
||||
WITH picked AS (
|
||||
SELECT id AS task_id
|
||||
FROM gateway_tasks
|
||||
WHERE async_mode = true
|
||||
AND status = 'queued'
|
||||
AND next_run_at <= now()
|
||||
ORDER BY priority ASC, created_at ASC
|
||||
LIMIT 1
|
||||
FOR UPDATE SKIP LOCKED
|
||||
)
|
||||
UPDATE gateway_tasks t
|
||||
SET status = 'running',
|
||||
locked_by = NULLIF($1::text, ''),
|
||||
locked_at = now(),
|
||||
heartbeat_at = now(),
|
||||
updated_at = now()
|
||||
FROM picked
|
||||
WHERE t.id = picked.task_id
|
||||
RETURNING `+gatewayTaskColumns, workerID))
|
||||
}
|
||||
|
||||
func (s *Store) RequeueTask(ctx context.Context, taskID string, delay time.Duration) (GatewayTask, error) {
|
||||
if delay < time.Second {
|
||||
delay = time.Second
|
||||
}
|
||||
if delay > 10*time.Minute {
|
||||
delay = 10 * time.Minute
|
||||
}
|
||||
nextRunAt := time.Now().Add(delay)
|
||||
return scanGatewayTask(s.pool.QueryRow(ctx, `
|
||||
UPDATE gateway_tasks
|
||||
SET status = 'queued',
|
||||
locked_by = NULL,
|
||||
locked_at = NULL,
|
||||
heartbeat_at = NULL,
|
||||
next_run_at = $2::timestamptz,
|
||||
error = NULL,
|
||||
error_code = NULL,
|
||||
error_message = NULL,
|
||||
updated_at = now()
|
||||
WHERE id = $1::uuid
|
||||
RETURNING `+gatewayTaskColumns, taskID, nextRunAt))
|
||||
}
|
||||
|
||||
func (s *Store) SetTaskRiverJobID(ctx context.Context, taskID string, riverJobID int64) error {
|
||||
if riverJobID <= 0 {
|
||||
return nil
|
||||
}
|
||||
_, err := s.pool.Exec(ctx, `
|
||||
UPDATE gateway_tasks
|
||||
SET river_job_id = $2,
|
||||
updated_at = now()
|
||||
WHERE id = $1::uuid`, taskID, riverJobID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) SetTaskRemoteTask(ctx context.Context, taskID string, attemptID string, remoteTaskID string, payload map[string]any) error {
|
||||
payloadJSON, _ := json.Marshal(emptyObjectIfNil(payload))
|
||||
return pgx.BeginFunc(ctx, s.pool, func(tx pgx.Tx) error {
|
||||
if _, err := tx.Exec(ctx, `
|
||||
UPDATE gateway_tasks
|
||||
SET remote_task_id = NULLIF($2::text, ''),
|
||||
remote_task_payload = $3::jsonb,
|
||||
updated_at = now()
|
||||
WHERE id = $1::uuid`,
|
||||
taskID,
|
||||
remoteTaskID,
|
||||
string(payloadJSON),
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
if strings.TrimSpace(attemptID) == "" {
|
||||
return nil
|
||||
}
|
||||
_, err := tx.Exec(ctx, `
|
||||
UPDATE gateway_task_attempts
|
||||
SET remote_task_id = NULLIF($2::text, ''),
|
||||
response_snapshot = COALESCE(response_snapshot, '{}'::jsonb) || jsonb_build_object('remote_task_payload', $3::jsonb)
|
||||
WHERE id = $1::uuid`,
|
||||
attemptID,
|
||||
remoteTaskID,
|
||||
string(payloadJSON),
|
||||
)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Store) ListRecoverableAsyncTasks(ctx context.Context, limit int) ([]AsyncTaskQueueItem, error) {
|
||||
if limit <= 0 {
|
||||
limit = 500
|
||||
}
|
||||
rows, err := s.pool.Query(ctx, `
|
||||
SELECT id::text, priority
|
||||
FROM gateway_tasks
|
||||
WHERE async_mode = true
|
||||
AND status IN ('queued', 'running')
|
||||
ORDER BY priority ASC, created_at ASC
|
||||
LIMIT $1`, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
items := make([]AsyncTaskQueueItem, 0)
|
||||
for rows.Next() {
|
||||
var item AsyncTaskQueueItem
|
||||
if err := rows.Scan(&item.ID, &item.Priority); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (s *Store) CreateTaskAttempt(ctx context.Context, input CreateTaskAttemptInput) (string, error) {
|
||||
requestJSON, _ := json.Marshal(emptyObjectIfNil(input.RequestSnapshot))
|
||||
metricsJSON, _ := json.Marshal(emptyObjectIfNil(input.Metrics))
|
||||
@ -182,7 +300,7 @@ INSERT INTO gateway_task_attempts (
|
||||
status, simulated, request_snapshot, metrics
|
||||
)
|
||||
VALUES (
|
||||
$1::uuid, $2, NULLIF($3, '')::uuid, NULLIF($4, '')::uuid, NULLIF($5, ''), $6,
|
||||
$1::uuid, $2::int, NULLIF($3::text, '')::uuid, NULLIF($4::text, '')::uuid, NULLIF($5::text, ''), $6,
|
||||
$7, $8, $9::jsonb, $10::jsonb
|
||||
)
|
||||
RETURNING id::text`,
|
||||
@ -202,7 +320,7 @@ RETURNING id::text`,
|
||||
}
|
||||
if _, err := tx.Exec(ctx, `
|
||||
UPDATE gateway_tasks
|
||||
SET attempt_count = GREATEST(attempt_count, $2), updated_at = now()
|
||||
SET attempt_count = GREATEST(attempt_count, $2::int), updated_at = now()
|
||||
WHERE id = $1::uuid`, input.TaskID, input.AttemptNo); err != nil {
|
||||
return "", err
|
||||
}
|
||||
@ -252,7 +370,7 @@ SET metrics = jsonb_set(
|
||||
true
|
||||
)
|
||||
WHERE task_id = $1::uuid
|
||||
AND attempt_no = $2`, taskID, attemptNo, string(entryJSON))
|
||||
AND attempt_no = $2::int`, taskID, attemptNo, string(entryJSON))
|
||||
return err
|
||||
}
|
||||
|
||||
@ -386,17 +504,17 @@ func (s *Store) FinishTaskAttempt(ctx context.Context, input FinishTaskAttemptIn
|
||||
metricsJSON, _ := json.Marshal(emptyObjectIfNil(input.Metrics))
|
||||
_, err := s.pool.Exec(ctx, `
|
||||
UPDATE gateway_task_attempts
|
||||
SET status = $2,
|
||||
SET status = $2::text,
|
||||
retryable = $3,
|
||||
request_id = NULLIF($4, ''),
|
||||
request_id = NULLIF($4::text, ''),
|
||||
usage = $5::jsonb,
|
||||
metrics = $6::jsonb,
|
||||
response_snapshot = $7::jsonb,
|
||||
response_started_at = $8::timestamptz,
|
||||
response_finished_at = $9::timestamptz,
|
||||
response_duration_ms = $10,
|
||||
error_code = NULLIF($11, ''),
|
||||
error_message = NULLIF($12, ''),
|
||||
error_code = NULLIF($11::text, ''),
|
||||
error_message = NULLIF($12::text, ''),
|
||||
finished_at = now()
|
||||
WHERE id = $1::uuid`,
|
||||
input.AttemptID,
|
||||
@ -438,6 +556,9 @@ SET status = 'succeeded',
|
||||
error = NULL,
|
||||
error_code = NULL,
|
||||
error_message = NULL,
|
||||
locked_by = NULL,
|
||||
locked_at = NULL,
|
||||
heartbeat_at = NULL,
|
||||
finished_at = now(),
|
||||
updated_at = now()
|
||||
WHERE id = $1::uuid`,
|
||||
@ -561,14 +682,17 @@ func (s *Store) FinishTaskFailure(ctx context.Context, input FinishTaskFailureIn
|
||||
if _, err := s.pool.Exec(ctx, `
|
||||
UPDATE gateway_tasks
|
||||
SET status = 'failed',
|
||||
error = NULLIF($2, ''),
|
||||
error_code = NULLIF($3, ''),
|
||||
error_message = NULLIF($2, ''),
|
||||
request_id = NULLIF($4, ''),
|
||||
error = NULLIF($2::text, ''),
|
||||
error_code = NULLIF($3::text, ''),
|
||||
error_message = NULLIF($2::text, ''),
|
||||
request_id = NULLIF($4::text, ''),
|
||||
metrics = $5::jsonb,
|
||||
response_started_at = $6::timestamptz,
|
||||
response_finished_at = $7::timestamptz,
|
||||
response_duration_ms = $8,
|
||||
locked_by = NULL,
|
||||
locked_at = NULL,
|
||||
heartbeat_at = NULL,
|
||||
finished_at = now(),
|
||||
updated_at = now()
|
||||
WHERE id = $1::uuid`,
|
||||
@ -604,7 +728,7 @@ WITH next_seq AS (
|
||||
WHERE task_id = $1::uuid
|
||||
)
|
||||
INSERT INTO gateway_task_events (task_id, seq, event_type, status, phase, progress, message, payload, simulated)
|
||||
SELECT $1::uuid, next_seq.seq, $2, NULLIF($3, ''), NULLIF($4, ''), $5, NULLIF($6, ''), $7::jsonb, $8
|
||||
SELECT $1::uuid, next_seq.seq, $2::text, NULLIF($3::text, ''), NULLIF($4::text, ''), $5, NULLIF($6::text, ''), $7::jsonb, $8
|
||||
FROM next_seq
|
||||
RETURNING id::text, task_id::text, seq, event_type, COALESCE(status, ''), COALESCE(phase, ''),
|
||||
COALESCE(progress, 0)::float8, COALESCE(message, ''), payload, simulated, created_at`,
|
||||
@ -688,7 +812,7 @@ func (s *Store) RecordClientRelease(ctx context.Context, clientID string, lastEr
|
||||
_, err := s.pool.Exec(ctx, `
|
||||
UPDATE runtime_client_states
|
||||
SET running_count = GREATEST(running_count - 1, 0),
|
||||
last_error = NULLIF($2, ''),
|
||||
last_error = NULLIF($2::text, ''),
|
||||
updated_at = now()
|
||||
WHERE client_id = $1`, clientID, lastError)
|
||||
return err
|
||||
|
||||
6
apps/api/migrations/0031_async_task_queue.sql
Normal file
6
apps/api/migrations/0031_async_task_queue.sql
Normal file
@ -0,0 +1,6 @@
|
||||
ALTER TABLE IF EXISTS gateway_tasks
|
||||
ADD COLUMN IF NOT EXISTS async_mode boolean NOT NULL DEFAULT false;
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_gateway_tasks_async_queue
|
||||
ON gateway_tasks(async_mode, status, next_run_at, priority, created_at)
|
||||
WHERE async_mode = true;
|
||||
19
apps/api/migrations/0032_river_async_queue.sql
Normal file
19
apps/api/migrations/0032_river_async_queue.sql
Normal file
@ -0,0 +1,19 @@
|
||||
ALTER TABLE IF EXISTS gateway_tasks
|
||||
ADD COLUMN IF NOT EXISTS river_job_id bigint,
|
||||
ADD COLUMN IF NOT EXISTS remote_task_id text,
|
||||
ADD COLUMN IF NOT EXISTS remote_task_payload jsonb;
|
||||
|
||||
UPDATE gateway_tasks
|
||||
SET remote_task_payload = '{}'::jsonb
|
||||
WHERE remote_task_payload IS NULL;
|
||||
|
||||
ALTER TABLE IF EXISTS gateway_tasks
|
||||
ALTER COLUMN remote_task_payload SET DEFAULT '{}'::jsonb;
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_gateway_tasks_river_job
|
||||
ON gateway_tasks(river_job_id)
|
||||
WHERE river_job_id IS NOT NULL;
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_gateway_tasks_async_recover
|
||||
ON gateway_tasks(async_mode, status, priority, created_at)
|
||||
WHERE async_mode = true AND status IN ('queued', 'running');
|
||||
@ -46,7 +46,6 @@ import {
|
||||
getHealth,
|
||||
getNetworkProxyConfig,
|
||||
getRunnerPolicy,
|
||||
getTask,
|
||||
getWalletSummary,
|
||||
listAccessRules,
|
||||
listAuditLogs,
|
||||
@ -72,6 +71,7 @@ import {
|
||||
listUserGroups,
|
||||
listUsers,
|
||||
loginLocalAccount,
|
||||
pollTaskUntilSettled,
|
||||
registerLocalAccount,
|
||||
replacePlatformModels,
|
||||
setUserWalletBalance,
|
||||
@ -788,7 +788,11 @@ export function App() {
|
||||
setCoreMessage('');
|
||||
try {
|
||||
const response = await runTask(credential, taskForm);
|
||||
const detail = await getTask(credential, response.task.id);
|
||||
const syncTask = (detail: GatewayTask) => {
|
||||
setTaskResult(detail);
|
||||
setTasks((current) => [detail, ...current.filter((item) => item.id !== detail.id)]);
|
||||
};
|
||||
const detail = await pollTaskUntilSettled(credential, response.task, { onUpdate: syncTask });
|
||||
setTaskResult(detail);
|
||||
setTasks((current) => [detail, ...current.filter((item) => item.id !== detail.id)]);
|
||||
invalidateDataKeys('tasks', 'wallet', 'walletTransactions');
|
||||
|
||||
@ -521,6 +521,7 @@ export async function createChatTask(
|
||||
): Promise<{ task: GatewayTask; next: Record<string, string> }> {
|
||||
return request<{ task: GatewayTask; next: Record<string, string> }>('/api/v1/chat/completions', {
|
||||
body: input,
|
||||
headers: { 'X-Async': 'true' },
|
||||
method: 'POST',
|
||||
token,
|
||||
});
|
||||
@ -598,6 +599,7 @@ export async function createImageGenerationTask(
|
||||
): Promise<{ task: GatewayTask; next: Record<string, string> }> {
|
||||
return request<{ task: GatewayTask; next: Record<string, string> }>('/api/v1/images/generations', {
|
||||
body: input,
|
||||
headers: { 'X-Async': 'true' },
|
||||
method: 'POST',
|
||||
token,
|
||||
});
|
||||
@ -609,6 +611,7 @@ export async function createImageEditTask(
|
||||
): Promise<{ task: GatewayTask; next: Record<string, string> }> {
|
||||
return request<{ task: GatewayTask; next: Record<string, string> }>('/api/v1/images/edits', {
|
||||
body: input,
|
||||
headers: { 'X-Async': 'true' },
|
||||
method: 'POST',
|
||||
token,
|
||||
});
|
||||
@ -636,6 +639,7 @@ export async function createVideoGenerationTask(
|
||||
): Promise<{ task: GatewayTask; next: Record<string, string> }> {
|
||||
return request<{ task: GatewayTask; next: Record<string, string> }>('/api/v1/videos/generations', {
|
||||
body: input,
|
||||
headers: { 'X-Async': 'true' },
|
||||
method: 'POST',
|
||||
token,
|
||||
});
|
||||
@ -656,6 +660,33 @@ export async function getTask(token: string, taskId: string): Promise<GatewayTas
|
||||
return request<GatewayTask>(`/api/workspace/tasks/${taskId}`, { token });
|
||||
}
|
||||
|
||||
export async function pollTaskUntilSettled(
|
||||
token: string,
|
||||
task: GatewayTask,
|
||||
options: { intervalMs?: number; maxAttempts?: number | null; onUpdate?: (task: GatewayTask) => void } = {},
|
||||
): Promise<GatewayTask> {
|
||||
let detail = task;
|
||||
const intervalMs = options.intervalMs ?? 1200;
|
||||
const maxAttempts = options.maxAttempts ?? Number.POSITIVE_INFINITY;
|
||||
for (let attempt = 0; attempt < maxAttempts; attempt += 1) {
|
||||
if (!taskIsPending(detail.status)) return detail;
|
||||
try {
|
||||
detail = await getTask(token, detail.id);
|
||||
options.onUpdate?.(detail);
|
||||
if (!taskIsPending(detail.status)) return detail;
|
||||
} catch {
|
||||
// Backend restarts or short network gaps should not turn a durable task into a failed UI run.
|
||||
// Only an explicit terminal task status from the task detail endpoint settles the run.
|
||||
}
|
||||
await delay(intervalMs);
|
||||
}
|
||||
return detail;
|
||||
}
|
||||
|
||||
export function taskIsPending(status: string) {
|
||||
return status === 'queued' || status === 'running' || status === 'submitting';
|
||||
}
|
||||
|
||||
export async function listTasks(token: string, query: WorkspaceTaskQuery): Promise<ListResponse<GatewayTask>> {
|
||||
const search = new URLSearchParams({
|
||||
page: String(query.page),
|
||||
@ -707,9 +738,9 @@ export async function getNetworkProxyConfig(token: string): Promise<GatewayNetwo
|
||||
|
||||
async function request<T>(
|
||||
path: string,
|
||||
options: { token?: string; auth?: boolean; method?: string; body?: unknown } = {},
|
||||
options: { token?: string; auth?: boolean; method?: string; body?: unknown; headers?: Record<string, string> } = {},
|
||||
): Promise<T> {
|
||||
const headers: Record<string, string> = {};
|
||||
const headers: Record<string, string> = { ...(options.headers ?? {}) };
|
||||
if (options.auth !== false && options.token) {
|
||||
headers.Authorization = `Bearer ${options.token}`;
|
||||
}
|
||||
@ -731,6 +762,10 @@ async function request<T>(
|
||||
return response.json() as Promise<T>;
|
||||
}
|
||||
|
||||
function delay(ms: number) {
|
||||
return new Promise((resolve) => window.setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
function parseErrorMessage(body: string) {
|
||||
return formatGatewayErrorDetails(parseErrorDetails(body));
|
||||
}
|
||||
|
||||
@ -21,7 +21,7 @@ import { mermaid } from '@streamdown/mermaid';
|
||||
import type { GatewayApiKey, GatewayTask, PlatformModel } from '@easyai-ai-gateway/contracts';
|
||||
import { Bot, ChevronDown, Image as ImageIcon, MessageSquarePlus, Paperclip, Send, Sparkles, Video } from 'lucide-react';
|
||||
import { Badge, Button, Select, Textarea } from '../components/ui';
|
||||
import { GatewayApiError, createImageGenerationTask, createVideoGenerationTask, getTask, streamChatCompletionText } from '../api';
|
||||
import { GatewayApiError, createImageGenerationTask, createVideoGenerationTask, pollTaskUntilSettled, streamChatCompletionText, taskIsPending } from '../api';
|
||||
import type { PlaygroundMode } from '../types';
|
||||
import {
|
||||
defaultMediaGenerationSettings,
|
||||
@ -170,14 +170,12 @@ export function PlaygroundPage(props: {
|
||||
resumableRuns.forEach((run) => {
|
||||
if (!run.task?.id) return;
|
||||
resumedTaskIdsRef.current.add(run.task.id);
|
||||
void pollTaskUntilSettled(credential, run.task)
|
||||
void pollTaskUntilSettled(credential, run.task, {
|
||||
onUpdate: (detail) => updateMediaRunFromTask(run.localId, detail),
|
||||
})
|
||||
.then((detail) => {
|
||||
if (!isMountedRef.current) return;
|
||||
setMediaRuns((current) => updateMediaRun(current, run.localId, {
|
||||
error: gatewayTaskErrorText(detail, '任务执行失败'),
|
||||
status: detail.status,
|
||||
task: detail,
|
||||
}));
|
||||
updateMediaRunFromTask(run.localId, detail);
|
||||
})
|
||||
.catch((err) => {
|
||||
if (!isMountedRef.current) return;
|
||||
@ -249,13 +247,11 @@ export function PlaygroundPage(props: {
|
||||
|
||||
async function pollMediaRunUntilSettled(credential: string, localId: string, task: GatewayTask) {
|
||||
try {
|
||||
const detail = await pollTaskUntilSettled(credential, task);
|
||||
const detail = await pollTaskUntilSettled(credential, task, {
|
||||
onUpdate: (nextTask) => updateMediaRunFromTask(localId, nextTask),
|
||||
});
|
||||
if (!isMountedRef.current) return;
|
||||
setMediaRuns((current) => updateMediaRun(current, localId, {
|
||||
error: gatewayTaskErrorText(detail, '任务执行失败'),
|
||||
status: detail.status,
|
||||
task: detail,
|
||||
}));
|
||||
updateMediaRunFromTask(localId, detail);
|
||||
} catch (err) {
|
||||
if (!isMountedRef.current) return;
|
||||
const errorMessage = err instanceof Error ? err.message : '任务状态同步失败';
|
||||
@ -263,6 +259,15 @@ export function PlaygroundPage(props: {
|
||||
}
|
||||
}
|
||||
|
||||
function updateMediaRunFromTask(localId: string, task: GatewayTask) {
|
||||
if (!isMountedRef.current) return;
|
||||
setMediaRuns((current) => updateMediaRun(current, localId, {
|
||||
error: taskIsPending(task.status) ? '' : gatewayTaskErrorText(task, '任务执行失败'),
|
||||
status: task.status,
|
||||
task,
|
||||
}));
|
||||
}
|
||||
|
||||
function editMediaRun(run: MediaGenerationRun) {
|
||||
setPrompt(run.prompt);
|
||||
setMediaSettings(run.settings);
|
||||
@ -1042,20 +1047,6 @@ function updateMediaRun(runs: MediaGenerationRun[], localId: string, patch: Part
|
||||
return runs.map((run) => run.localId === localId ? { ...run, ...patch } : run);
|
||||
}
|
||||
|
||||
async function pollTaskUntilSettled(token: string, task: GatewayTask) {
|
||||
let detail = task;
|
||||
for (let attempt = 0; attempt < 20; attempt += 1) {
|
||||
detail = await getTask(token, detail.id);
|
||||
if (!taskIsPending(detail.status)) return detail;
|
||||
await delay(1200);
|
||||
}
|
||||
return detail;
|
||||
}
|
||||
|
||||
function taskIsPending(status: string) {
|
||||
return status === 'queued' || status === 'running' || status === 'submitting';
|
||||
}
|
||||
|
||||
function readStoredMediaRuns(): MediaGenerationRun[] {
|
||||
if (typeof window === 'undefined') return [];
|
||||
try {
|
||||
@ -1182,10 +1173,6 @@ function booleanFromUnknown(value: unknown, fallback: boolean) {
|
||||
return fallback;
|
||||
}
|
||||
|
||||
function delay(ms: number) {
|
||||
return new Promise((resolve) => window.setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
function newLocalId() {
|
||||
return typeof crypto !== 'undefined' && 'randomUUID' in crypto
|
||||
? crypto.randomUUID()
|
||||
|
||||
@ -593,11 +593,14 @@ function RateLimitStatusTable(props: { statuses: ModelRateLimitStatus[]; platfor
|
||||
<TableRow className="shTableHeader">
|
||||
<TableHead>模型</TableHead>
|
||||
<TableHead>平台</TableHead>
|
||||
<TableHead>并发</TableHead>
|
||||
<TableHead>TPM</TableHead>
|
||||
<TableHead>RPM</TableHead>
|
||||
<TableHead>状态</TableHead>
|
||||
<TableHead>满载率</TableHead>
|
||||
<TableHead className="platformLimitMetricHead platformLimitNumberHead" title="正在执行 / 并发上限 / 排队任务">
|
||||
<span>并发</span>
|
||||
<small>正在执行 / 并发 / 排队</small>
|
||||
</TableHead>
|
||||
<TableHead className="platformLimitNumberHead">TPM</TableHead>
|
||||
<TableHead className="platformLimitNumberHead">RPM</TableHead>
|
||||
<TableHead className="platformLimitStatusHead">状态</TableHead>
|
||||
<TableHead className="platformLimitNumberHead">满载率</TableHead>
|
||||
</TableRow>
|
||||
{props.statuses.map((status) => {
|
||||
const platform = props.platformMap.get(status.platformId);
|
||||
@ -615,12 +618,12 @@ function RateLimitStatusTable(props: { statuses: ModelRateLimitStatus[]; platfor
|
||||
<small>{status.provider}</small>
|
||||
</span>
|
||||
</TableCell>
|
||||
<TableCell>{metricCell(status.concurrent)}</TableCell>
|
||||
<TableCell>{metricCell(status.tpm, true)}</TableCell>
|
||||
<TableCell>{metricCell(status.rpm)}</TableCell>
|
||||
<TableCell>{modelRuntimeStatusCell(status, props.now)}</TableCell>
|
||||
<TableCell>
|
||||
<span className="rateLoadCell">
|
||||
<TableCell className="platformLimitNumberCell">{concurrencyMetricCell(status)}</TableCell>
|
||||
<TableCell className="platformLimitNumberCell">{metricCell(status.tpm, true)}</TableCell>
|
||||
<TableCell className="platformLimitNumberCell">{metricCell(status.rpm)}</TableCell>
|
||||
<TableCell className="platformLimitStatusCell">{modelRuntimeStatusCell(status, props.now)}</TableCell>
|
||||
<TableCell className="platformLimitNumberCell">
|
||||
<span className="rateLoadCell" data-overloaded={status.loadRatio > 0.8 ? 'true' : undefined}>
|
||||
<strong>{formatPercent(status.loadRatio)}</strong>
|
||||
<span className="rateLoadTrack"><i style={{ width: `${Math.min(status.loadRatio * 100, 100)}%` }} /></span>
|
||||
</span>
|
||||
@ -1210,6 +1213,16 @@ function metricCell(metric: ModelRateLimitStatus['rpm'], includeReserved = false
|
||||
);
|
||||
}
|
||||
|
||||
function concurrencyMetricCell(status: ModelRateLimitStatus) {
|
||||
const queuedTasks = status.queuedTasks ?? 0;
|
||||
const limitText = status.concurrent.limited ? formatLimit(status.concurrent.limitValue) : '不限';
|
||||
return (
|
||||
<span className="rateMetricCell" title="正在执行 / 并发上限 / 排队任务">
|
||||
<strong>{formatLimit(status.concurrent.currentValue)} / {limitText} / {formatLimit(queuedTasks)}</strong>
|
||||
</span>
|
||||
);
|
||||
}
|
||||
|
||||
function reservedMetricText(metric: ModelRateLimitStatus['rpm']) {
|
||||
return `已结算 ${formatLimit(metric.usedValue)} + 预占 ${formatLimit(metric.reservedValue)}`;
|
||||
}
|
||||
|
||||
@ -1016,27 +1016,65 @@
|
||||
}
|
||||
|
||||
.platformLimitTable .shTableRow {
|
||||
grid-template-columns: clamp(150px, 16vw, 220px) minmax(148px, 1fr) minmax(104px, max-content) minmax(136px, max-content) minmax(122px, max-content) minmax(132px, max-content) minmax(128px, max-content);
|
||||
min-width: 920px;
|
||||
grid-template-columns: minmax(180px, 1.15fr) minmax(160px, 0.95fr) 150px 170px 140px 132px 132px;
|
||||
min-width: 1064px;
|
||||
}
|
||||
|
||||
.platformLimitTable .shTableHead,
|
||||
.platformLimitTable .shTableCell {
|
||||
display: grid;
|
||||
align-content: center;
|
||||
min-height: 68px;
|
||||
padding-right: 10px;
|
||||
padding-left: 10px;
|
||||
}
|
||||
|
||||
.platformLimitTable .shTableHead {
|
||||
min-height: 74px;
|
||||
}
|
||||
|
||||
.platformLimitMetricHead {
|
||||
display: grid;
|
||||
align-content: center;
|
||||
gap: 3px;
|
||||
white-space: normal;
|
||||
}
|
||||
|
||||
.platformLimitMetricHead small {
|
||||
overflow: hidden;
|
||||
color: var(--muted-foreground);
|
||||
font-size: var(--font-size-xs);
|
||||
font-weight: var(--font-weight-medium);
|
||||
line-height: 1.2;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.platformLimitNumberHead,
|
||||
.platformLimitNumberCell {
|
||||
justify-items: start;
|
||||
}
|
||||
|
||||
.platformLimitStatusHead,
|
||||
.platformLimitStatusCell {
|
||||
justify-items: start;
|
||||
}
|
||||
|
||||
.rateMetricCell,
|
||||
.rateLoadCell {
|
||||
display: grid;
|
||||
min-width: 0;
|
||||
gap: 4px;
|
||||
width: 100%;
|
||||
align-content: start;
|
||||
font-variant-numeric: tabular-nums;
|
||||
}
|
||||
|
||||
.rateMetricCell strong,
|
||||
.rateLoadCell strong {
|
||||
color: var(--text-strong);
|
||||
font-size: var(--font-size-sm);
|
||||
line-height: 1.25;
|
||||
}
|
||||
|
||||
.rateMetricCell small {
|
||||
@ -1050,6 +1088,7 @@
|
||||
.rateLoadTrack {
|
||||
display: block;
|
||||
height: 6px;
|
||||
width: min(112px, 100%);
|
||||
overflow: hidden;
|
||||
border-radius: 999px;
|
||||
background: #eef2f6;
|
||||
@ -1062,6 +1101,18 @@
|
||||
background: #0f766e;
|
||||
}
|
||||
|
||||
.rateLoadCell[data-overloaded="true"] strong {
|
||||
color: var(--destructive);
|
||||
}
|
||||
|
||||
.rateLoadCell[data-overloaded="true"] .rateLoadTrack {
|
||||
background: #fee2e2;
|
||||
}
|
||||
|
||||
.rateLoadCell[data-overloaded="true"] .rateLoadTrack i {
|
||||
background: var(--destructive);
|
||||
}
|
||||
|
||||
.platformModelToolbar {
|
||||
display: grid;
|
||||
grid-template-columns: minmax(220px, 0.6fr) minmax(260px, 1fr);
|
||||
|
||||
28
go.work.sum
28
go.work.sum
@ -1,16 +1,12 @@
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.7.2 h1:mLoDLV6sonKlvjIEsV56SkWNCnuNv531l94GaIzO+XI=
|
||||
github.com/jackc/pgx/v5 v5.7.2/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
|
||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
|
||||
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438/go.mod h1:a/s9Lp5W7n/DD0VrVoyJ00FbP2ytTPDVOivvn2bMlds=
|
||||
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
|
||||
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
golang.org/x/mod v0.35.0/go.mod h1:+GwiRhIInF8wPm+4AoT6L0FA1QWAad3OMdTRx4tFYlU=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
|
||||
golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
|
||||
@ -775,6 +775,7 @@ export interface ModelRateLimitStatus {
|
||||
platformCooldownUntil?: string;
|
||||
modelCooldownUntil?: string;
|
||||
concurrent: RateLimitMetricStatus;
|
||||
queuedTasks: number;
|
||||
rpm: RateLimitMetricStatus;
|
||||
tpm: RateLimitMetricStatus;
|
||||
loadRatio: number;
|
||||
@ -808,6 +809,7 @@ export interface GatewayTask {
|
||||
resolvedModel?: string;
|
||||
requestId?: string;
|
||||
request?: Record<string, unknown>;
|
||||
asyncMode?: boolean;
|
||||
status: 'queued' | 'running' | 'succeeded' | 'failed' | 'cancelled' | string;
|
||||
result?: Record<string, unknown>;
|
||||
billings?: unknown[];
|
||||
|
||||
Loading…
Reference in New Issue
Block a user