easyai-ai-gateway/apps/api/internal/store/tasks_runtime.go

218 lines
6.2 KiB
Go

package store
import (
"context"
"encoding/json"
)
func (s *Store) MarkTaskRunning(ctx context.Context, taskID string, modelType string, normalizedRequest map[string]any) error {
normalizedJSON, _ := json.Marshal(emptyObjectIfNil(normalizedRequest))
_, err := s.pool.Exec(ctx, `
UPDATE gateway_tasks
SET status = 'running',
model_type = NULLIF($2, ''),
normalized_request = $3::jsonb,
locked_at = now(),
heartbeat_at = now(),
updated_at = now()
WHERE id = $1::uuid`, taskID, modelType, string(normalizedJSON))
return err
}
func (s *Store) CreateTaskAttempt(ctx context.Context, input CreateTaskAttemptInput) (string, error) {
requestJSON, _ := json.Marshal(emptyObjectIfNil(input.RequestSnapshot))
tx, err := s.pool.Begin(ctx)
if err != nil {
return "", err
}
defer tx.Rollback(ctx)
var attemptID string
err = tx.QueryRow(ctx, `
INSERT INTO gateway_task_attempts (
task_id, attempt_no, platform_id, platform_model_id, client_id, queue_key,
status, simulated, request_snapshot
)
VALUES (
$1::uuid, $2, NULLIF($3, '')::uuid, NULLIF($4, '')::uuid, NULLIF($5, ''), $6,
$7, $8, $9::jsonb
)
RETURNING id::text`,
input.TaskID,
input.AttemptNo,
input.PlatformID,
input.PlatformModelID,
input.ClientID,
input.QueueKey,
firstNonEmpty(input.Status, "running"),
input.Simulated,
string(requestJSON),
).Scan(&attemptID)
if err != nil {
return "", err
}
if _, err := tx.Exec(ctx, `
UPDATE gateway_tasks
SET attempt_count = GREATEST(attempt_count, $2), updated_at = now()
WHERE id = $1::uuid`, input.TaskID, input.AttemptNo); err != nil {
return "", err
}
return attemptID, tx.Commit(ctx)
}
func (s *Store) FinishTaskAttempt(ctx context.Context, input FinishTaskAttemptInput) error {
responseJSON, _ := json.Marshal(emptyObjectIfNil(input.ResponseSnapshot))
_, err := s.pool.Exec(ctx, `
UPDATE gateway_task_attempts
SET status = $2,
retryable = $3,
response_snapshot = $4::jsonb,
error_code = NULLIF($5, ''),
error_message = NULLIF($6, ''),
finished_at = now()
WHERE id = $1::uuid`,
input.AttemptID,
input.Status,
input.Retryable,
string(responseJSON),
input.ErrorCode,
input.ErrorMessage,
)
return err
}
func (s *Store) FinishTaskSuccess(ctx context.Context, taskID string, result map[string]any, billings []any) (GatewayTask, error) {
resultJSON, _ := json.Marshal(emptyObjectIfNil(result))
billingsJSON, _ := json.Marshal(billings)
if _, err := s.pool.Exec(ctx, `
UPDATE gateway_tasks
SET status = 'succeeded',
result = $2::jsonb,
billings = $3::jsonb,
error = NULL,
error_code = NULL,
error_message = NULL,
finished_at = now(),
updated_at = now()
WHERE id = $1::uuid`, taskID, string(resultJSON), string(billingsJSON)); err != nil {
return GatewayTask{}, err
}
return s.GetTask(ctx, taskID)
}
func (s *Store) FinishTaskFailure(ctx context.Context, taskID string, code string, message string) (GatewayTask, error) {
if _, err := s.pool.Exec(ctx, `
UPDATE gateway_tasks
SET status = 'failed',
error = NULLIF($2, ''),
error_code = NULLIF($3, ''),
error_message = NULLIF($2, ''),
finished_at = now(),
updated_at = now()
WHERE id = $1::uuid`, taskID, message, code); err != nil {
return GatewayTask{}, err
}
return s.GetTask(ctx, taskID)
}
func (s *Store) AddTaskEvent(ctx context.Context, taskID string, eventType string, status string, phase string, progress float64, message string, payload map[string]any, simulated bool) (TaskEvent, error) {
payloadJSON, _ := json.Marshal(emptyObjectIfNil(payload))
var event TaskEvent
var payloadBytes []byte
err := s.pool.QueryRow(ctx, `
WITH next_seq AS (
SELECT COALESCE(MAX(seq), 0) + 1 AS seq
FROM gateway_task_events
WHERE task_id = $1::uuid
)
INSERT INTO gateway_task_events (task_id, seq, event_type, status, phase, progress, message, payload, simulated)
SELECT $1::uuid, next_seq.seq, $2, NULLIF($3, ''), NULLIF($4, ''), $5, NULLIF($6, ''), $7::jsonb, $8
FROM next_seq
RETURNING id::text, task_id::text, seq, event_type, COALESCE(status, ''), COALESCE(phase, ''),
COALESCE(progress, 0)::float8, COALESCE(message, ''), payload, simulated, created_at`,
taskID,
eventType,
status,
phase,
progress,
message,
string(payloadJSON),
simulated,
).Scan(
&event.ID,
&event.TaskID,
&event.Seq,
&event.EventType,
&event.Status,
&event.Phase,
&event.Progress,
&event.Message,
&payloadBytes,
&event.Simulated,
&event.CreatedAt,
)
if err != nil {
return TaskEvent{}, err
}
event.Payload = decodeObject(payloadBytes)
return event, nil
}
func (s *Store) QueueTaskCallback(ctx context.Context, event TaskEvent, callbackURL string) error {
if callbackURL == "" {
return nil
}
payloadJSON, _ := json.Marshal(map[string]any{
"taskId": event.TaskID,
"seq": event.Seq,
"eventType": event.EventType,
"status": event.Status,
"phase": event.Phase,
"progress": event.Progress,
"message": event.Message,
"payload": event.Payload,
"simulated": event.Simulated,
"createdAt": event.CreatedAt,
})
_, err := s.pool.Exec(ctx, `
INSERT INTO gateway_task_callback_outbox (task_id, event_id, seq, callback_url, payload)
VALUES ($1::uuid, $2::uuid, $3, $4, $5::jsonb)
ON CONFLICT (task_id, seq, callback_url) DO NOTHING`,
event.TaskID,
event.ID,
event.Seq,
callbackURL,
string(payloadJSON),
)
return err
}
func (s *Store) RecordClientAssignment(ctx context.Context, candidate RuntimeModelCandidate) error {
_, err := s.pool.Exec(ctx, `
INSERT INTO runtime_client_states (
client_id, platform_id, provider, method_name, queue_key, running_count, last_assigned_at
)
VALUES ($1, $2::uuid, $3, $4, $5, 1, now())
ON CONFLICT (client_id) DO UPDATE
SET running_count = runtime_client_states.running_count + 1,
last_assigned_at = now(),
updated_at = now()`,
candidate.ClientID,
candidate.PlatformID,
candidate.Provider,
candidate.ModelType,
candidate.QueueKey,
)
return err
}
func (s *Store) RecordClientRelease(ctx context.Context, clientID string, lastError string) error {
_, err := s.pool.Exec(ctx, `
UPDATE runtime_client_states
SET running_count = GREATEST(running_count - 1, 0),
last_error = NULLIF($2, ''),
updated_at = now()
WHERE client_id = $1`, clientID, lastError)
return err
}