easyai-ai-gateway/apps/api/internal/store/conversations.go

150 lines
4.2 KiB
Go

package store
import (
"context"
"encoding/json"
"strings"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
"github.com/jackc/pgx/v5"
)
type ConversationMessageInput struct {
Hash string
Role string
Snapshot map[string]any
AssetSHA256s []string
}
type TaskMessageRefInput struct {
MessageID string
Position int
}
type ConversationMessageRef struct {
MessageID string `json:"messageId"`
Position int `json:"position"`
Message map[string]any `json:"message"`
}
func (s *Store) EnsureConversation(ctx context.Context, user *auth.User, conversationKey string, metadata map[string]any) (string, error) {
conversationKey = strings.TrimSpace(conversationKey)
if conversationKey == "" {
return "", nil
}
userID := ""
gatewayUserID := ""
if user != nil {
userID = strings.TrimSpace(user.ID)
gatewayUserID = strings.TrimSpace(user.GatewayUserID)
}
if userID == "" {
userID = "anonymous"
}
metadataJSON, _ := json.Marshal(emptyObjectIfNil(metadata))
var conversationID string
err := s.pool.QueryRow(ctx, `
INSERT INTO gateway_conversations (user_id, gateway_user_id, conversation_key, metadata)
VALUES ($1, NULLIF($2, '')::uuid, $3, $4::jsonb)
ON CONFLICT (user_id, conversation_key) DO UPDATE
SET gateway_user_id = COALESCE(gateway_conversations.gateway_user_id, EXCLUDED.gateway_user_id),
metadata = gateway_conversations.metadata || EXCLUDED.metadata,
updated_at = now()
RETURNING id::text`, userID, gatewayUserID, conversationKey, string(metadataJSON)).Scan(&conversationID)
return conversationID, err
}
func (s *Store) UpsertConversationMessages(ctx context.Context, conversationID string, messages []ConversationMessageInput) ([]TaskMessageRefInput, int, error) {
if strings.TrimSpace(conversationID) == "" || len(messages) == 0 {
return nil, 0, nil
}
tx, err := s.pool.Begin(ctx)
if err != nil {
return nil, 0, err
}
defer tx.Rollback(ctx)
refs := make([]TaskMessageRefInput, 0, len(messages))
newCount := 0
for index, message := range messages {
snapshotJSON, _ := json.Marshal(emptyObjectIfNil(message.Snapshot))
var messageID string
var inserted bool
if err := tx.QueryRow(ctx, `
INSERT INTO gateway_conversation_messages (
conversation_id, message_hash, role, message_snapshot, asset_sha256s
)
VALUES ($1::uuid, $2, NULLIF($3, ''), $4::jsonb, $5)
ON CONFLICT (conversation_id, message_hash) DO UPDATE
SET updated_at = gateway_conversation_messages.updated_at
RETURNING id::text, (xmax = 0) AS inserted`,
conversationID,
message.Hash,
message.Role,
string(snapshotJSON),
message.AssetSHA256s,
).Scan(&messageID, &inserted); err != nil {
return nil, 0, err
}
if inserted {
newCount++
}
refs = append(refs, TaskMessageRefInput{MessageID: messageID, Position: index})
}
if err := tx.Commit(ctx); err != nil {
return nil, 0, err
}
return refs, newCount, nil
}
func (s *Store) ListTaskConversationMessages(ctx context.Context, taskID string) ([]ConversationMessageRef, error) {
rows, err := s.pool.Query(ctx, `
SELECT refs.message_id::text, refs.position, messages.message_snapshot
FROM gateway_task_message_refs refs
JOIN gateway_conversation_messages messages ON messages.id = refs.message_id
WHERE refs.task_id = $1::uuid
ORDER BY refs.position ASC`, taskID)
if err != nil {
if IsUndefinedDatabaseObject(err) {
return nil, nil
}
return nil, err
}
defer rows.Close()
items := make([]ConversationMessageRef, 0)
for rows.Next() {
var item ConversationMessageRef
var snapshot []byte
if err := rows.Scan(&item.MessageID, &item.Position, &snapshot); err != nil {
return nil, err
}
item.Message = decodeObject(snapshot)
items = append(items, item)
}
return items, rows.Err()
}
func insertTaskMessageRefs(ctx context.Context, tx pgx.Tx, taskID string, refs []TaskMessageRefInput) error {
if len(refs) == 0 {
return nil
}
for _, ref := range refs {
if strings.TrimSpace(ref.MessageID) == "" {
continue
}
if _, err := tx.Exec(ctx, `
INSERT INTO gateway_task_message_refs (task_id, message_id, position)
VALUES ($1::uuid, $2::uuid, $3)
ON CONFLICT (task_id, position) DO UPDATE
SET message_id = EXCLUDED.message_id`,
taskID,
ref.MessageID,
ref.Position,
); err != nil {
return err
}
}
return nil
}