820 lines
23 KiB
Go
820 lines
23 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgconn"
|
|
)
|
|
|
|
type TaskListFilter struct {
|
|
Query string
|
|
ModelType string
|
|
CreatedFrom *time.Time
|
|
CreatedTo *time.Time
|
|
Page int
|
|
PageSize int
|
|
}
|
|
|
|
type TaskListResult struct {
|
|
Items []GatewayTask
|
|
Total int
|
|
Page int
|
|
PageSize int
|
|
}
|
|
|
|
func (s *Store) ListTasks(ctx context.Context, user *auth.User, filter TaskListFilter) (TaskListResult, error) {
|
|
page := filter.Page
|
|
if page <= 0 {
|
|
page = 1
|
|
}
|
|
pageSize := filter.PageSize
|
|
if pageSize <= 0 {
|
|
pageSize = 50
|
|
}
|
|
if pageSize > 100 {
|
|
pageSize = 100
|
|
}
|
|
offset := (page - 1) * pageSize
|
|
gatewayUserID := localGatewayUserID(user)
|
|
apiKeyID := ""
|
|
userID := ""
|
|
if user != nil {
|
|
apiKeyID = strings.TrimSpace(user.APIKeyID)
|
|
userID = strings.TrimSpace(user.ID)
|
|
}
|
|
if gatewayUserID == "" && userID == "" {
|
|
return TaskListResult{}, ErrLocalUserRequired
|
|
}
|
|
queryPattern := ""
|
|
if query := strings.TrimSpace(filter.Query); query != "" {
|
|
queryPattern = "%" + query + "%"
|
|
}
|
|
args := []any{
|
|
gatewayUserID,
|
|
userID,
|
|
apiKeyID,
|
|
queryPattern,
|
|
strings.TrimSpace(filter.ModelType),
|
|
nullableTaskListTime(filter.CreatedFrom),
|
|
nullableTaskListTime(filter.CreatedTo),
|
|
}
|
|
whereSQL := `
|
|
WHERE (
|
|
(
|
|
NULLIF($1, '')::uuid IS NOT NULL
|
|
AND gateway_user_id = NULLIF($1, '')::uuid
|
|
)
|
|
OR (
|
|
NULLIF($1, '')::uuid IS NULL
|
|
AND NULLIF($2, '') IS NOT NULL
|
|
AND user_id = $2
|
|
)
|
|
)
|
|
AND (
|
|
NULLIF($3, '') IS NULL
|
|
OR api_key_id = $3
|
|
)
|
|
AND (
|
|
NULLIF($4, '') IS NULL
|
|
OR id::text ILIKE $4
|
|
OR COALESCE(request_id, '') ILIKE $4
|
|
OR kind ILIKE $4
|
|
OR model ILIKE $4
|
|
OR COALESCE(requested_model, '') ILIKE $4
|
|
OR COALESCE(resolved_model, '') ILIKE $4
|
|
OR COALESCE(api_key_id, '') ILIKE $4
|
|
OR COALESCE(api_key_name, '') ILIKE $4
|
|
OR COALESCE(api_key_prefix, '') ILIKE $4
|
|
OR status ILIKE $4
|
|
OR COALESCE(model_type, '') ILIKE $4
|
|
)
|
|
AND (
|
|
NULLIF($5, '') IS NULL
|
|
OR model_type = $5
|
|
)
|
|
AND (
|
|
$6::timestamptz IS NULL
|
|
OR created_at >= $6::timestamptz
|
|
)
|
|
AND (
|
|
$7::timestamptz IS NULL
|
|
OR created_at <= $7::timestamptz
|
|
)`
|
|
var total int
|
|
if err := s.pool.QueryRow(ctx, `SELECT count(*) FROM gateway_tasks `+whereSQL, args...).Scan(&total); err != nil {
|
|
return TaskListResult{}, err
|
|
}
|
|
queryArgs := append(args, pageSize, offset)
|
|
rows, err := s.pool.Query(ctx, `
|
|
SELECT `+gatewayTaskColumns+`
|
|
FROM gateway_tasks
|
|
`+whereSQL+`
|
|
ORDER BY created_at DESC
|
|
LIMIT $8 OFFSET $9`, queryArgs...)
|
|
if err != nil {
|
|
return TaskListResult{}, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
items := make([]GatewayTask, 0)
|
|
for rows.Next() {
|
|
task, err := scanGatewayTask(rows)
|
|
if err != nil {
|
|
return TaskListResult{}, err
|
|
}
|
|
items = append(items, task)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return TaskListResult{}, err
|
|
}
|
|
items, err = s.attachTaskAttempts(ctx, items)
|
|
if err != nil {
|
|
return TaskListResult{}, err
|
|
}
|
|
return TaskListResult{
|
|
Items: items,
|
|
Total: total,
|
|
Page: page,
|
|
PageSize: pageSize,
|
|
}, nil
|
|
}
|
|
|
|
func nullableTaskListTime(value *time.Time) any {
|
|
if value == nil {
|
|
return nil
|
|
}
|
|
return *value
|
|
}
|
|
|
|
func (s *Store) MarkTaskRunning(ctx context.Context, taskID string, modelType string, normalizedRequest map[string]any) error {
|
|
normalizedJSON, _ := json.Marshal(emptyObjectIfNil(normalizedRequest))
|
|
_, err := s.pool.Exec(ctx, `
|
|
UPDATE gateway_tasks
|
|
SET status = 'running',
|
|
model_type = NULLIF($2::text, ''),
|
|
normalized_request = $3::jsonb,
|
|
locked_at = now(),
|
|
heartbeat_at = now(),
|
|
updated_at = now()
|
|
WHERE id = $1::uuid`, taskID, modelType, string(normalizedJSON))
|
|
return err
|
|
}
|
|
|
|
func (s *Store) ClaimAsyncQueuedTask(ctx context.Context, workerID string) (GatewayTask, error) {
|
|
return scanGatewayTask(s.pool.QueryRow(ctx, `
|
|
WITH picked AS (
|
|
SELECT id AS task_id
|
|
FROM gateway_tasks
|
|
WHERE async_mode = true
|
|
AND status = 'queued'
|
|
AND next_run_at <= now()
|
|
ORDER BY priority ASC, created_at ASC
|
|
LIMIT 1
|
|
FOR UPDATE SKIP LOCKED
|
|
)
|
|
UPDATE gateway_tasks t
|
|
SET status = 'running',
|
|
locked_by = NULLIF($1::text, ''),
|
|
locked_at = now(),
|
|
heartbeat_at = now(),
|
|
updated_at = now()
|
|
FROM picked
|
|
WHERE t.id = picked.task_id
|
|
RETURNING `+gatewayTaskColumns, workerID))
|
|
}
|
|
|
|
func (s *Store) RequeueTask(ctx context.Context, taskID string, delay time.Duration) (GatewayTask, error) {
|
|
if delay < time.Second {
|
|
delay = time.Second
|
|
}
|
|
if delay > 10*time.Minute {
|
|
delay = 10 * time.Minute
|
|
}
|
|
nextRunAt := time.Now().Add(delay)
|
|
return scanGatewayTask(s.pool.QueryRow(ctx, `
|
|
UPDATE gateway_tasks
|
|
SET status = 'queued',
|
|
locked_by = NULL,
|
|
locked_at = NULL,
|
|
heartbeat_at = NULL,
|
|
next_run_at = $2::timestamptz,
|
|
error = NULL,
|
|
error_code = NULL,
|
|
error_message = NULL,
|
|
updated_at = now()
|
|
WHERE id = $1::uuid
|
|
RETURNING `+gatewayTaskColumns, taskID, nextRunAt))
|
|
}
|
|
|
|
func (s *Store) SetTaskRiverJobID(ctx context.Context, taskID string, riverJobID int64) error {
|
|
if riverJobID <= 0 {
|
|
return nil
|
|
}
|
|
_, err := s.pool.Exec(ctx, `
|
|
UPDATE gateway_tasks
|
|
SET river_job_id = $2,
|
|
updated_at = now()
|
|
WHERE id = $1::uuid`, taskID, riverJobID)
|
|
return err
|
|
}
|
|
|
|
func (s *Store) SetTaskRemoteTask(ctx context.Context, taskID string, attemptID string, remoteTaskID string, payload map[string]any) error {
|
|
payloadJSON, _ := json.Marshal(emptyObjectIfNil(payload))
|
|
return pgx.BeginFunc(ctx, s.pool, func(tx pgx.Tx) error {
|
|
if _, err := tx.Exec(ctx, `
|
|
UPDATE gateway_tasks
|
|
SET remote_task_id = NULLIF($2::text, ''),
|
|
remote_task_payload = $3::jsonb,
|
|
updated_at = now()
|
|
WHERE id = $1::uuid`,
|
|
taskID,
|
|
remoteTaskID,
|
|
string(payloadJSON),
|
|
); err != nil {
|
|
return err
|
|
}
|
|
if strings.TrimSpace(attemptID) == "" {
|
|
return nil
|
|
}
|
|
_, err := tx.Exec(ctx, `
|
|
UPDATE gateway_task_attempts
|
|
SET remote_task_id = NULLIF($2::text, ''),
|
|
response_snapshot = COALESCE(response_snapshot, '{}'::jsonb) || jsonb_build_object('remote_task_payload', $3::jsonb)
|
|
WHERE id = $1::uuid`,
|
|
attemptID,
|
|
remoteTaskID,
|
|
string(payloadJSON),
|
|
)
|
|
return err
|
|
})
|
|
}
|
|
|
|
func (s *Store) ListRecoverableAsyncTasks(ctx context.Context, limit int) ([]AsyncTaskQueueItem, error) {
|
|
if limit <= 0 {
|
|
limit = 500
|
|
}
|
|
rows, err := s.pool.Query(ctx, `
|
|
SELECT id::text, priority
|
|
FROM gateway_tasks
|
|
WHERE async_mode = true
|
|
AND status IN ('queued', 'running')
|
|
ORDER BY priority ASC, created_at ASC
|
|
LIMIT $1`, limit)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
items := make([]AsyncTaskQueueItem, 0)
|
|
for rows.Next() {
|
|
var item AsyncTaskQueueItem
|
|
if err := rows.Scan(&item.ID, &item.Priority); err != nil {
|
|
return nil, err
|
|
}
|
|
items = append(items, item)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return items, nil
|
|
}
|
|
|
|
func (s *Store) CreateTaskAttempt(ctx context.Context, input CreateTaskAttemptInput) (string, error) {
|
|
requestJSON, _ := json.Marshal(emptyObjectIfNil(input.RequestSnapshot))
|
|
metricsJSON, _ := json.Marshal(emptyObjectIfNil(input.Metrics))
|
|
tx, err := s.pool.Begin(ctx)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer tx.Rollback(ctx)
|
|
|
|
var attemptID string
|
|
err = tx.QueryRow(ctx, `
|
|
INSERT INTO gateway_task_attempts (
|
|
task_id, attempt_no, platform_id, platform_model_id, client_id, queue_key,
|
|
status, simulated, request_snapshot, metrics
|
|
)
|
|
VALUES (
|
|
$1::uuid, $2::int, NULLIF($3::text, '')::uuid, NULLIF($4::text, '')::uuid, NULLIF($5::text, ''), $6,
|
|
$7, $8, $9::jsonb, $10::jsonb
|
|
)
|
|
RETURNING id::text`,
|
|
input.TaskID,
|
|
input.AttemptNo,
|
|
input.PlatformID,
|
|
input.PlatformModelID,
|
|
input.ClientID,
|
|
input.QueueKey,
|
|
firstNonEmpty(input.Status, "running"),
|
|
input.Simulated,
|
|
string(requestJSON),
|
|
string(metricsJSON),
|
|
).Scan(&attemptID)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if _, err := tx.Exec(ctx, `
|
|
UPDATE gateway_tasks
|
|
SET attempt_count = GREATEST(attempt_count, $2::int), updated_at = now()
|
|
WHERE id = $1::uuid`, input.TaskID, input.AttemptNo); err != nil {
|
|
return "", err
|
|
}
|
|
return attemptID, tx.Commit(ctx)
|
|
}
|
|
|
|
func (s *Store) attachTaskAttempts(ctx context.Context, items []GatewayTask) ([]GatewayTask, error) {
|
|
if len(items) == 0 {
|
|
return items, nil
|
|
}
|
|
taskIDs := make([]string, 0, len(items))
|
|
for _, item := range items {
|
|
taskIDs = append(taskIDs, item.ID)
|
|
}
|
|
attemptsByTaskID, err := s.listTaskAttemptsByTaskIDs(ctx, taskIDs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for index := range items {
|
|
items[index].Attempts = attemptsByTaskID[items[index].ID]
|
|
}
|
|
return items, nil
|
|
}
|
|
|
|
func (s *Store) ListTaskAttempts(ctx context.Context, taskID string) ([]TaskAttempt, error) {
|
|
attemptsByTaskID, err := s.listTaskAttemptsByTaskIDs(ctx, []string{taskID})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return attemptsByTaskID[taskID], nil
|
|
}
|
|
|
|
func (s *Store) AppendTaskAttemptTrace(ctx context.Context, taskID string, attemptNo int, entry map[string]any) error {
|
|
entryJSON, _ := json.Marshal(emptyObjectIfNil(entry))
|
|
_, err := s.pool.Exec(ctx, `
|
|
UPDATE gateway_task_attempts
|
|
SET metrics = jsonb_set(
|
|
COALESCE(metrics, '{}'::jsonb),
|
|
'{trace}',
|
|
(
|
|
CASE
|
|
WHEN jsonb_typeof(COALESCE(metrics->'trace', '[]'::jsonb)) = 'array'
|
|
THEN COALESCE(metrics->'trace', '[]'::jsonb)
|
|
ELSE '[]'::jsonb
|
|
END
|
|
) || jsonb_build_array($3::jsonb),
|
|
true
|
|
)
|
|
WHERE task_id = $1::uuid
|
|
AND attempt_no = $2::int`, taskID, attemptNo, string(entryJSON))
|
|
return err
|
|
}
|
|
|
|
func (s *Store) listTaskAttemptsByTaskIDs(ctx context.Context, taskIDs []string) (map[string][]TaskAttempt, error) {
|
|
itemsByTaskID := map[string][]TaskAttempt{}
|
|
if len(taskIDs) == 0 {
|
|
return itemsByTaskID, nil
|
|
}
|
|
rows, err := s.pool.Query(ctx, `
|
|
SELECT a.id::text, a.task_id::text, a.attempt_no,
|
|
COALESCE(a.platform_id::text, ''), COALESCE(p.name, ''), COALESCE(p.provider, ''),
|
|
COALESCE(a.platform_model_id::text, ''), COALESCE(pm.model_name, ''),
|
|
COALESCE(NULLIF(pm.provider_model_name, ''), pm.model_name, ''),
|
|
COALESCE(pm.model_alias, ''),
|
|
COALESCE(a.client_id, ''), a.queue_key, a.status, a.retryable, a.simulated,
|
|
COALESCE(a.request_id, ''), COALESCE(a.usage, '{}'::jsonb), COALESCE(a.metrics, '{}'::jsonb),
|
|
a.request_snapshot, COALESCE(a.response_snapshot, '{}'::jsonb),
|
|
COALESCE(a.response_started_at::text, ''), COALESCE(a.response_finished_at::text, ''),
|
|
COALESCE(a.response_duration_ms, 0), COALESCE(a.error_code, ''), COALESCE(a.error_message, ''),
|
|
a.started_at, COALESCE(a.finished_at::text, '')
|
|
FROM gateway_task_attempts a
|
|
LEFT JOIN integration_platforms p ON p.id = a.platform_id
|
|
LEFT JOIN platform_models pm ON pm.id = a.platform_model_id
|
|
WHERE a.task_id::text = ANY($1)
|
|
ORDER BY a.task_id, a.attempt_no`, taskIDs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
item, err := scanTaskAttempt(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
itemsByTaskID[item.TaskID] = append(itemsByTaskID[item.TaskID], item)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return itemsByTaskID, nil
|
|
}
|
|
|
|
func scanTaskAttempt(scanner taskScanner) (TaskAttempt, error) {
|
|
var item TaskAttempt
|
|
var usageBytes []byte
|
|
var metricsBytes []byte
|
|
var requestBytes []byte
|
|
var responseBytes []byte
|
|
if err := scanner.Scan(
|
|
&item.ID,
|
|
&item.TaskID,
|
|
&item.AttemptNo,
|
|
&item.PlatformID,
|
|
&item.PlatformName,
|
|
&item.Provider,
|
|
&item.PlatformModelID,
|
|
&item.ModelName,
|
|
&item.ProviderModelName,
|
|
&item.ModelAlias,
|
|
&item.ClientID,
|
|
&item.QueueKey,
|
|
&item.Status,
|
|
&item.Retryable,
|
|
&item.Simulated,
|
|
&item.RequestID,
|
|
&usageBytes,
|
|
&metricsBytes,
|
|
&requestBytes,
|
|
&responseBytes,
|
|
&item.ResponseStartedAt,
|
|
&item.ResponseFinishedAt,
|
|
&item.ResponseDurationMS,
|
|
&item.ErrorCode,
|
|
&item.ErrorMessage,
|
|
&item.StartedAt,
|
|
&item.FinishedAt,
|
|
); err != nil {
|
|
return TaskAttempt{}, err
|
|
}
|
|
item.Usage = decodeObject(usageBytes)
|
|
item.Metrics = decodeObject(metricsBytes)
|
|
item.RequestSnapshot = decodeObject(requestBytes)
|
|
item.ResponseSnapshot = decodeObject(responseBytes)
|
|
enrichTaskAttemptFromMetrics(&item)
|
|
return item, nil
|
|
}
|
|
|
|
func enrichTaskAttemptFromMetrics(item *TaskAttempt) {
|
|
if item == nil || len(item.Metrics) == 0 {
|
|
return
|
|
}
|
|
item.PlatformID = firstNonEmpty(item.PlatformID, taskAttemptMetricString(item.Metrics, "platformId"))
|
|
item.PlatformName = firstNonEmpty(item.PlatformName, taskAttemptMetricString(item.Metrics, "platformName"))
|
|
item.Provider = firstNonEmpty(item.Provider, taskAttemptMetricString(item.Metrics, "provider"))
|
|
item.PlatformModelID = firstNonEmpty(item.PlatformModelID, taskAttemptMetricString(item.Metrics, "platformModelId"))
|
|
item.ModelName = firstNonEmpty(item.ModelName, taskAttemptMetricString(item.Metrics, "resolvedModel"), taskAttemptMetricString(item.Metrics, "modelName"))
|
|
item.ProviderModelName = firstNonEmpty(item.ProviderModelName, taskAttemptMetricString(item.Metrics, "providerModel"))
|
|
item.ModelAlias = firstNonEmpty(item.ModelAlias, taskAttemptMetricString(item.Metrics, "modelAlias"))
|
|
item.ModelType = firstNonEmpty(item.ModelType, taskAttemptMetricString(item.Metrics, "modelType"))
|
|
item.ClientID = firstNonEmpty(item.ClientID, taskAttemptMetricString(item.Metrics, "clientId"))
|
|
item.StatusCode = taskAttemptMetricInt(item.Metrics, "statusCode")
|
|
}
|
|
|
|
func taskAttemptMetricString(metrics map[string]any, key string) string {
|
|
value, _ := metrics[key].(string)
|
|
return strings.TrimSpace(value)
|
|
}
|
|
|
|
func taskAttemptMetricInt(metrics map[string]any, key string) int {
|
|
switch value := metrics[key].(type) {
|
|
case int:
|
|
return value
|
|
case int64:
|
|
return int(value)
|
|
case float64:
|
|
return int(value)
|
|
case json.Number:
|
|
next, _ := value.Int64()
|
|
return int(next)
|
|
case string:
|
|
next, _ := strconv.Atoi(strings.TrimSpace(value))
|
|
return next
|
|
default:
|
|
return 0
|
|
}
|
|
}
|
|
|
|
func (s *Store) FinishTaskAttempt(ctx context.Context, input FinishTaskAttemptInput) error {
|
|
responseJSON, _ := json.Marshal(emptyObjectIfNil(input.ResponseSnapshot))
|
|
usageJSON, _ := json.Marshal(emptyObjectIfNil(input.Usage))
|
|
metricsJSON, _ := json.Marshal(emptyObjectIfNil(input.Metrics))
|
|
_, err := s.pool.Exec(ctx, `
|
|
UPDATE gateway_task_attempts
|
|
SET status = $2::text,
|
|
retryable = $3,
|
|
request_id = NULLIF($4::text, ''),
|
|
usage = $5::jsonb,
|
|
metrics = $6::jsonb,
|
|
response_snapshot = $7::jsonb,
|
|
response_started_at = $8::timestamptz,
|
|
response_finished_at = $9::timestamptz,
|
|
response_duration_ms = $10,
|
|
error_code = NULLIF($11::text, ''),
|
|
error_message = NULLIF($12::text, ''),
|
|
finished_at = now()
|
|
WHERE id = $1::uuid`,
|
|
input.AttemptID,
|
|
input.Status,
|
|
input.Retryable,
|
|
input.RequestID,
|
|
string(usageJSON),
|
|
string(metricsJSON),
|
|
string(responseJSON),
|
|
nullableTime(input.ResponseStartedAt),
|
|
nullableTime(input.ResponseFinishedAt),
|
|
input.ResponseDurationMS,
|
|
input.ErrorCode,
|
|
input.ErrorMessage,
|
|
)
|
|
return err
|
|
}
|
|
|
|
func (s *Store) FinishTaskSuccess(ctx context.Context, input FinishTaskSuccessInput) (GatewayTask, error) {
|
|
resultJSON, _ := json.Marshal(emptyObjectIfNil(input.Result))
|
|
billingsJSON, _ := json.Marshal(input.Billings)
|
|
usageJSON, _ := json.Marshal(emptyObjectIfNil(input.Usage))
|
|
metricsJSON, _ := json.Marshal(emptyObjectIfNil(input.Metrics))
|
|
billingSummaryJSON, _ := json.Marshal(emptyObjectIfNil(input.BillingSummary))
|
|
if _, err := s.pool.Exec(ctx, `
|
|
UPDATE gateway_tasks
|
|
SET status = 'succeeded',
|
|
result = $2::jsonb,
|
|
billings = $3::jsonb,
|
|
request_id = NULLIF($4, ''),
|
|
resolved_model = NULLIF($5, ''),
|
|
usage = $6::jsonb,
|
|
metrics = $7::jsonb,
|
|
billing_summary = $8::jsonb,
|
|
final_charge_amount = $9,
|
|
response_started_at = $10::timestamptz,
|
|
response_finished_at = $11::timestamptz,
|
|
response_duration_ms = $12,
|
|
error = NULL,
|
|
error_code = NULL,
|
|
error_message = NULL,
|
|
locked_by = NULL,
|
|
locked_at = NULL,
|
|
heartbeat_at = NULL,
|
|
finished_at = now(),
|
|
updated_at = now()
|
|
WHERE id = $1::uuid`,
|
|
input.TaskID,
|
|
string(resultJSON),
|
|
string(billingsJSON),
|
|
input.RequestID,
|
|
input.ResolvedModel,
|
|
string(usageJSON),
|
|
string(metricsJSON),
|
|
string(billingSummaryJSON),
|
|
input.FinalChargeAmount,
|
|
nullableTime(input.ResponseStartedAt),
|
|
nullableTime(input.ResponseFinishedAt),
|
|
input.ResponseDurationMS,
|
|
); err != nil {
|
|
return GatewayTask{}, err
|
|
}
|
|
return s.GetTask(ctx, input.TaskID)
|
|
}
|
|
|
|
func (s *Store) SettleTaskBilling(ctx context.Context, task GatewayTask) error {
|
|
if task.FinalChargeAmount <= 0 || strings.TrimSpace(task.GatewayUserID) == "" {
|
|
return nil
|
|
}
|
|
currency := strings.TrimSpace(taskBillingString(task.BillingSummary["currency"]))
|
|
if currency == "" || currency == "mixed" {
|
|
currency = "resource"
|
|
}
|
|
metadata, _ := json.Marshal(map[string]any{
|
|
"taskId": task.ID,
|
|
"kind": task.Kind,
|
|
"model": task.Model,
|
|
"resolvedModel": task.ResolvedModel,
|
|
"billings": task.Billings,
|
|
"billingSummary": task.BillingSummary,
|
|
})
|
|
return pgx.BeginFunc(ctx, s.pool, func(tx pgx.Tx) error {
|
|
if _, err := tx.Exec(ctx, `
|
|
INSERT INTO gateway_wallet_accounts (
|
|
gateway_tenant_id, gateway_user_id, tenant_id, tenant_key, user_id, currency
|
|
)
|
|
VALUES (NULLIF($1, '')::uuid, $2::uuid, NULLIF($3, ''), NULLIF($4, ''), NULLIF($5, ''), $6)
|
|
ON CONFLICT (gateway_user_id, currency) DO NOTHING`,
|
|
task.GatewayTenantID, task.GatewayUserID, task.TenantID, task.TenantKey, task.UserID, currency); err != nil {
|
|
return err
|
|
}
|
|
var exists bool
|
|
if err := tx.QueryRow(ctx, `
|
|
SELECT EXISTS (
|
|
SELECT 1
|
|
FROM gateway_wallet_transactions t
|
|
JOIN gateway_wallet_accounts a ON a.id = t.account_id
|
|
WHERE a.gateway_user_id = $1::uuid
|
|
AND a.currency = $2
|
|
AND t.idempotency_key = $3
|
|
)`, task.GatewayUserID, currency, billingIdempotencyKey(task.ID)).Scan(&exists); err != nil {
|
|
return err
|
|
}
|
|
if exists {
|
|
return nil
|
|
}
|
|
var accountID string
|
|
var balanceBefore float64
|
|
var gatewayTenantID string
|
|
if err := tx.QueryRow(ctx, `
|
|
SELECT id::text, balance::float8, COALESCE(gateway_tenant_id::text, '')
|
|
FROM gateway_wallet_accounts
|
|
WHERE gateway_user_id = $1::uuid
|
|
AND currency = $2
|
|
FOR UPDATE`, task.GatewayUserID, currency).Scan(&accountID, &balanceBefore, &gatewayTenantID); err != nil {
|
|
return err
|
|
}
|
|
amount := roundMoney(task.FinalChargeAmount)
|
|
balanceAfter := roundMoney(balanceBefore - amount)
|
|
if _, err := tx.Exec(ctx, `
|
|
UPDATE gateway_wallet_accounts
|
|
SET balance = $2,
|
|
total_spent = total_spent + $3,
|
|
updated_at = now()
|
|
WHERE id = $1::uuid`, accountID, balanceAfter, amount); err != nil {
|
|
return err
|
|
}
|
|
_, err := tx.Exec(ctx, `
|
|
INSERT INTO gateway_wallet_transactions (
|
|
account_id, gateway_tenant_id, gateway_user_id, direction, transaction_type,
|
|
amount, balance_before, balance_after, idempotency_key, reference_type, reference_id, metadata
|
|
)
|
|
VALUES (
|
|
$1::uuid, NULLIF($2, '')::uuid, $3::uuid, 'debit', 'task_billing',
|
|
$4, $5, $6, $7, 'gateway_task', $8, $9::jsonb
|
|
)`,
|
|
accountID, firstNonEmpty(gatewayTenantID, task.GatewayTenantID), task.GatewayUserID, amount, roundMoney(balanceBefore), balanceAfter, billingIdempotencyKey(task.ID), task.ID, string(metadata))
|
|
if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == "23505" {
|
|
return nil
|
|
}
|
|
return err
|
|
})
|
|
}
|
|
|
|
func billingIdempotencyKey(taskID string) string {
|
|
return "task:" + taskID + ":billing"
|
|
}
|
|
|
|
func roundMoney(value float64) float64 {
|
|
if value < 0 {
|
|
return -roundMoney(-value)
|
|
}
|
|
return float64(int64(value*1000000+0.5)) / 1000000
|
|
}
|
|
|
|
func taskBillingString(value any) string {
|
|
if text, ok := value.(string); ok {
|
|
return text
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func (s *Store) FinishTaskFailure(ctx context.Context, input FinishTaskFailureInput) (GatewayTask, error) {
|
|
metricsJSON, _ := json.Marshal(emptyObjectIfNil(input.Metrics))
|
|
if _, err := s.pool.Exec(ctx, `
|
|
UPDATE gateway_tasks
|
|
SET status = 'failed',
|
|
error = NULLIF($2::text, ''),
|
|
error_code = NULLIF($3::text, ''),
|
|
error_message = NULLIF($2::text, ''),
|
|
request_id = NULLIF($4::text, ''),
|
|
metrics = $5::jsonb,
|
|
response_started_at = $6::timestamptz,
|
|
response_finished_at = $7::timestamptz,
|
|
response_duration_ms = $8,
|
|
locked_by = NULL,
|
|
locked_at = NULL,
|
|
heartbeat_at = NULL,
|
|
finished_at = now(),
|
|
updated_at = now()
|
|
WHERE id = $1::uuid`,
|
|
input.TaskID,
|
|
input.Message,
|
|
input.Code,
|
|
input.RequestID,
|
|
string(metricsJSON),
|
|
nullableTime(input.ResponseStartedAt),
|
|
nullableTime(input.ResponseFinishedAt),
|
|
input.ResponseDurationMS,
|
|
); err != nil {
|
|
return GatewayTask{}, err
|
|
}
|
|
return s.GetTask(ctx, input.TaskID)
|
|
}
|
|
|
|
func nullableTime(value time.Time) any {
|
|
if value.IsZero() {
|
|
return nil
|
|
}
|
|
return value
|
|
}
|
|
|
|
func (s *Store) AddTaskEvent(ctx context.Context, taskID string, eventType string, status string, phase string, progress float64, message string, payload map[string]any, simulated bool) (TaskEvent, error) {
|
|
payloadJSON, _ := json.Marshal(emptyObjectIfNil(payload))
|
|
var event TaskEvent
|
|
var payloadBytes []byte
|
|
err := s.pool.QueryRow(ctx, `
|
|
WITH next_seq AS (
|
|
SELECT COALESCE(MAX(seq), 0) + 1 AS seq
|
|
FROM gateway_task_events
|
|
WHERE task_id = $1::uuid
|
|
)
|
|
INSERT INTO gateway_task_events (task_id, seq, event_type, status, phase, progress, message, payload, simulated)
|
|
SELECT $1::uuid, next_seq.seq, $2::text, NULLIF($3::text, ''), NULLIF($4::text, ''), $5, NULLIF($6::text, ''), $7::jsonb, $8
|
|
FROM next_seq
|
|
RETURNING id::text, task_id::text, seq, event_type, COALESCE(status, ''), COALESCE(phase, ''),
|
|
COALESCE(progress, 0)::float8, COALESCE(message, ''), payload, simulated, created_at`,
|
|
taskID,
|
|
eventType,
|
|
status,
|
|
phase,
|
|
progress,
|
|
message,
|
|
string(payloadJSON),
|
|
simulated,
|
|
).Scan(
|
|
&event.ID,
|
|
&event.TaskID,
|
|
&event.Seq,
|
|
&event.EventType,
|
|
&event.Status,
|
|
&event.Phase,
|
|
&event.Progress,
|
|
&event.Message,
|
|
&payloadBytes,
|
|
&event.Simulated,
|
|
&event.CreatedAt,
|
|
)
|
|
if err != nil {
|
|
return TaskEvent{}, err
|
|
}
|
|
event.Payload = decodeObject(payloadBytes)
|
|
return event, nil
|
|
}
|
|
|
|
func (s *Store) QueueTaskCallback(ctx context.Context, event TaskEvent, callbackURL string) error {
|
|
if callbackURL == "" {
|
|
return nil
|
|
}
|
|
payloadJSON, _ := json.Marshal(map[string]any{
|
|
"taskId": event.TaskID,
|
|
"seq": event.Seq,
|
|
"eventType": event.EventType,
|
|
"status": event.Status,
|
|
"phase": event.Phase,
|
|
"progress": event.Progress,
|
|
"message": event.Message,
|
|
"payload": event.Payload,
|
|
"simulated": event.Simulated,
|
|
"createdAt": event.CreatedAt,
|
|
})
|
|
_, err := s.pool.Exec(ctx, `
|
|
INSERT INTO gateway_task_callback_outbox (task_id, event_id, seq, callback_url, payload)
|
|
VALUES ($1::uuid, $2::uuid, $3, $4, $5::jsonb)
|
|
ON CONFLICT (task_id, seq, callback_url) DO NOTHING`,
|
|
event.TaskID,
|
|
event.ID,
|
|
event.Seq,
|
|
callbackURL,
|
|
string(payloadJSON),
|
|
)
|
|
return err
|
|
}
|
|
|
|
func (s *Store) RecordClientAssignment(ctx context.Context, candidate RuntimeModelCandidate) error {
|
|
_, err := s.pool.Exec(ctx, `
|
|
INSERT INTO runtime_client_states (
|
|
client_id, platform_id, provider, method_name, queue_key, running_count, last_assigned_at
|
|
)
|
|
VALUES ($1, $2::uuid, $3, $4, $5, 1, now())
|
|
ON CONFLICT (client_id) DO UPDATE
|
|
SET running_count = runtime_client_states.running_count + 1,
|
|
last_assigned_at = now(),
|
|
updated_at = now()`,
|
|
candidate.ClientID,
|
|
candidate.PlatformID,
|
|
candidate.Provider,
|
|
candidate.ModelType,
|
|
candidate.QueueKey,
|
|
)
|
|
return err
|
|
}
|
|
|
|
func (s *Store) RecordClientRelease(ctx context.Context, clientID string, lastError string) error {
|
|
_, err := s.pool.Exec(ctx, `
|
|
UPDATE runtime_client_states
|
|
SET running_count = GREATEST(running_count - 1, 0),
|
|
last_error = NULLIF($2::text, ''),
|
|
updated_at = now()
|
|
WHERE client_id = $1`, clientID, lastError)
|
|
return err
|
|
}
|