435 lines
12 KiB
Go
435 lines
12 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgconn"
|
|
)
|
|
|
|
func (s *Store) ListTasks(ctx context.Context, user *auth.User, limit int) ([]GatewayTask, error) {
|
|
if limit <= 0 {
|
|
limit = 50
|
|
}
|
|
if limit > 100 {
|
|
limit = 100
|
|
}
|
|
gatewayUserID := localGatewayUserID(user)
|
|
apiKeyID := ""
|
|
userID := ""
|
|
if user != nil {
|
|
apiKeyID = strings.TrimSpace(user.APIKeyID)
|
|
userID = strings.TrimSpace(user.ID)
|
|
}
|
|
if gatewayUserID == "" && userID == "" {
|
|
return nil, ErrLocalUserRequired
|
|
}
|
|
rows, err := s.pool.Query(ctx, `
|
|
SELECT `+gatewayTaskColumns+`
|
|
FROM gateway_tasks
|
|
WHERE (
|
|
(
|
|
NULLIF($1, '')::uuid IS NOT NULL
|
|
AND gateway_user_id = NULLIF($1, '')::uuid
|
|
)
|
|
OR (
|
|
NULLIF($1, '')::uuid IS NULL
|
|
AND NULLIF($2, '') IS NOT NULL
|
|
AND user_id = $2
|
|
)
|
|
)
|
|
AND (
|
|
NULLIF($3, '') IS NULL
|
|
OR api_key_id = $3
|
|
)
|
|
ORDER BY created_at DESC
|
|
LIMIT $4`, gatewayUserID, userID, apiKeyID, limit)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
items := make([]GatewayTask, 0)
|
|
for rows.Next() {
|
|
task, err := scanGatewayTask(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
items = append(items, task)
|
|
}
|
|
return items, rows.Err()
|
|
}
|
|
|
|
func (s *Store) MarkTaskRunning(ctx context.Context, taskID string, modelType string, normalizedRequest map[string]any) error {
|
|
normalizedJSON, _ := json.Marshal(emptyObjectIfNil(normalizedRequest))
|
|
_, err := s.pool.Exec(ctx, `
|
|
UPDATE gateway_tasks
|
|
SET status = 'running',
|
|
model_type = NULLIF($2, ''),
|
|
normalized_request = $3::jsonb,
|
|
locked_at = now(),
|
|
heartbeat_at = now(),
|
|
updated_at = now()
|
|
WHERE id = $1::uuid`, taskID, modelType, string(normalizedJSON))
|
|
return err
|
|
}
|
|
|
|
func (s *Store) CreateTaskAttempt(ctx context.Context, input CreateTaskAttemptInput) (string, error) {
|
|
requestJSON, _ := json.Marshal(emptyObjectIfNil(input.RequestSnapshot))
|
|
tx, err := s.pool.Begin(ctx)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer tx.Rollback(ctx)
|
|
|
|
var attemptID string
|
|
err = tx.QueryRow(ctx, `
|
|
INSERT INTO gateway_task_attempts (
|
|
task_id, attempt_no, platform_id, platform_model_id, client_id, queue_key,
|
|
status, simulated, request_snapshot
|
|
)
|
|
VALUES (
|
|
$1::uuid, $2, NULLIF($3, '')::uuid, NULLIF($4, '')::uuid, NULLIF($5, ''), $6,
|
|
$7, $8, $9::jsonb
|
|
)
|
|
RETURNING id::text`,
|
|
input.TaskID,
|
|
input.AttemptNo,
|
|
input.PlatformID,
|
|
input.PlatformModelID,
|
|
input.ClientID,
|
|
input.QueueKey,
|
|
firstNonEmpty(input.Status, "running"),
|
|
input.Simulated,
|
|
string(requestJSON),
|
|
).Scan(&attemptID)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if _, err := tx.Exec(ctx, `
|
|
UPDATE gateway_tasks
|
|
SET attempt_count = GREATEST(attempt_count, $2), updated_at = now()
|
|
WHERE id = $1::uuid`, input.TaskID, input.AttemptNo); err != nil {
|
|
return "", err
|
|
}
|
|
return attemptID, tx.Commit(ctx)
|
|
}
|
|
|
|
func (s *Store) FinishTaskAttempt(ctx context.Context, input FinishTaskAttemptInput) error {
|
|
responseJSON, _ := json.Marshal(emptyObjectIfNil(input.ResponseSnapshot))
|
|
usageJSON, _ := json.Marshal(emptyObjectIfNil(input.Usage))
|
|
metricsJSON, _ := json.Marshal(emptyObjectIfNil(input.Metrics))
|
|
_, err := s.pool.Exec(ctx, `
|
|
UPDATE gateway_task_attempts
|
|
SET status = $2,
|
|
retryable = $3,
|
|
request_id = NULLIF($4, ''),
|
|
usage = $5::jsonb,
|
|
metrics = $6::jsonb,
|
|
response_snapshot = $7::jsonb,
|
|
response_started_at = $8::timestamptz,
|
|
response_finished_at = $9::timestamptz,
|
|
response_duration_ms = $10,
|
|
error_code = NULLIF($11, ''),
|
|
error_message = NULLIF($12, ''),
|
|
finished_at = now()
|
|
WHERE id = $1::uuid`,
|
|
input.AttemptID,
|
|
input.Status,
|
|
input.Retryable,
|
|
input.RequestID,
|
|
string(usageJSON),
|
|
string(metricsJSON),
|
|
string(responseJSON),
|
|
nullableTime(input.ResponseStartedAt),
|
|
nullableTime(input.ResponseFinishedAt),
|
|
input.ResponseDurationMS,
|
|
input.ErrorCode,
|
|
input.ErrorMessage,
|
|
)
|
|
return err
|
|
}
|
|
|
|
func (s *Store) FinishTaskSuccess(ctx context.Context, input FinishTaskSuccessInput) (GatewayTask, error) {
|
|
resultJSON, _ := json.Marshal(emptyObjectIfNil(input.Result))
|
|
billingsJSON, _ := json.Marshal(input.Billings)
|
|
usageJSON, _ := json.Marshal(emptyObjectIfNil(input.Usage))
|
|
metricsJSON, _ := json.Marshal(emptyObjectIfNil(input.Metrics))
|
|
billingSummaryJSON, _ := json.Marshal(emptyObjectIfNil(input.BillingSummary))
|
|
if _, err := s.pool.Exec(ctx, `
|
|
UPDATE gateway_tasks
|
|
SET status = 'succeeded',
|
|
result = $2::jsonb,
|
|
billings = $3::jsonb,
|
|
request_id = NULLIF($4, ''),
|
|
resolved_model = NULLIF($5, ''),
|
|
usage = $6::jsonb,
|
|
metrics = $7::jsonb,
|
|
billing_summary = $8::jsonb,
|
|
final_charge_amount = $9,
|
|
response_started_at = $10::timestamptz,
|
|
response_finished_at = $11::timestamptz,
|
|
response_duration_ms = $12,
|
|
error = NULL,
|
|
error_code = NULL,
|
|
error_message = NULL,
|
|
finished_at = now(),
|
|
updated_at = now()
|
|
WHERE id = $1::uuid`,
|
|
input.TaskID,
|
|
string(resultJSON),
|
|
string(billingsJSON),
|
|
input.RequestID,
|
|
input.ResolvedModel,
|
|
string(usageJSON),
|
|
string(metricsJSON),
|
|
string(billingSummaryJSON),
|
|
input.FinalChargeAmount,
|
|
nullableTime(input.ResponseStartedAt),
|
|
nullableTime(input.ResponseFinishedAt),
|
|
input.ResponseDurationMS,
|
|
); err != nil {
|
|
return GatewayTask{}, err
|
|
}
|
|
return s.GetTask(ctx, input.TaskID)
|
|
}
|
|
|
|
func (s *Store) SettleTaskBilling(ctx context.Context, task GatewayTask) error {
|
|
if task.FinalChargeAmount <= 0 || strings.TrimSpace(task.GatewayUserID) == "" {
|
|
return nil
|
|
}
|
|
currency := strings.TrimSpace(taskBillingString(task.BillingSummary["currency"]))
|
|
if currency == "" || currency == "mixed" {
|
|
currency = "resource"
|
|
}
|
|
metadata, _ := json.Marshal(map[string]any{
|
|
"taskId": task.ID,
|
|
"kind": task.Kind,
|
|
"model": task.Model,
|
|
"resolvedModel": task.ResolvedModel,
|
|
"billings": task.Billings,
|
|
"billingSummary": task.BillingSummary,
|
|
})
|
|
return pgx.BeginFunc(ctx, s.pool, func(tx pgx.Tx) error {
|
|
if _, err := tx.Exec(ctx, `
|
|
INSERT INTO gateway_wallet_accounts (
|
|
gateway_tenant_id, gateway_user_id, tenant_id, tenant_key, user_id, currency
|
|
)
|
|
VALUES (NULLIF($1, '')::uuid, $2::uuid, NULLIF($3, ''), NULLIF($4, ''), NULLIF($5, ''), $6)
|
|
ON CONFLICT (gateway_user_id, currency) DO NOTHING`,
|
|
task.GatewayTenantID, task.GatewayUserID, task.TenantID, task.TenantKey, task.UserID, currency); err != nil {
|
|
return err
|
|
}
|
|
var exists bool
|
|
if err := tx.QueryRow(ctx, `
|
|
SELECT EXISTS (
|
|
SELECT 1
|
|
FROM gateway_wallet_transactions t
|
|
JOIN gateway_wallet_accounts a ON a.id = t.account_id
|
|
WHERE a.gateway_user_id = $1::uuid
|
|
AND a.currency = $2
|
|
AND t.idempotency_key = $3
|
|
)`, task.GatewayUserID, currency, billingIdempotencyKey(task.ID)).Scan(&exists); err != nil {
|
|
return err
|
|
}
|
|
if exists {
|
|
return nil
|
|
}
|
|
var accountID string
|
|
var balanceBefore float64
|
|
var gatewayTenantID string
|
|
if err := tx.QueryRow(ctx, `
|
|
SELECT id::text, balance::float8, COALESCE(gateway_tenant_id::text, '')
|
|
FROM gateway_wallet_accounts
|
|
WHERE gateway_user_id = $1::uuid
|
|
AND currency = $2
|
|
FOR UPDATE`, task.GatewayUserID, currency).Scan(&accountID, &balanceBefore, &gatewayTenantID); err != nil {
|
|
return err
|
|
}
|
|
amount := roundMoney(task.FinalChargeAmount)
|
|
balanceAfter := roundMoney(balanceBefore - amount)
|
|
if _, err := tx.Exec(ctx, `
|
|
UPDATE gateway_wallet_accounts
|
|
SET balance = $2,
|
|
total_spent = total_spent + $3,
|
|
updated_at = now()
|
|
WHERE id = $1::uuid`, accountID, balanceAfter, amount); err != nil {
|
|
return err
|
|
}
|
|
_, err := tx.Exec(ctx, `
|
|
INSERT INTO gateway_wallet_transactions (
|
|
account_id, gateway_tenant_id, gateway_user_id, direction, transaction_type,
|
|
amount, balance_before, balance_after, idempotency_key, reference_type, reference_id, metadata
|
|
)
|
|
VALUES (
|
|
$1::uuid, NULLIF($2, '')::uuid, $3::uuid, 'debit', 'task_billing',
|
|
$4, $5, $6, $7, 'gateway_task', $8, $9::jsonb
|
|
)`,
|
|
accountID, firstNonEmpty(gatewayTenantID, task.GatewayTenantID), task.GatewayUserID, amount, roundMoney(balanceBefore), balanceAfter, billingIdempotencyKey(task.ID), task.ID, string(metadata))
|
|
if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == "23505" {
|
|
return nil
|
|
}
|
|
return err
|
|
})
|
|
}
|
|
|
|
func billingIdempotencyKey(taskID string) string {
|
|
return "task:" + taskID + ":billing"
|
|
}
|
|
|
|
func roundMoney(value float64) float64 {
|
|
if value < 0 {
|
|
return -roundMoney(-value)
|
|
}
|
|
return float64(int64(value*1000000+0.5)) / 1000000
|
|
}
|
|
|
|
func taskBillingString(value any) string {
|
|
if text, ok := value.(string); ok {
|
|
return text
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func (s *Store) FinishTaskFailure(ctx context.Context, input FinishTaskFailureInput) (GatewayTask, error) {
|
|
metricsJSON, _ := json.Marshal(emptyObjectIfNil(input.Metrics))
|
|
if _, err := s.pool.Exec(ctx, `
|
|
UPDATE gateway_tasks
|
|
SET status = 'failed',
|
|
error = NULLIF($2, ''),
|
|
error_code = NULLIF($3, ''),
|
|
error_message = NULLIF($2, ''),
|
|
request_id = NULLIF($4, ''),
|
|
metrics = $5::jsonb,
|
|
response_started_at = $6::timestamptz,
|
|
response_finished_at = $7::timestamptz,
|
|
response_duration_ms = $8,
|
|
finished_at = now(),
|
|
updated_at = now()
|
|
WHERE id = $1::uuid`,
|
|
input.TaskID,
|
|
input.Message,
|
|
input.Code,
|
|
input.RequestID,
|
|
string(metricsJSON),
|
|
nullableTime(input.ResponseStartedAt),
|
|
nullableTime(input.ResponseFinishedAt),
|
|
input.ResponseDurationMS,
|
|
); err != nil {
|
|
return GatewayTask{}, err
|
|
}
|
|
return s.GetTask(ctx, input.TaskID)
|
|
}
|
|
|
|
func nullableTime(value time.Time) any {
|
|
if value.IsZero() {
|
|
return nil
|
|
}
|
|
return value
|
|
}
|
|
|
|
func (s *Store) AddTaskEvent(ctx context.Context, taskID string, eventType string, status string, phase string, progress float64, message string, payload map[string]any, simulated bool) (TaskEvent, error) {
|
|
payloadJSON, _ := json.Marshal(emptyObjectIfNil(payload))
|
|
var event TaskEvent
|
|
var payloadBytes []byte
|
|
err := s.pool.QueryRow(ctx, `
|
|
WITH next_seq AS (
|
|
SELECT COALESCE(MAX(seq), 0) + 1 AS seq
|
|
FROM gateway_task_events
|
|
WHERE task_id = $1::uuid
|
|
)
|
|
INSERT INTO gateway_task_events (task_id, seq, event_type, status, phase, progress, message, payload, simulated)
|
|
SELECT $1::uuid, next_seq.seq, $2, NULLIF($3, ''), NULLIF($4, ''), $5, NULLIF($6, ''), $7::jsonb, $8
|
|
FROM next_seq
|
|
RETURNING id::text, task_id::text, seq, event_type, COALESCE(status, ''), COALESCE(phase, ''),
|
|
COALESCE(progress, 0)::float8, COALESCE(message, ''), payload, simulated, created_at`,
|
|
taskID,
|
|
eventType,
|
|
status,
|
|
phase,
|
|
progress,
|
|
message,
|
|
string(payloadJSON),
|
|
simulated,
|
|
).Scan(
|
|
&event.ID,
|
|
&event.TaskID,
|
|
&event.Seq,
|
|
&event.EventType,
|
|
&event.Status,
|
|
&event.Phase,
|
|
&event.Progress,
|
|
&event.Message,
|
|
&payloadBytes,
|
|
&event.Simulated,
|
|
&event.CreatedAt,
|
|
)
|
|
if err != nil {
|
|
return TaskEvent{}, err
|
|
}
|
|
event.Payload = decodeObject(payloadBytes)
|
|
return event, nil
|
|
}
|
|
|
|
func (s *Store) QueueTaskCallback(ctx context.Context, event TaskEvent, callbackURL string) error {
|
|
if callbackURL == "" {
|
|
return nil
|
|
}
|
|
payloadJSON, _ := json.Marshal(map[string]any{
|
|
"taskId": event.TaskID,
|
|
"seq": event.Seq,
|
|
"eventType": event.EventType,
|
|
"status": event.Status,
|
|
"phase": event.Phase,
|
|
"progress": event.Progress,
|
|
"message": event.Message,
|
|
"payload": event.Payload,
|
|
"simulated": event.Simulated,
|
|
"createdAt": event.CreatedAt,
|
|
})
|
|
_, err := s.pool.Exec(ctx, `
|
|
INSERT INTO gateway_task_callback_outbox (task_id, event_id, seq, callback_url, payload)
|
|
VALUES ($1::uuid, $2::uuid, $3, $4, $5::jsonb)
|
|
ON CONFLICT (task_id, seq, callback_url) DO NOTHING`,
|
|
event.TaskID,
|
|
event.ID,
|
|
event.Seq,
|
|
callbackURL,
|
|
string(payloadJSON),
|
|
)
|
|
return err
|
|
}
|
|
|
|
func (s *Store) RecordClientAssignment(ctx context.Context, candidate RuntimeModelCandidate) error {
|
|
_, err := s.pool.Exec(ctx, `
|
|
INSERT INTO runtime_client_states (
|
|
client_id, platform_id, provider, method_name, queue_key, running_count, last_assigned_at
|
|
)
|
|
VALUES ($1, $2::uuid, $3, $4, $5, 1, now())
|
|
ON CONFLICT (client_id) DO UPDATE
|
|
SET running_count = runtime_client_states.running_count + 1,
|
|
last_assigned_at = now(),
|
|
updated_at = now()`,
|
|
candidate.ClientID,
|
|
candidate.PlatformID,
|
|
candidate.Provider,
|
|
candidate.ModelType,
|
|
candidate.QueueKey,
|
|
)
|
|
return err
|
|
}
|
|
|
|
func (s *Store) RecordClientRelease(ctx context.Context, clientID string, lastError string) error {
|
|
_, err := s.pool.Exec(ctx, `
|
|
UPDATE runtime_client_states
|
|
SET running_count = GREATEST(running_count - 1, 0),
|
|
last_error = NULLIF($2, ''),
|
|
updated_at = now()
|
|
WHERE client_id = $1`, clientID, lastError)
|
|
return err
|
|
}
|