package store import ( "context" "encoding/json" "strings" "time" "github.com/easyai/easyai-ai-gateway/apps/api/internal/auth" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" ) func (s *Store) ListTasks(ctx context.Context, user *auth.User, limit int) ([]GatewayTask, error) { if limit <= 0 { limit = 50 } if limit > 100 { limit = 100 } gatewayUserID := localGatewayUserID(user) apiKeyID := "" userID := "" if user != nil { apiKeyID = strings.TrimSpace(user.APIKeyID) userID = strings.TrimSpace(user.ID) } if gatewayUserID == "" && userID == "" { return nil, ErrLocalUserRequired } rows, err := s.pool.Query(ctx, ` SELECT `+gatewayTaskColumns+` FROM gateway_tasks WHERE ( ( NULLIF($1, '')::uuid IS NOT NULL AND gateway_user_id = NULLIF($1, '')::uuid ) OR ( NULLIF($1, '')::uuid IS NULL AND NULLIF($2, '') IS NOT NULL AND user_id = $2 ) ) AND ( NULLIF($3, '') IS NULL OR api_key_id = $3 ) ORDER BY created_at DESC LIMIT $4`, gatewayUserID, userID, apiKeyID, limit) if err != nil { return nil, err } defer rows.Close() items := make([]GatewayTask, 0) for rows.Next() { task, err := scanGatewayTask(rows) if err != nil { return nil, err } items = append(items, task) } return items, rows.Err() } func (s *Store) MarkTaskRunning(ctx context.Context, taskID string, modelType string, normalizedRequest map[string]any) error { normalizedJSON, _ := json.Marshal(emptyObjectIfNil(normalizedRequest)) _, err := s.pool.Exec(ctx, ` UPDATE gateway_tasks SET status = 'running', model_type = NULLIF($2, ''), normalized_request = $3::jsonb, locked_at = now(), heartbeat_at = now(), updated_at = now() WHERE id = $1::uuid`, taskID, modelType, string(normalizedJSON)) return err } func (s *Store) CreateTaskAttempt(ctx context.Context, input CreateTaskAttemptInput) (string, error) { requestJSON, _ := json.Marshal(emptyObjectIfNil(input.RequestSnapshot)) tx, err := s.pool.Begin(ctx) if err != nil { return "", err } defer tx.Rollback(ctx) var attemptID string err = tx.QueryRow(ctx, ` INSERT INTO gateway_task_attempts ( task_id, attempt_no, platform_id, platform_model_id, client_id, queue_key, status, simulated, request_snapshot ) VALUES ( $1::uuid, $2, NULLIF($3, '')::uuid, NULLIF($4, '')::uuid, NULLIF($5, ''), $6, $7, $8, $9::jsonb ) RETURNING id::text`, input.TaskID, input.AttemptNo, input.PlatformID, input.PlatformModelID, input.ClientID, input.QueueKey, firstNonEmpty(input.Status, "running"), input.Simulated, string(requestJSON), ).Scan(&attemptID) if err != nil { return "", err } if _, err := tx.Exec(ctx, ` UPDATE gateway_tasks SET attempt_count = GREATEST(attempt_count, $2), updated_at = now() WHERE id = $1::uuid`, input.TaskID, input.AttemptNo); err != nil { return "", err } return attemptID, tx.Commit(ctx) } func (s *Store) FinishTaskAttempt(ctx context.Context, input FinishTaskAttemptInput) error { responseJSON, _ := json.Marshal(emptyObjectIfNil(input.ResponseSnapshot)) usageJSON, _ := json.Marshal(emptyObjectIfNil(input.Usage)) metricsJSON, _ := json.Marshal(emptyObjectIfNil(input.Metrics)) _, err := s.pool.Exec(ctx, ` UPDATE gateway_task_attempts SET status = $2, retryable = $3, request_id = NULLIF($4, ''), usage = $5::jsonb, metrics = $6::jsonb, response_snapshot = $7::jsonb, response_started_at = $8::timestamptz, response_finished_at = $9::timestamptz, response_duration_ms = $10, error_code = NULLIF($11, ''), error_message = NULLIF($12, ''), finished_at = now() WHERE id = $1::uuid`, input.AttemptID, input.Status, input.Retryable, input.RequestID, string(usageJSON), string(metricsJSON), string(responseJSON), nullableTime(input.ResponseStartedAt), nullableTime(input.ResponseFinishedAt), input.ResponseDurationMS, input.ErrorCode, input.ErrorMessage, ) return err } func (s *Store) FinishTaskSuccess(ctx context.Context, input FinishTaskSuccessInput) (GatewayTask, error) { resultJSON, _ := json.Marshal(emptyObjectIfNil(input.Result)) billingsJSON, _ := json.Marshal(input.Billings) usageJSON, _ := json.Marshal(emptyObjectIfNil(input.Usage)) metricsJSON, _ := json.Marshal(emptyObjectIfNil(input.Metrics)) billingSummaryJSON, _ := json.Marshal(emptyObjectIfNil(input.BillingSummary)) if _, err := s.pool.Exec(ctx, ` UPDATE gateway_tasks SET status = 'succeeded', result = $2::jsonb, billings = $3::jsonb, request_id = NULLIF($4, ''), resolved_model = NULLIF($5, ''), usage = $6::jsonb, metrics = $7::jsonb, billing_summary = $8::jsonb, final_charge_amount = $9, response_started_at = $10::timestamptz, response_finished_at = $11::timestamptz, response_duration_ms = $12, error = NULL, error_code = NULL, error_message = NULL, finished_at = now(), updated_at = now() WHERE id = $1::uuid`, input.TaskID, string(resultJSON), string(billingsJSON), input.RequestID, input.ResolvedModel, string(usageJSON), string(metricsJSON), string(billingSummaryJSON), input.FinalChargeAmount, nullableTime(input.ResponseStartedAt), nullableTime(input.ResponseFinishedAt), input.ResponseDurationMS, ); err != nil { return GatewayTask{}, err } return s.GetTask(ctx, input.TaskID) } func (s *Store) SettleTaskBilling(ctx context.Context, task GatewayTask) error { if task.FinalChargeAmount <= 0 || strings.TrimSpace(task.GatewayUserID) == "" { return nil } currency := strings.TrimSpace(taskBillingString(task.BillingSummary["currency"])) if currency == "" || currency == "mixed" { currency = "resource" } metadata, _ := json.Marshal(map[string]any{ "taskId": task.ID, "kind": task.Kind, "model": task.Model, "resolvedModel": task.ResolvedModel, "billings": task.Billings, "billingSummary": task.BillingSummary, }) return pgx.BeginFunc(ctx, s.pool, func(tx pgx.Tx) error { if _, err := tx.Exec(ctx, ` INSERT INTO gateway_wallet_accounts ( gateway_tenant_id, gateway_user_id, tenant_id, tenant_key, user_id, currency ) VALUES (NULLIF($1, '')::uuid, $2::uuid, NULLIF($3, ''), NULLIF($4, ''), NULLIF($5, ''), $6) ON CONFLICT (gateway_user_id, currency) DO NOTHING`, task.GatewayTenantID, task.GatewayUserID, task.TenantID, task.TenantKey, task.UserID, currency); err != nil { return err } var exists bool if err := tx.QueryRow(ctx, ` SELECT EXISTS ( SELECT 1 FROM gateway_wallet_transactions t JOIN gateway_wallet_accounts a ON a.id = t.account_id WHERE a.gateway_user_id = $1::uuid AND a.currency = $2 AND t.idempotency_key = $3 )`, task.GatewayUserID, currency, billingIdempotencyKey(task.ID)).Scan(&exists); err != nil { return err } if exists { return nil } var accountID string var balanceBefore float64 var gatewayTenantID string if err := tx.QueryRow(ctx, ` SELECT id::text, balance::float8, COALESCE(gateway_tenant_id::text, '') FROM gateway_wallet_accounts WHERE gateway_user_id = $1::uuid AND currency = $2 FOR UPDATE`, task.GatewayUserID, currency).Scan(&accountID, &balanceBefore, &gatewayTenantID); err != nil { return err } amount := roundMoney(task.FinalChargeAmount) balanceAfter := roundMoney(balanceBefore - amount) if _, err := tx.Exec(ctx, ` UPDATE gateway_wallet_accounts SET balance = $2, total_spent = total_spent + $3, updated_at = now() WHERE id = $1::uuid`, accountID, balanceAfter, amount); err != nil { return err } _, err := tx.Exec(ctx, ` INSERT INTO gateway_wallet_transactions ( account_id, gateway_tenant_id, gateway_user_id, direction, transaction_type, amount, balance_before, balance_after, idempotency_key, reference_type, reference_id, metadata ) VALUES ( $1::uuid, NULLIF($2, '')::uuid, $3::uuid, 'debit', 'task_billing', $4, $5, $6, $7, 'gateway_task', $8, $9::jsonb )`, accountID, firstNonEmpty(gatewayTenantID, task.GatewayTenantID), task.GatewayUserID, amount, roundMoney(balanceBefore), balanceAfter, billingIdempotencyKey(task.ID), task.ID, string(metadata)) if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == "23505" { return nil } return err }) } func billingIdempotencyKey(taskID string) string { return "task:" + taskID + ":billing" } func roundMoney(value float64) float64 { if value < 0 { return -roundMoney(-value) } return float64(int64(value*1000000+0.5)) / 1000000 } func taskBillingString(value any) string { if text, ok := value.(string); ok { return text } return "" } func (s *Store) FinishTaskFailure(ctx context.Context, input FinishTaskFailureInput) (GatewayTask, error) { metricsJSON, _ := json.Marshal(emptyObjectIfNil(input.Metrics)) if _, err := s.pool.Exec(ctx, ` UPDATE gateway_tasks SET status = 'failed', error = NULLIF($2, ''), error_code = NULLIF($3, ''), error_message = NULLIF($2, ''), request_id = NULLIF($4, ''), metrics = $5::jsonb, response_started_at = $6::timestamptz, response_finished_at = $7::timestamptz, response_duration_ms = $8, finished_at = now(), updated_at = now() WHERE id = $1::uuid`, input.TaskID, input.Message, input.Code, input.RequestID, string(metricsJSON), nullableTime(input.ResponseStartedAt), nullableTime(input.ResponseFinishedAt), input.ResponseDurationMS, ); err != nil { return GatewayTask{}, err } return s.GetTask(ctx, input.TaskID) } func nullableTime(value time.Time) any { if value.IsZero() { return nil } return value } func (s *Store) AddTaskEvent(ctx context.Context, taskID string, eventType string, status string, phase string, progress float64, message string, payload map[string]any, simulated bool) (TaskEvent, error) { payloadJSON, _ := json.Marshal(emptyObjectIfNil(payload)) var event TaskEvent var payloadBytes []byte err := s.pool.QueryRow(ctx, ` WITH next_seq AS ( SELECT COALESCE(MAX(seq), 0) + 1 AS seq FROM gateway_task_events WHERE task_id = $1::uuid ) INSERT INTO gateway_task_events (task_id, seq, event_type, status, phase, progress, message, payload, simulated) SELECT $1::uuid, next_seq.seq, $2, NULLIF($3, ''), NULLIF($4, ''), $5, NULLIF($6, ''), $7::jsonb, $8 FROM next_seq RETURNING id::text, task_id::text, seq, event_type, COALESCE(status, ''), COALESCE(phase, ''), COALESCE(progress, 0)::float8, COALESCE(message, ''), payload, simulated, created_at`, taskID, eventType, status, phase, progress, message, string(payloadJSON), simulated, ).Scan( &event.ID, &event.TaskID, &event.Seq, &event.EventType, &event.Status, &event.Phase, &event.Progress, &event.Message, &payloadBytes, &event.Simulated, &event.CreatedAt, ) if err != nil { return TaskEvent{}, err } event.Payload = decodeObject(payloadBytes) return event, nil } func (s *Store) QueueTaskCallback(ctx context.Context, event TaskEvent, callbackURL string) error { if callbackURL == "" { return nil } payloadJSON, _ := json.Marshal(map[string]any{ "taskId": event.TaskID, "seq": event.Seq, "eventType": event.EventType, "status": event.Status, "phase": event.Phase, "progress": event.Progress, "message": event.Message, "payload": event.Payload, "simulated": event.Simulated, "createdAt": event.CreatedAt, }) _, err := s.pool.Exec(ctx, ` INSERT INTO gateway_task_callback_outbox (task_id, event_id, seq, callback_url, payload) VALUES ($1::uuid, $2::uuid, $3, $4, $5::jsonb) ON CONFLICT (task_id, seq, callback_url) DO NOTHING`, event.TaskID, event.ID, event.Seq, callbackURL, string(payloadJSON), ) return err } func (s *Store) RecordClientAssignment(ctx context.Context, candidate RuntimeModelCandidate) error { _, err := s.pool.Exec(ctx, ` INSERT INTO runtime_client_states ( client_id, platform_id, provider, method_name, queue_key, running_count, last_assigned_at ) VALUES ($1, $2::uuid, $3, $4, $5, 1, now()) ON CONFLICT (client_id) DO UPDATE SET running_count = runtime_client_states.running_count + 1, last_assigned_at = now(), updated_at = now()`, candidate.ClientID, candidate.PlatformID, candidate.Provider, candidate.ModelType, candidate.QueueKey, ) return err } func (s *Store) RecordClientRelease(ctx context.Context, clientID string, lastError string) error { _, err := s.pool.Exec(ctx, ` UPDATE runtime_client_states SET running_count = GREATEST(running_count - 1, 0), last_error = NULLIF($2, ''), updated_at = now() WHERE client_id = $1`, clientID, lastError) return err }