280 lines
7.9 KiB
Go
280 lines
7.9 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"time"
|
|
)
|
|
|
|
func (s *Store) MarkTaskRunning(ctx context.Context, taskID string, modelType string, normalizedRequest map[string]any) error {
|
|
normalizedJSON, _ := json.Marshal(emptyObjectIfNil(normalizedRequest))
|
|
_, err := s.pool.Exec(ctx, `
|
|
UPDATE gateway_tasks
|
|
SET status = 'running',
|
|
model_type = NULLIF($2, ''),
|
|
normalized_request = $3::jsonb,
|
|
locked_at = now(),
|
|
heartbeat_at = now(),
|
|
updated_at = now()
|
|
WHERE id = $1::uuid`, taskID, modelType, string(normalizedJSON))
|
|
return err
|
|
}
|
|
|
|
func (s *Store) CreateTaskAttempt(ctx context.Context, input CreateTaskAttemptInput) (string, error) {
|
|
requestJSON, _ := json.Marshal(emptyObjectIfNil(input.RequestSnapshot))
|
|
tx, err := s.pool.Begin(ctx)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer tx.Rollback(ctx)
|
|
|
|
var attemptID string
|
|
err = tx.QueryRow(ctx, `
|
|
INSERT INTO gateway_task_attempts (
|
|
task_id, attempt_no, platform_id, platform_model_id, client_id, queue_key,
|
|
status, simulated, request_snapshot
|
|
)
|
|
VALUES (
|
|
$1::uuid, $2, NULLIF($3, '')::uuid, NULLIF($4, '')::uuid, NULLIF($5, ''), $6,
|
|
$7, $8, $9::jsonb
|
|
)
|
|
RETURNING id::text`,
|
|
input.TaskID,
|
|
input.AttemptNo,
|
|
input.PlatformID,
|
|
input.PlatformModelID,
|
|
input.ClientID,
|
|
input.QueueKey,
|
|
firstNonEmpty(input.Status, "running"),
|
|
input.Simulated,
|
|
string(requestJSON),
|
|
).Scan(&attemptID)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if _, err := tx.Exec(ctx, `
|
|
UPDATE gateway_tasks
|
|
SET attempt_count = GREATEST(attempt_count, $2), updated_at = now()
|
|
WHERE id = $1::uuid`, input.TaskID, input.AttemptNo); err != nil {
|
|
return "", err
|
|
}
|
|
return attemptID, tx.Commit(ctx)
|
|
}
|
|
|
|
func (s *Store) FinishTaskAttempt(ctx context.Context, input FinishTaskAttemptInput) error {
|
|
responseJSON, _ := json.Marshal(emptyObjectIfNil(input.ResponseSnapshot))
|
|
usageJSON, _ := json.Marshal(emptyObjectIfNil(input.Usage))
|
|
metricsJSON, _ := json.Marshal(emptyObjectIfNil(input.Metrics))
|
|
_, err := s.pool.Exec(ctx, `
|
|
UPDATE gateway_task_attempts
|
|
SET status = $2,
|
|
retryable = $3,
|
|
request_id = NULLIF($4, ''),
|
|
usage = $5::jsonb,
|
|
metrics = $6::jsonb,
|
|
response_snapshot = $7::jsonb,
|
|
response_started_at = $8::timestamptz,
|
|
response_finished_at = $9::timestamptz,
|
|
response_duration_ms = $10,
|
|
error_code = NULLIF($11, ''),
|
|
error_message = NULLIF($12, ''),
|
|
finished_at = now()
|
|
WHERE id = $1::uuid`,
|
|
input.AttemptID,
|
|
input.Status,
|
|
input.Retryable,
|
|
input.RequestID,
|
|
string(usageJSON),
|
|
string(metricsJSON),
|
|
string(responseJSON),
|
|
nullableTime(input.ResponseStartedAt),
|
|
nullableTime(input.ResponseFinishedAt),
|
|
input.ResponseDurationMS,
|
|
input.ErrorCode,
|
|
input.ErrorMessage,
|
|
)
|
|
return err
|
|
}
|
|
|
|
func (s *Store) FinishTaskSuccess(ctx context.Context, input FinishTaskSuccessInput) (GatewayTask, error) {
|
|
resultJSON, _ := json.Marshal(emptyObjectIfNil(input.Result))
|
|
billingsJSON, _ := json.Marshal(input.Billings)
|
|
usageJSON, _ := json.Marshal(emptyObjectIfNil(input.Usage))
|
|
metricsJSON, _ := json.Marshal(emptyObjectIfNil(input.Metrics))
|
|
billingSummaryJSON, _ := json.Marshal(emptyObjectIfNil(input.BillingSummary))
|
|
if _, err := s.pool.Exec(ctx, `
|
|
UPDATE gateway_tasks
|
|
SET status = 'succeeded',
|
|
result = $2::jsonb,
|
|
billings = $3::jsonb,
|
|
request_id = NULLIF($4, ''),
|
|
resolved_model = NULLIF($5, ''),
|
|
usage = $6::jsonb,
|
|
metrics = $7::jsonb,
|
|
billing_summary = $8::jsonb,
|
|
final_charge_amount = $9,
|
|
response_started_at = $10::timestamptz,
|
|
response_finished_at = $11::timestamptz,
|
|
response_duration_ms = $12,
|
|
error = NULL,
|
|
error_code = NULL,
|
|
error_message = NULL,
|
|
finished_at = now(),
|
|
updated_at = now()
|
|
WHERE id = $1::uuid`,
|
|
input.TaskID,
|
|
string(resultJSON),
|
|
string(billingsJSON),
|
|
input.RequestID,
|
|
input.ResolvedModel,
|
|
string(usageJSON),
|
|
string(metricsJSON),
|
|
string(billingSummaryJSON),
|
|
input.FinalChargeAmount,
|
|
nullableTime(input.ResponseStartedAt),
|
|
nullableTime(input.ResponseFinishedAt),
|
|
input.ResponseDurationMS,
|
|
); err != nil {
|
|
return GatewayTask{}, err
|
|
}
|
|
return s.GetTask(ctx, input.TaskID)
|
|
}
|
|
|
|
func (s *Store) FinishTaskFailure(ctx context.Context, input FinishTaskFailureInput) (GatewayTask, error) {
|
|
metricsJSON, _ := json.Marshal(emptyObjectIfNil(input.Metrics))
|
|
if _, err := s.pool.Exec(ctx, `
|
|
UPDATE gateway_tasks
|
|
SET status = 'failed',
|
|
error = NULLIF($2, ''),
|
|
error_code = NULLIF($3, ''),
|
|
error_message = NULLIF($2, ''),
|
|
request_id = NULLIF($4, ''),
|
|
metrics = $5::jsonb,
|
|
response_started_at = $6::timestamptz,
|
|
response_finished_at = $7::timestamptz,
|
|
response_duration_ms = $8,
|
|
finished_at = now(),
|
|
updated_at = now()
|
|
WHERE id = $1::uuid`,
|
|
input.TaskID,
|
|
input.Message,
|
|
input.Code,
|
|
input.RequestID,
|
|
string(metricsJSON),
|
|
nullableTime(input.ResponseStartedAt),
|
|
nullableTime(input.ResponseFinishedAt),
|
|
input.ResponseDurationMS,
|
|
); err != nil {
|
|
return GatewayTask{}, err
|
|
}
|
|
return s.GetTask(ctx, input.TaskID)
|
|
}
|
|
|
|
func nullableTime(value time.Time) any {
|
|
if value.IsZero() {
|
|
return nil
|
|
}
|
|
return value
|
|
}
|
|
|
|
func (s *Store) AddTaskEvent(ctx context.Context, taskID string, eventType string, status string, phase string, progress float64, message string, payload map[string]any, simulated bool) (TaskEvent, error) {
|
|
payloadJSON, _ := json.Marshal(emptyObjectIfNil(payload))
|
|
var event TaskEvent
|
|
var payloadBytes []byte
|
|
err := s.pool.QueryRow(ctx, `
|
|
WITH next_seq AS (
|
|
SELECT COALESCE(MAX(seq), 0) + 1 AS seq
|
|
FROM gateway_task_events
|
|
WHERE task_id = $1::uuid
|
|
)
|
|
INSERT INTO gateway_task_events (task_id, seq, event_type, status, phase, progress, message, payload, simulated)
|
|
SELECT $1::uuid, next_seq.seq, $2, NULLIF($3, ''), NULLIF($4, ''), $5, NULLIF($6, ''), $7::jsonb, $8
|
|
FROM next_seq
|
|
RETURNING id::text, task_id::text, seq, event_type, COALESCE(status, ''), COALESCE(phase, ''),
|
|
COALESCE(progress, 0)::float8, COALESCE(message, ''), payload, simulated, created_at`,
|
|
taskID,
|
|
eventType,
|
|
status,
|
|
phase,
|
|
progress,
|
|
message,
|
|
string(payloadJSON),
|
|
simulated,
|
|
).Scan(
|
|
&event.ID,
|
|
&event.TaskID,
|
|
&event.Seq,
|
|
&event.EventType,
|
|
&event.Status,
|
|
&event.Phase,
|
|
&event.Progress,
|
|
&event.Message,
|
|
&payloadBytes,
|
|
&event.Simulated,
|
|
&event.CreatedAt,
|
|
)
|
|
if err != nil {
|
|
return TaskEvent{}, err
|
|
}
|
|
event.Payload = decodeObject(payloadBytes)
|
|
return event, nil
|
|
}
|
|
|
|
func (s *Store) QueueTaskCallback(ctx context.Context, event TaskEvent, callbackURL string) error {
|
|
if callbackURL == "" {
|
|
return nil
|
|
}
|
|
payloadJSON, _ := json.Marshal(map[string]any{
|
|
"taskId": event.TaskID,
|
|
"seq": event.Seq,
|
|
"eventType": event.EventType,
|
|
"status": event.Status,
|
|
"phase": event.Phase,
|
|
"progress": event.Progress,
|
|
"message": event.Message,
|
|
"payload": event.Payload,
|
|
"simulated": event.Simulated,
|
|
"createdAt": event.CreatedAt,
|
|
})
|
|
_, err := s.pool.Exec(ctx, `
|
|
INSERT INTO gateway_task_callback_outbox (task_id, event_id, seq, callback_url, payload)
|
|
VALUES ($1::uuid, $2::uuid, $3, $4, $5::jsonb)
|
|
ON CONFLICT (task_id, seq, callback_url) DO NOTHING`,
|
|
event.TaskID,
|
|
event.ID,
|
|
event.Seq,
|
|
callbackURL,
|
|
string(payloadJSON),
|
|
)
|
|
return err
|
|
}
|
|
|
|
func (s *Store) RecordClientAssignment(ctx context.Context, candidate RuntimeModelCandidate) error {
|
|
_, err := s.pool.Exec(ctx, `
|
|
INSERT INTO runtime_client_states (
|
|
client_id, platform_id, provider, method_name, queue_key, running_count, last_assigned_at
|
|
)
|
|
VALUES ($1, $2::uuid, $3, $4, $5, 1, now())
|
|
ON CONFLICT (client_id) DO UPDATE
|
|
SET running_count = runtime_client_states.running_count + 1,
|
|
last_assigned_at = now(),
|
|
updated_at = now()`,
|
|
candidate.ClientID,
|
|
candidate.PlatformID,
|
|
candidate.Provider,
|
|
candidate.ModelType,
|
|
candidate.QueueKey,
|
|
)
|
|
return err
|
|
}
|
|
|
|
func (s *Store) RecordClientRelease(ctx context.Context, clientID string, lastError string) error {
|
|
_, err := s.pool.Exec(ctx, `
|
|
UPDATE runtime_client_states
|
|
SET running_count = GREATEST(running_count - 1, 0),
|
|
last_error = NULLIF($2, ''),
|
|
updated_at = now()
|
|
WHERE client_id = $1`, clientID, lastError)
|
|
return err
|
|
}
|