package store import ( "context" "encoding/json" "strconv" "strings" "time" "github.com/easyai/easyai-ai-gateway/apps/api/internal/auth" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" ) type TaskListFilter struct { Query string ModelType string CreatedFrom *time.Time CreatedTo *time.Time Page int PageSize int } type TaskListResult struct { Items []GatewayTask Total int Page int PageSize int } func (s *Store) ListTasks(ctx context.Context, user *auth.User, filter TaskListFilter) (TaskListResult, error) { page := filter.Page if page <= 0 { page = 1 } pageSize := filter.PageSize if pageSize <= 0 { pageSize = 50 } if pageSize > 100 { pageSize = 100 } offset := (page - 1) * pageSize gatewayUserID := localGatewayUserID(user) apiKeyID := "" userID := "" if user != nil { apiKeyID = strings.TrimSpace(user.APIKeyID) userID = strings.TrimSpace(user.ID) } if gatewayUserID == "" && userID == "" { return TaskListResult{}, ErrLocalUserRequired } queryPattern := "" if query := strings.TrimSpace(filter.Query); query != "" { queryPattern = "%" + query + "%" } args := []any{ gatewayUserID, userID, apiKeyID, queryPattern, strings.TrimSpace(filter.ModelType), nullableTaskListTime(filter.CreatedFrom), nullableTaskListTime(filter.CreatedTo), } whereSQL := ` WHERE ( ( NULLIF($1, '')::uuid IS NOT NULL AND gateway_user_id = NULLIF($1, '')::uuid ) OR ( NULLIF($1, '')::uuid IS NULL AND NULLIF($2, '') IS NOT NULL AND user_id = $2 ) ) AND ( NULLIF($3, '') IS NULL OR api_key_id = $3 ) AND ( NULLIF($4, '') IS NULL OR id::text ILIKE $4 OR COALESCE(request_id, '') ILIKE $4 OR kind ILIKE $4 OR model ILIKE $4 OR COALESCE(requested_model, '') ILIKE $4 OR COALESCE(resolved_model, '') ILIKE $4 OR COALESCE(api_key_id, '') ILIKE $4 OR COALESCE(api_key_name, '') ILIKE $4 OR COALESCE(api_key_prefix, '') ILIKE $4 OR status ILIKE $4 OR COALESCE(model_type, '') ILIKE $4 ) AND ( NULLIF($5, '') IS NULL OR model_type = $5 ) AND ( $6::timestamptz IS NULL OR created_at >= $6::timestamptz ) AND ( $7::timestamptz IS NULL OR created_at <= $7::timestamptz )` var total int if err := s.pool.QueryRow(ctx, `SELECT count(*) FROM gateway_tasks `+whereSQL, args...).Scan(&total); err != nil { return TaskListResult{}, err } queryArgs := append(args, pageSize, offset) rows, err := s.pool.Query(ctx, ` SELECT `+gatewayTaskColumns+` FROM gateway_tasks `+whereSQL+` ORDER BY created_at DESC LIMIT $8 OFFSET $9`, queryArgs...) if err != nil { return TaskListResult{}, err } defer rows.Close() items := make([]GatewayTask, 0) for rows.Next() { task, err := scanGatewayTask(rows) if err != nil { return TaskListResult{}, err } items = append(items, task) } if err := rows.Err(); err != nil { return TaskListResult{}, err } items, err = s.attachTaskAttempts(ctx, items) if err != nil { return TaskListResult{}, err } return TaskListResult{ Items: items, Total: total, Page: page, PageSize: pageSize, }, nil } func nullableTaskListTime(value *time.Time) any { if value == nil { return nil } return *value } func (s *Store) MarkTaskRunning(ctx context.Context, taskID string, modelType string, normalizedRequest map[string]any) error { normalizedJSON, _ := json.Marshal(emptyObjectIfNil(normalizedRequest)) _, err := s.pool.Exec(ctx, ` UPDATE gateway_tasks SET status = 'running', model_type = NULLIF($2, ''), normalized_request = $3::jsonb, locked_at = now(), heartbeat_at = now(), updated_at = now() WHERE id = $1::uuid`, taskID, modelType, string(normalizedJSON)) return err } func (s *Store) CreateTaskAttempt(ctx context.Context, input CreateTaskAttemptInput) (string, error) { requestJSON, _ := json.Marshal(emptyObjectIfNil(input.RequestSnapshot)) metricsJSON, _ := json.Marshal(emptyObjectIfNil(input.Metrics)) tx, err := s.pool.Begin(ctx) if err != nil { return "", err } defer tx.Rollback(ctx) var attemptID string err = tx.QueryRow(ctx, ` INSERT INTO gateway_task_attempts ( task_id, attempt_no, platform_id, platform_model_id, client_id, queue_key, status, simulated, request_snapshot, metrics ) VALUES ( $1::uuid, $2, NULLIF($3, '')::uuid, NULLIF($4, '')::uuid, NULLIF($5, ''), $6, $7, $8, $9::jsonb, $10::jsonb ) RETURNING id::text`, input.TaskID, input.AttemptNo, input.PlatformID, input.PlatformModelID, input.ClientID, input.QueueKey, firstNonEmpty(input.Status, "running"), input.Simulated, string(requestJSON), string(metricsJSON), ).Scan(&attemptID) if err != nil { return "", err } if _, err := tx.Exec(ctx, ` UPDATE gateway_tasks SET attempt_count = GREATEST(attempt_count, $2), updated_at = now() WHERE id = $1::uuid`, input.TaskID, input.AttemptNo); err != nil { return "", err } return attemptID, tx.Commit(ctx) } func (s *Store) attachTaskAttempts(ctx context.Context, items []GatewayTask) ([]GatewayTask, error) { if len(items) == 0 { return items, nil } taskIDs := make([]string, 0, len(items)) for _, item := range items { taskIDs = append(taskIDs, item.ID) } attemptsByTaskID, err := s.listTaskAttemptsByTaskIDs(ctx, taskIDs) if err != nil { return nil, err } for index := range items { items[index].Attempts = attemptsByTaskID[items[index].ID] } return items, nil } func (s *Store) ListTaskAttempts(ctx context.Context, taskID string) ([]TaskAttempt, error) { attemptsByTaskID, err := s.listTaskAttemptsByTaskIDs(ctx, []string{taskID}) if err != nil { return nil, err } return attemptsByTaskID[taskID], nil } func (s *Store) AppendTaskAttemptTrace(ctx context.Context, taskID string, attemptNo int, entry map[string]any) error { entryJSON, _ := json.Marshal(emptyObjectIfNil(entry)) _, err := s.pool.Exec(ctx, ` UPDATE gateway_task_attempts SET metrics = jsonb_set( COALESCE(metrics, '{}'::jsonb), '{trace}', ( CASE WHEN jsonb_typeof(COALESCE(metrics->'trace', '[]'::jsonb)) = 'array' THEN COALESCE(metrics->'trace', '[]'::jsonb) ELSE '[]'::jsonb END ) || jsonb_build_array($3::jsonb), true ) WHERE task_id = $1::uuid AND attempt_no = $2`, taskID, attemptNo, string(entryJSON)) return err } func (s *Store) listTaskAttemptsByTaskIDs(ctx context.Context, taskIDs []string) (map[string][]TaskAttempt, error) { itemsByTaskID := map[string][]TaskAttempt{} if len(taskIDs) == 0 { return itemsByTaskID, nil } rows, err := s.pool.Query(ctx, ` SELECT a.id::text, a.task_id::text, a.attempt_no, COALESCE(a.platform_id::text, ''), COALESCE(p.name, ''), COALESCE(p.provider, ''), COALESCE(a.platform_model_id::text, ''), COALESCE(pm.model_name, ''), COALESCE(NULLIF(pm.provider_model_name, ''), pm.model_name, ''), COALESCE(pm.model_alias, ''), COALESCE(a.client_id, ''), a.queue_key, a.status, a.retryable, a.simulated, COALESCE(a.request_id, ''), COALESCE(a.usage, '{}'::jsonb), COALESCE(a.metrics, '{}'::jsonb), a.request_snapshot, COALESCE(a.response_snapshot, '{}'::jsonb), COALESCE(a.response_started_at::text, ''), COALESCE(a.response_finished_at::text, ''), COALESCE(a.response_duration_ms, 0), COALESCE(a.error_code, ''), COALESCE(a.error_message, ''), a.started_at, COALESCE(a.finished_at::text, '') FROM gateway_task_attempts a LEFT JOIN integration_platforms p ON p.id = a.platform_id LEFT JOIN platform_models pm ON pm.id = a.platform_model_id WHERE a.task_id::text = ANY($1) ORDER BY a.task_id, a.attempt_no`, taskIDs) if err != nil { return nil, err } defer rows.Close() for rows.Next() { item, err := scanTaskAttempt(rows) if err != nil { return nil, err } itemsByTaskID[item.TaskID] = append(itemsByTaskID[item.TaskID], item) } if err := rows.Err(); err != nil { return nil, err } return itemsByTaskID, nil } func scanTaskAttempt(scanner taskScanner) (TaskAttempt, error) { var item TaskAttempt var usageBytes []byte var metricsBytes []byte var requestBytes []byte var responseBytes []byte if err := scanner.Scan( &item.ID, &item.TaskID, &item.AttemptNo, &item.PlatformID, &item.PlatformName, &item.Provider, &item.PlatformModelID, &item.ModelName, &item.ProviderModelName, &item.ModelAlias, &item.ClientID, &item.QueueKey, &item.Status, &item.Retryable, &item.Simulated, &item.RequestID, &usageBytes, &metricsBytes, &requestBytes, &responseBytes, &item.ResponseStartedAt, &item.ResponseFinishedAt, &item.ResponseDurationMS, &item.ErrorCode, &item.ErrorMessage, &item.StartedAt, &item.FinishedAt, ); err != nil { return TaskAttempt{}, err } item.Usage = decodeObject(usageBytes) item.Metrics = decodeObject(metricsBytes) item.RequestSnapshot = decodeObject(requestBytes) item.ResponseSnapshot = decodeObject(responseBytes) enrichTaskAttemptFromMetrics(&item) return item, nil } func enrichTaskAttemptFromMetrics(item *TaskAttempt) { if item == nil || len(item.Metrics) == 0 { return } item.PlatformID = firstNonEmpty(item.PlatformID, taskAttemptMetricString(item.Metrics, "platformId")) item.PlatformName = firstNonEmpty(item.PlatformName, taskAttemptMetricString(item.Metrics, "platformName")) item.Provider = firstNonEmpty(item.Provider, taskAttemptMetricString(item.Metrics, "provider")) item.PlatformModelID = firstNonEmpty(item.PlatformModelID, taskAttemptMetricString(item.Metrics, "platformModelId")) item.ModelName = firstNonEmpty(item.ModelName, taskAttemptMetricString(item.Metrics, "resolvedModel"), taskAttemptMetricString(item.Metrics, "modelName")) item.ProviderModelName = firstNonEmpty(item.ProviderModelName, taskAttemptMetricString(item.Metrics, "providerModel")) item.ModelAlias = firstNonEmpty(item.ModelAlias, taskAttemptMetricString(item.Metrics, "modelAlias")) item.ModelType = firstNonEmpty(item.ModelType, taskAttemptMetricString(item.Metrics, "modelType")) item.ClientID = firstNonEmpty(item.ClientID, taskAttemptMetricString(item.Metrics, "clientId")) item.StatusCode = taskAttemptMetricInt(item.Metrics, "statusCode") } func taskAttemptMetricString(metrics map[string]any, key string) string { value, _ := metrics[key].(string) return strings.TrimSpace(value) } func taskAttemptMetricInt(metrics map[string]any, key string) int { switch value := metrics[key].(type) { case int: return value case int64: return int(value) case float64: return int(value) case json.Number: next, _ := value.Int64() return int(next) case string: next, _ := strconv.Atoi(strings.TrimSpace(value)) return next default: return 0 } } func (s *Store) FinishTaskAttempt(ctx context.Context, input FinishTaskAttemptInput) error { responseJSON, _ := json.Marshal(emptyObjectIfNil(input.ResponseSnapshot)) usageJSON, _ := json.Marshal(emptyObjectIfNil(input.Usage)) metricsJSON, _ := json.Marshal(emptyObjectIfNil(input.Metrics)) _, err := s.pool.Exec(ctx, ` UPDATE gateway_task_attempts SET status = $2, retryable = $3, request_id = NULLIF($4, ''), usage = $5::jsonb, metrics = $6::jsonb, response_snapshot = $7::jsonb, response_started_at = $8::timestamptz, response_finished_at = $9::timestamptz, response_duration_ms = $10, error_code = NULLIF($11, ''), error_message = NULLIF($12, ''), finished_at = now() WHERE id = $1::uuid`, input.AttemptID, input.Status, input.Retryable, input.RequestID, string(usageJSON), string(metricsJSON), string(responseJSON), nullableTime(input.ResponseStartedAt), nullableTime(input.ResponseFinishedAt), input.ResponseDurationMS, input.ErrorCode, input.ErrorMessage, ) return err } func (s *Store) FinishTaskSuccess(ctx context.Context, input FinishTaskSuccessInput) (GatewayTask, error) { resultJSON, _ := json.Marshal(emptyObjectIfNil(input.Result)) billingsJSON, _ := json.Marshal(input.Billings) usageJSON, _ := json.Marshal(emptyObjectIfNil(input.Usage)) metricsJSON, _ := json.Marshal(emptyObjectIfNil(input.Metrics)) billingSummaryJSON, _ := json.Marshal(emptyObjectIfNil(input.BillingSummary)) if _, err := s.pool.Exec(ctx, ` UPDATE gateway_tasks SET status = 'succeeded', result = $2::jsonb, billings = $3::jsonb, request_id = NULLIF($4, ''), resolved_model = NULLIF($5, ''), usage = $6::jsonb, metrics = $7::jsonb, billing_summary = $8::jsonb, final_charge_amount = $9, response_started_at = $10::timestamptz, response_finished_at = $11::timestamptz, response_duration_ms = $12, error = NULL, error_code = NULL, error_message = NULL, finished_at = now(), updated_at = now() WHERE id = $1::uuid`, input.TaskID, string(resultJSON), string(billingsJSON), input.RequestID, input.ResolvedModel, string(usageJSON), string(metricsJSON), string(billingSummaryJSON), input.FinalChargeAmount, nullableTime(input.ResponseStartedAt), nullableTime(input.ResponseFinishedAt), input.ResponseDurationMS, ); err != nil { return GatewayTask{}, err } return s.GetTask(ctx, input.TaskID) } func (s *Store) SettleTaskBilling(ctx context.Context, task GatewayTask) error { if task.FinalChargeAmount <= 0 || strings.TrimSpace(task.GatewayUserID) == "" { return nil } currency := strings.TrimSpace(taskBillingString(task.BillingSummary["currency"])) if currency == "" || currency == "mixed" { currency = "resource" } metadata, _ := json.Marshal(map[string]any{ "taskId": task.ID, "kind": task.Kind, "model": task.Model, "resolvedModel": task.ResolvedModel, "billings": task.Billings, "billingSummary": task.BillingSummary, }) return pgx.BeginFunc(ctx, s.pool, func(tx pgx.Tx) error { if _, err := tx.Exec(ctx, ` INSERT INTO gateway_wallet_accounts ( gateway_tenant_id, gateway_user_id, tenant_id, tenant_key, user_id, currency ) VALUES (NULLIF($1, '')::uuid, $2::uuid, NULLIF($3, ''), NULLIF($4, ''), NULLIF($5, ''), $6) ON CONFLICT (gateway_user_id, currency) DO NOTHING`, task.GatewayTenantID, task.GatewayUserID, task.TenantID, task.TenantKey, task.UserID, currency); err != nil { return err } var exists bool if err := tx.QueryRow(ctx, ` SELECT EXISTS ( SELECT 1 FROM gateway_wallet_transactions t JOIN gateway_wallet_accounts a ON a.id = t.account_id WHERE a.gateway_user_id = $1::uuid AND a.currency = $2 AND t.idempotency_key = $3 )`, task.GatewayUserID, currency, billingIdempotencyKey(task.ID)).Scan(&exists); err != nil { return err } if exists { return nil } var accountID string var balanceBefore float64 var gatewayTenantID string if err := tx.QueryRow(ctx, ` SELECT id::text, balance::float8, COALESCE(gateway_tenant_id::text, '') FROM gateway_wallet_accounts WHERE gateway_user_id = $1::uuid AND currency = $2 FOR UPDATE`, task.GatewayUserID, currency).Scan(&accountID, &balanceBefore, &gatewayTenantID); err != nil { return err } amount := roundMoney(task.FinalChargeAmount) balanceAfter := roundMoney(balanceBefore - amount) if _, err := tx.Exec(ctx, ` UPDATE gateway_wallet_accounts SET balance = $2, total_spent = total_spent + $3, updated_at = now() WHERE id = $1::uuid`, accountID, balanceAfter, amount); err != nil { return err } _, err := tx.Exec(ctx, ` INSERT INTO gateway_wallet_transactions ( account_id, gateway_tenant_id, gateway_user_id, direction, transaction_type, amount, balance_before, balance_after, idempotency_key, reference_type, reference_id, metadata ) VALUES ( $1::uuid, NULLIF($2, '')::uuid, $3::uuid, 'debit', 'task_billing', $4, $5, $6, $7, 'gateway_task', $8, $9::jsonb )`, accountID, firstNonEmpty(gatewayTenantID, task.GatewayTenantID), task.GatewayUserID, amount, roundMoney(balanceBefore), balanceAfter, billingIdempotencyKey(task.ID), task.ID, string(metadata)) if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == "23505" { return nil } return err }) } func billingIdempotencyKey(taskID string) string { return "task:" + taskID + ":billing" } func roundMoney(value float64) float64 { if value < 0 { return -roundMoney(-value) } return float64(int64(value*1000000+0.5)) / 1000000 } func taskBillingString(value any) string { if text, ok := value.(string); ok { return text } return "" } func (s *Store) FinishTaskFailure(ctx context.Context, input FinishTaskFailureInput) (GatewayTask, error) { metricsJSON, _ := json.Marshal(emptyObjectIfNil(input.Metrics)) if _, err := s.pool.Exec(ctx, ` UPDATE gateway_tasks SET status = 'failed', error = NULLIF($2, ''), error_code = NULLIF($3, ''), error_message = NULLIF($2, ''), request_id = NULLIF($4, ''), metrics = $5::jsonb, response_started_at = $6::timestamptz, response_finished_at = $7::timestamptz, response_duration_ms = $8, finished_at = now(), updated_at = now() WHERE id = $1::uuid`, input.TaskID, input.Message, input.Code, input.RequestID, string(metricsJSON), nullableTime(input.ResponseStartedAt), nullableTime(input.ResponseFinishedAt), input.ResponseDurationMS, ); err != nil { return GatewayTask{}, err } return s.GetTask(ctx, input.TaskID) } func nullableTime(value time.Time) any { if value.IsZero() { return nil } return value } func (s *Store) AddTaskEvent(ctx context.Context, taskID string, eventType string, status string, phase string, progress float64, message string, payload map[string]any, simulated bool) (TaskEvent, error) { payloadJSON, _ := json.Marshal(emptyObjectIfNil(payload)) var event TaskEvent var payloadBytes []byte err := s.pool.QueryRow(ctx, ` WITH next_seq AS ( SELECT COALESCE(MAX(seq), 0) + 1 AS seq FROM gateway_task_events WHERE task_id = $1::uuid ) INSERT INTO gateway_task_events (task_id, seq, event_type, status, phase, progress, message, payload, simulated) SELECT $1::uuid, next_seq.seq, $2, NULLIF($3, ''), NULLIF($4, ''), $5, NULLIF($6, ''), $7::jsonb, $8 FROM next_seq RETURNING id::text, task_id::text, seq, event_type, COALESCE(status, ''), COALESCE(phase, ''), COALESCE(progress, 0)::float8, COALESCE(message, ''), payload, simulated, created_at`, taskID, eventType, status, phase, progress, message, string(payloadJSON), simulated, ).Scan( &event.ID, &event.TaskID, &event.Seq, &event.EventType, &event.Status, &event.Phase, &event.Progress, &event.Message, &payloadBytes, &event.Simulated, &event.CreatedAt, ) if err != nil { return TaskEvent{}, err } event.Payload = decodeObject(payloadBytes) return event, nil } func (s *Store) QueueTaskCallback(ctx context.Context, event TaskEvent, callbackURL string) error { if callbackURL == "" { return nil } payloadJSON, _ := json.Marshal(map[string]any{ "taskId": event.TaskID, "seq": event.Seq, "eventType": event.EventType, "status": event.Status, "phase": event.Phase, "progress": event.Progress, "message": event.Message, "payload": event.Payload, "simulated": event.Simulated, "createdAt": event.CreatedAt, }) _, err := s.pool.Exec(ctx, ` INSERT INTO gateway_task_callback_outbox (task_id, event_id, seq, callback_url, payload) VALUES ($1::uuid, $2::uuid, $3, $4, $5::jsonb) ON CONFLICT (task_id, seq, callback_url) DO NOTHING`, event.TaskID, event.ID, event.Seq, callbackURL, string(payloadJSON), ) return err } func (s *Store) RecordClientAssignment(ctx context.Context, candidate RuntimeModelCandidate) error { _, err := s.pool.Exec(ctx, ` INSERT INTO runtime_client_states ( client_id, platform_id, provider, method_name, queue_key, running_count, last_assigned_at ) VALUES ($1, $2::uuid, $3, $4, $5, 1, now()) ON CONFLICT (client_id) DO UPDATE SET running_count = runtime_client_states.running_count + 1, last_assigned_at = now(), updated_at = now()`, candidate.ClientID, candidate.PlatformID, candidate.Provider, candidate.ModelType, candidate.QueueKey, ) return err } func (s *Store) RecordClientRelease(ctx context.Context, clientID string, lastError string) error { _, err := s.pool.Exec(ctx, ` UPDATE runtime_client_states SET running_count = GREATEST(running_count - 1, 0), last_error = NULLIF($2, ''), updated_at = now() WHERE client_id = $1`, clientID, lastError) return err }