easyai-ai-gateway/apps/api/internal/store/tasks_runtime.go

518 lines
14 KiB
Go

package store
import (
"context"
"encoding/json"
"strings"
"time"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
)
type TaskListFilter struct {
Query string
ModelType string
CreatedFrom *time.Time
CreatedTo *time.Time
Page int
PageSize int
}
type TaskListResult struct {
Items []GatewayTask
Total int
Page int
PageSize int
}
func (s *Store) ListTasks(ctx context.Context, user *auth.User, filter TaskListFilter) (TaskListResult, error) {
page := filter.Page
if page <= 0 {
page = 1
}
pageSize := filter.PageSize
if pageSize <= 0 {
pageSize = 50
}
if pageSize > 100 {
pageSize = 100
}
offset := (page - 1) * pageSize
gatewayUserID := localGatewayUserID(user)
apiKeyID := ""
userID := ""
if user != nil {
apiKeyID = strings.TrimSpace(user.APIKeyID)
userID = strings.TrimSpace(user.ID)
}
if gatewayUserID == "" && userID == "" {
return TaskListResult{}, ErrLocalUserRequired
}
queryPattern := ""
if query := strings.TrimSpace(filter.Query); query != "" {
queryPattern = "%" + query + "%"
}
args := []any{
gatewayUserID,
userID,
apiKeyID,
queryPattern,
strings.TrimSpace(filter.ModelType),
nullableTaskListTime(filter.CreatedFrom),
nullableTaskListTime(filter.CreatedTo),
}
whereSQL := `
WHERE (
(
NULLIF($1, '')::uuid IS NOT NULL
AND gateway_user_id = NULLIF($1, '')::uuid
)
OR (
NULLIF($1, '')::uuid IS NULL
AND NULLIF($2, '') IS NOT NULL
AND user_id = $2
)
)
AND (
NULLIF($3, '') IS NULL
OR api_key_id = $3
)
AND (
NULLIF($4, '') IS NULL
OR id::text ILIKE $4
OR COALESCE(request_id, '') ILIKE $4
OR kind ILIKE $4
OR model ILIKE $4
OR COALESCE(requested_model, '') ILIKE $4
OR COALESCE(resolved_model, '') ILIKE $4
OR COALESCE(api_key_id, '') ILIKE $4
OR COALESCE(api_key_name, '') ILIKE $4
OR COALESCE(api_key_prefix, '') ILIKE $4
OR status ILIKE $4
OR COALESCE(model_type, '') ILIKE $4
)
AND (
NULLIF($5, '') IS NULL
OR model_type = $5
)
AND (
$6::timestamptz IS NULL
OR created_at >= $6::timestamptz
)
AND (
$7::timestamptz IS NULL
OR created_at <= $7::timestamptz
)`
var total int
if err := s.pool.QueryRow(ctx, `SELECT count(*) FROM gateway_tasks `+whereSQL, args...).Scan(&total); err != nil {
return TaskListResult{}, err
}
queryArgs := append(args, pageSize, offset)
rows, err := s.pool.Query(ctx, `
SELECT `+gatewayTaskColumns+`
FROM gateway_tasks
`+whereSQL+`
ORDER BY created_at DESC
LIMIT $8 OFFSET $9`, queryArgs...)
if err != nil {
return TaskListResult{}, err
}
defer rows.Close()
items := make([]GatewayTask, 0)
for rows.Next() {
task, err := scanGatewayTask(rows)
if err != nil {
return TaskListResult{}, err
}
items = append(items, task)
}
if err := rows.Err(); err != nil {
return TaskListResult{}, err
}
return TaskListResult{
Items: items,
Total: total,
Page: page,
PageSize: pageSize,
}, nil
}
func nullableTaskListTime(value *time.Time) any {
if value == nil {
return nil
}
return *value
}
func (s *Store) MarkTaskRunning(ctx context.Context, taskID string, modelType string, normalizedRequest map[string]any) error {
normalizedJSON, _ := json.Marshal(emptyObjectIfNil(normalizedRequest))
_, err := s.pool.Exec(ctx, `
UPDATE gateway_tasks
SET status = 'running',
model_type = NULLIF($2, ''),
normalized_request = $3::jsonb,
locked_at = now(),
heartbeat_at = now(),
updated_at = now()
WHERE id = $1::uuid`, taskID, modelType, string(normalizedJSON))
return err
}
func (s *Store) CreateTaskAttempt(ctx context.Context, input CreateTaskAttemptInput) (string, error) {
requestJSON, _ := json.Marshal(emptyObjectIfNil(input.RequestSnapshot))
tx, err := s.pool.Begin(ctx)
if err != nil {
return "", err
}
defer tx.Rollback(ctx)
var attemptID string
err = tx.QueryRow(ctx, `
INSERT INTO gateway_task_attempts (
task_id, attempt_no, platform_id, platform_model_id, client_id, queue_key,
status, simulated, request_snapshot
)
VALUES (
$1::uuid, $2, NULLIF($3, '')::uuid, NULLIF($4, '')::uuid, NULLIF($5, ''), $6,
$7, $8, $9::jsonb
)
RETURNING id::text`,
input.TaskID,
input.AttemptNo,
input.PlatformID,
input.PlatformModelID,
input.ClientID,
input.QueueKey,
firstNonEmpty(input.Status, "running"),
input.Simulated,
string(requestJSON),
).Scan(&attemptID)
if err != nil {
return "", err
}
if _, err := tx.Exec(ctx, `
UPDATE gateway_tasks
SET attempt_count = GREATEST(attempt_count, $2), updated_at = now()
WHERE id = $1::uuid`, input.TaskID, input.AttemptNo); err != nil {
return "", err
}
return attemptID, tx.Commit(ctx)
}
func (s *Store) FinishTaskAttempt(ctx context.Context, input FinishTaskAttemptInput) error {
responseJSON, _ := json.Marshal(emptyObjectIfNil(input.ResponseSnapshot))
usageJSON, _ := json.Marshal(emptyObjectIfNil(input.Usage))
metricsJSON, _ := json.Marshal(emptyObjectIfNil(input.Metrics))
_, err := s.pool.Exec(ctx, `
UPDATE gateway_task_attempts
SET status = $2,
retryable = $3,
request_id = NULLIF($4, ''),
usage = $5::jsonb,
metrics = $6::jsonb,
response_snapshot = $7::jsonb,
response_started_at = $8::timestamptz,
response_finished_at = $9::timestamptz,
response_duration_ms = $10,
error_code = NULLIF($11, ''),
error_message = NULLIF($12, ''),
finished_at = now()
WHERE id = $1::uuid`,
input.AttemptID,
input.Status,
input.Retryable,
input.RequestID,
string(usageJSON),
string(metricsJSON),
string(responseJSON),
nullableTime(input.ResponseStartedAt),
nullableTime(input.ResponseFinishedAt),
input.ResponseDurationMS,
input.ErrorCode,
input.ErrorMessage,
)
return err
}
func (s *Store) FinishTaskSuccess(ctx context.Context, input FinishTaskSuccessInput) (GatewayTask, error) {
resultJSON, _ := json.Marshal(emptyObjectIfNil(input.Result))
billingsJSON, _ := json.Marshal(input.Billings)
usageJSON, _ := json.Marshal(emptyObjectIfNil(input.Usage))
metricsJSON, _ := json.Marshal(emptyObjectIfNil(input.Metrics))
billingSummaryJSON, _ := json.Marshal(emptyObjectIfNil(input.BillingSummary))
if _, err := s.pool.Exec(ctx, `
UPDATE gateway_tasks
SET status = 'succeeded',
result = $2::jsonb,
billings = $3::jsonb,
request_id = NULLIF($4, ''),
resolved_model = NULLIF($5, ''),
usage = $6::jsonb,
metrics = $7::jsonb,
billing_summary = $8::jsonb,
final_charge_amount = $9,
response_started_at = $10::timestamptz,
response_finished_at = $11::timestamptz,
response_duration_ms = $12,
error = NULL,
error_code = NULL,
error_message = NULL,
finished_at = now(),
updated_at = now()
WHERE id = $1::uuid`,
input.TaskID,
string(resultJSON),
string(billingsJSON),
input.RequestID,
input.ResolvedModel,
string(usageJSON),
string(metricsJSON),
string(billingSummaryJSON),
input.FinalChargeAmount,
nullableTime(input.ResponseStartedAt),
nullableTime(input.ResponseFinishedAt),
input.ResponseDurationMS,
); err != nil {
return GatewayTask{}, err
}
return s.GetTask(ctx, input.TaskID)
}
func (s *Store) SettleTaskBilling(ctx context.Context, task GatewayTask) error {
if task.FinalChargeAmount <= 0 || strings.TrimSpace(task.GatewayUserID) == "" {
return nil
}
currency := strings.TrimSpace(taskBillingString(task.BillingSummary["currency"]))
if currency == "" || currency == "mixed" {
currency = "resource"
}
metadata, _ := json.Marshal(map[string]any{
"taskId": task.ID,
"kind": task.Kind,
"model": task.Model,
"resolvedModel": task.ResolvedModel,
"billings": task.Billings,
"billingSummary": task.BillingSummary,
})
return pgx.BeginFunc(ctx, s.pool, func(tx pgx.Tx) error {
if _, err := tx.Exec(ctx, `
INSERT INTO gateway_wallet_accounts (
gateway_tenant_id, gateway_user_id, tenant_id, tenant_key, user_id, currency
)
VALUES (NULLIF($1, '')::uuid, $2::uuid, NULLIF($3, ''), NULLIF($4, ''), NULLIF($5, ''), $6)
ON CONFLICT (gateway_user_id, currency) DO NOTHING`,
task.GatewayTenantID, task.GatewayUserID, task.TenantID, task.TenantKey, task.UserID, currency); err != nil {
return err
}
var exists bool
if err := tx.QueryRow(ctx, `
SELECT EXISTS (
SELECT 1
FROM gateway_wallet_transactions t
JOIN gateway_wallet_accounts a ON a.id = t.account_id
WHERE a.gateway_user_id = $1::uuid
AND a.currency = $2
AND t.idempotency_key = $3
)`, task.GatewayUserID, currency, billingIdempotencyKey(task.ID)).Scan(&exists); err != nil {
return err
}
if exists {
return nil
}
var accountID string
var balanceBefore float64
var gatewayTenantID string
if err := tx.QueryRow(ctx, `
SELECT id::text, balance::float8, COALESCE(gateway_tenant_id::text, '')
FROM gateway_wallet_accounts
WHERE gateway_user_id = $1::uuid
AND currency = $2
FOR UPDATE`, task.GatewayUserID, currency).Scan(&accountID, &balanceBefore, &gatewayTenantID); err != nil {
return err
}
amount := roundMoney(task.FinalChargeAmount)
balanceAfter := roundMoney(balanceBefore - amount)
if _, err := tx.Exec(ctx, `
UPDATE gateway_wallet_accounts
SET balance = $2,
total_spent = total_spent + $3,
updated_at = now()
WHERE id = $1::uuid`, accountID, balanceAfter, amount); err != nil {
return err
}
_, err := tx.Exec(ctx, `
INSERT INTO gateway_wallet_transactions (
account_id, gateway_tenant_id, gateway_user_id, direction, transaction_type,
amount, balance_before, balance_after, idempotency_key, reference_type, reference_id, metadata
)
VALUES (
$1::uuid, NULLIF($2, '')::uuid, $3::uuid, 'debit', 'task_billing',
$4, $5, $6, $7, 'gateway_task', $8, $9::jsonb
)`,
accountID, firstNonEmpty(gatewayTenantID, task.GatewayTenantID), task.GatewayUserID, amount, roundMoney(balanceBefore), balanceAfter, billingIdempotencyKey(task.ID), task.ID, string(metadata))
if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == "23505" {
return nil
}
return err
})
}
func billingIdempotencyKey(taskID string) string {
return "task:" + taskID + ":billing"
}
func roundMoney(value float64) float64 {
if value < 0 {
return -roundMoney(-value)
}
return float64(int64(value*1000000+0.5)) / 1000000
}
func taskBillingString(value any) string {
if text, ok := value.(string); ok {
return text
}
return ""
}
func (s *Store) FinishTaskFailure(ctx context.Context, input FinishTaskFailureInput) (GatewayTask, error) {
metricsJSON, _ := json.Marshal(emptyObjectIfNil(input.Metrics))
if _, err := s.pool.Exec(ctx, `
UPDATE gateway_tasks
SET status = 'failed',
error = NULLIF($2, ''),
error_code = NULLIF($3, ''),
error_message = NULLIF($2, ''),
request_id = NULLIF($4, ''),
metrics = $5::jsonb,
response_started_at = $6::timestamptz,
response_finished_at = $7::timestamptz,
response_duration_ms = $8,
finished_at = now(),
updated_at = now()
WHERE id = $1::uuid`,
input.TaskID,
input.Message,
input.Code,
input.RequestID,
string(metricsJSON),
nullableTime(input.ResponseStartedAt),
nullableTime(input.ResponseFinishedAt),
input.ResponseDurationMS,
); err != nil {
return GatewayTask{}, err
}
return s.GetTask(ctx, input.TaskID)
}
func nullableTime(value time.Time) any {
if value.IsZero() {
return nil
}
return value
}
func (s *Store) AddTaskEvent(ctx context.Context, taskID string, eventType string, status string, phase string, progress float64, message string, payload map[string]any, simulated bool) (TaskEvent, error) {
payloadJSON, _ := json.Marshal(emptyObjectIfNil(payload))
var event TaskEvent
var payloadBytes []byte
err := s.pool.QueryRow(ctx, `
WITH next_seq AS (
SELECT COALESCE(MAX(seq), 0) + 1 AS seq
FROM gateway_task_events
WHERE task_id = $1::uuid
)
INSERT INTO gateway_task_events (task_id, seq, event_type, status, phase, progress, message, payload, simulated)
SELECT $1::uuid, next_seq.seq, $2, NULLIF($3, ''), NULLIF($4, ''), $5, NULLIF($6, ''), $7::jsonb, $8
FROM next_seq
RETURNING id::text, task_id::text, seq, event_type, COALESCE(status, ''), COALESCE(phase, ''),
COALESCE(progress, 0)::float8, COALESCE(message, ''), payload, simulated, created_at`,
taskID,
eventType,
status,
phase,
progress,
message,
string(payloadJSON),
simulated,
).Scan(
&event.ID,
&event.TaskID,
&event.Seq,
&event.EventType,
&event.Status,
&event.Phase,
&event.Progress,
&event.Message,
&payloadBytes,
&event.Simulated,
&event.CreatedAt,
)
if err != nil {
return TaskEvent{}, err
}
event.Payload = decodeObject(payloadBytes)
return event, nil
}
func (s *Store) QueueTaskCallback(ctx context.Context, event TaskEvent, callbackURL string) error {
if callbackURL == "" {
return nil
}
payloadJSON, _ := json.Marshal(map[string]any{
"taskId": event.TaskID,
"seq": event.Seq,
"eventType": event.EventType,
"status": event.Status,
"phase": event.Phase,
"progress": event.Progress,
"message": event.Message,
"payload": event.Payload,
"simulated": event.Simulated,
"createdAt": event.CreatedAt,
})
_, err := s.pool.Exec(ctx, `
INSERT INTO gateway_task_callback_outbox (task_id, event_id, seq, callback_url, payload)
VALUES ($1::uuid, $2::uuid, $3, $4, $5::jsonb)
ON CONFLICT (task_id, seq, callback_url) DO NOTHING`,
event.TaskID,
event.ID,
event.Seq,
callbackURL,
string(payloadJSON),
)
return err
}
func (s *Store) RecordClientAssignment(ctx context.Context, candidate RuntimeModelCandidate) error {
_, err := s.pool.Exec(ctx, `
INSERT INTO runtime_client_states (
client_id, platform_id, provider, method_name, queue_key, running_count, last_assigned_at
)
VALUES ($1, $2::uuid, $3, $4, $5, 1, now())
ON CONFLICT (client_id) DO UPDATE
SET running_count = runtime_client_states.running_count + 1,
last_assigned_at = now(),
updated_at = now()`,
candidate.ClientID,
candidate.PlatformID,
candidate.Provider,
candidate.ModelType,
candidate.QueueKey,
)
return err
}
func (s *Store) RecordClientRelease(ctx context.Context, clientID string, lastError string) error {
_, err := s.pool.Exec(ctx, `
UPDATE runtime_client_states
SET running_count = GREATEST(running_count - 1, 0),
last_error = NULLIF($2, ''),
updated_at = now()
WHERE client_id = $1`, clientID, lastError)
return err
}