150 lines
4.2 KiB
Go
150 lines
4.2 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"strings"
|
|
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
|
"github.com/jackc/pgx/v5"
|
|
)
|
|
|
|
type ConversationMessageInput struct {
|
|
Hash string
|
|
Role string
|
|
Snapshot map[string]any
|
|
AssetSHA256s []string
|
|
}
|
|
|
|
type TaskMessageRefInput struct {
|
|
MessageID string
|
|
Position int
|
|
}
|
|
|
|
type ConversationMessageRef struct {
|
|
MessageID string `json:"messageId"`
|
|
Position int `json:"position"`
|
|
Message map[string]any `json:"message"`
|
|
}
|
|
|
|
func (s *Store) EnsureConversation(ctx context.Context, user *auth.User, conversationKey string, metadata map[string]any) (string, error) {
|
|
conversationKey = strings.TrimSpace(conversationKey)
|
|
if conversationKey == "" {
|
|
return "", nil
|
|
}
|
|
userID := ""
|
|
gatewayUserID := ""
|
|
if user != nil {
|
|
userID = strings.TrimSpace(user.ID)
|
|
gatewayUserID = strings.TrimSpace(user.GatewayUserID)
|
|
}
|
|
if userID == "" {
|
|
userID = "anonymous"
|
|
}
|
|
metadataJSON, _ := json.Marshal(emptyObjectIfNil(metadata))
|
|
var conversationID string
|
|
err := s.pool.QueryRow(ctx, `
|
|
INSERT INTO gateway_conversations (user_id, gateway_user_id, conversation_key, metadata)
|
|
VALUES ($1, NULLIF($2, '')::uuid, $3, $4::jsonb)
|
|
ON CONFLICT (user_id, conversation_key) DO UPDATE
|
|
SET gateway_user_id = COALESCE(gateway_conversations.gateway_user_id, EXCLUDED.gateway_user_id),
|
|
metadata = gateway_conversations.metadata || EXCLUDED.metadata,
|
|
updated_at = now()
|
|
RETURNING id::text`, userID, gatewayUserID, conversationKey, string(metadataJSON)).Scan(&conversationID)
|
|
return conversationID, err
|
|
}
|
|
|
|
func (s *Store) UpsertConversationMessages(ctx context.Context, conversationID string, messages []ConversationMessageInput) ([]TaskMessageRefInput, int, error) {
|
|
if strings.TrimSpace(conversationID) == "" || len(messages) == 0 {
|
|
return nil, 0, nil
|
|
}
|
|
tx, err := s.pool.Begin(ctx)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
defer tx.Rollback(ctx)
|
|
|
|
refs := make([]TaskMessageRefInput, 0, len(messages))
|
|
newCount := 0
|
|
for index, message := range messages {
|
|
snapshotJSON, _ := json.Marshal(emptyObjectIfNil(message.Snapshot))
|
|
var messageID string
|
|
var inserted bool
|
|
if err := tx.QueryRow(ctx, `
|
|
INSERT INTO gateway_conversation_messages (
|
|
conversation_id, message_hash, role, message_snapshot, asset_sha256s
|
|
)
|
|
VALUES ($1::uuid, $2, NULLIF($3, ''), $4::jsonb, $5)
|
|
ON CONFLICT (conversation_id, message_hash) DO UPDATE
|
|
SET updated_at = gateway_conversation_messages.updated_at
|
|
RETURNING id::text, (xmax = 0) AS inserted`,
|
|
conversationID,
|
|
message.Hash,
|
|
message.Role,
|
|
string(snapshotJSON),
|
|
message.AssetSHA256s,
|
|
).Scan(&messageID, &inserted); err != nil {
|
|
return nil, 0, err
|
|
}
|
|
if inserted {
|
|
newCount++
|
|
}
|
|
refs = append(refs, TaskMessageRefInput{MessageID: messageID, Position: index})
|
|
}
|
|
if err := tx.Commit(ctx); err != nil {
|
|
return nil, 0, err
|
|
}
|
|
return refs, newCount, nil
|
|
}
|
|
|
|
func (s *Store) ListTaskConversationMessages(ctx context.Context, taskID string) ([]ConversationMessageRef, error) {
|
|
rows, err := s.pool.Query(ctx, `
|
|
SELECT refs.message_id::text, refs.position, messages.message_snapshot
|
|
FROM gateway_task_message_refs refs
|
|
JOIN gateway_conversation_messages messages ON messages.id = refs.message_id
|
|
WHERE refs.task_id = $1::uuid
|
|
ORDER BY refs.position ASC`, taskID)
|
|
if err != nil {
|
|
if IsUndefinedDatabaseObject(err) {
|
|
return nil, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
items := make([]ConversationMessageRef, 0)
|
|
for rows.Next() {
|
|
var item ConversationMessageRef
|
|
var snapshot []byte
|
|
if err := rows.Scan(&item.MessageID, &item.Position, &snapshot); err != nil {
|
|
return nil, err
|
|
}
|
|
item.Message = decodeObject(snapshot)
|
|
items = append(items, item)
|
|
}
|
|
return items, rows.Err()
|
|
}
|
|
|
|
func insertTaskMessageRefs(ctx context.Context, tx pgx.Tx, taskID string, refs []TaskMessageRefInput) error {
|
|
if len(refs) == 0 {
|
|
return nil
|
|
}
|
|
for _, ref := range refs {
|
|
if strings.TrimSpace(ref.MessageID) == "" {
|
|
continue
|
|
}
|
|
if _, err := tx.Exec(ctx, `
|
|
INSERT INTO gateway_task_message_refs (task_id, message_id, position)
|
|
VALUES ($1::uuid, $2::uuid, $3)
|
|
ON CONFLICT (task_id, position) DO UPDATE
|
|
SET message_id = EXCLUDED.message_id`,
|
|
taskID,
|
|
ref.MessageID,
|
|
ref.Position,
|
|
); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|