easyai-ai-gateway/apps/api/internal/store/postgres.go
wangbo 6323e70e49 Initial project scaffold
Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-09 14:36:35 +08:00

509 lines
16 KiB
Go

package store
import (
"context"
"encoding/json"
"time"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)
type Store struct {
pool *pgxpool.Pool
}
func Connect(ctx context.Context, databaseURL string) (*Store, error) {
pool, err := pgxpool.New(ctx, databaseURL)
if err != nil {
return nil, err
}
if err := pool.Ping(ctx); err != nil {
pool.Close()
return nil, err
}
return &Store{pool: pool}, nil
}
func (s *Store) Close() {
s.pool.Close()
}
func (s *Store) Ping(ctx context.Context) error {
return s.pool.Ping(ctx)
}
type Platform struct {
ID string `json:"id"`
Provider string `json:"provider"`
PlatformKey string `json:"platformKey"`
Name string `json:"name"`
BaseURL string `json:"baseUrl,omitempty"`
AuthType string `json:"authType"`
Status string `json:"status"`
Priority int `json:"priority"`
DefaultPricingMode string `json:"defaultPricingMode"`
DefaultDiscountFactor float64 `json:"defaultDiscountFactor"`
Config map[string]any `json:"config,omitempty"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
type CreatePlatformInput struct {
Provider string `json:"provider"`
PlatformKey string `json:"platformKey"`
Name string `json:"name"`
BaseURL string `json:"baseUrl"`
AuthType string `json:"authType"`
Credentials map[string]any `json:"credentials"`
Config map[string]any `json:"config"`
DefaultPricingMode string `json:"defaultPricingMode"`
DefaultDiscountFactor float64 `json:"defaultDiscountFactor"`
Priority int `json:"priority"`
}
type PlatformModel struct {
ID string `json:"id"`
PlatformID string `json:"platformId"`
BaseModelID string `json:"baseModelId,omitempty"`
Provider string `json:"provider,omitempty"`
PlatformName string `json:"platformName,omitempty"`
ModelName string `json:"modelName"`
ModelAlias string `json:"modelAlias,omitempty"`
ModelType string `json:"modelType"`
DisplayName string `json:"displayName"`
CapabilityOverride map[string]any `json:"capabilityOverride,omitempty"`
Capabilities map[string]any `json:"capabilities,omitempty"`
PricingMode string `json:"pricingMode"`
DiscountFactor float64 `json:"discountFactor,omitempty"`
BillingConfigOverride map[string]any `json:"billingConfigOverride,omitempty"`
BillingConfig map[string]any `json:"billingConfig,omitempty"`
Enabled bool `json:"enabled"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
type CatalogProvider struct {
ID string `json:"id"`
ProviderKey string `json:"providerKey"`
DisplayName string `json:"displayName"`
ProviderType string `json:"providerType"`
CapabilitySchema map[string]any `json:"capabilitySchema,omitempty"`
DefaultRateLimitPolicy map[string]any `json:"defaultRateLimitPolicy,omitempty"`
Status string `json:"status"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
type BaseModel struct {
ID string `json:"id"`
ProviderKey string `json:"providerKey"`
CanonicalModelKey string `json:"canonicalModelKey"`
ProviderModelName string `json:"providerModelName"`
ModelType string `json:"modelType"`
DisplayName string `json:"displayName"`
Capabilities map[string]any `json:"capabilities,omitempty"`
BaseBillingConfig map[string]any `json:"baseBillingConfig,omitempty"`
DefaultRateLimitPolicy map[string]any `json:"defaultRateLimitPolicy,omitempty"`
PricingVersion int `json:"pricingVersion"`
Status string `json:"status"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
type PricingRule struct {
ID string `json:"id"`
ScopeType string `json:"scopeType"`
ScopeID string `json:"scopeId,omitempty"`
ResourceType string `json:"resourceType"`
Unit string `json:"unit"`
BasePrice float64 `json:"basePrice"`
Currency string `json:"currency"`
BaseWeight map[string]any `json:"baseWeight,omitempty"`
DynamicWeight map[string]any `json:"dynamicWeight,omitempty"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
type RateLimitWindow struct {
ScopeType string `json:"scopeType"`
ScopeKey string `json:"scopeKey"`
Metric string `json:"metric"`
WindowStart time.Time `json:"windowStart"`
LimitValue float64 `json:"limitValue"`
UsedValue float64 `json:"usedValue"`
ReservedValue float64 `json:"reservedValue"`
ResetAt time.Time `json:"resetAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
type CreateTaskInput struct {
Kind string `json:"kind"`
Model string `json:"model"`
Request map[string]any `json:"request"`
}
type GatewayTask struct {
ID string `json:"id"`
Kind string `json:"kind"`
UserID string `json:"userId"`
TenantID string `json:"tenantId,omitempty"`
Model string `json:"model"`
Request map[string]any `json:"request,omitempty"`
Status string `json:"status"`
Result map[string]any `json:"result,omitempty"`
Billings []any `json:"billings,omitempty"`
Error string `json:"error,omitempty"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
func (s *Store) ListPlatforms(ctx context.Context) ([]Platform, error) {
rows, err := s.pool.Query(ctx, `
SELECT id::text, provider, platform_key, name, COALESCE(base_url, ''), auth_type, status, priority,
default_pricing_mode, default_discount_factor::float8, config, created_at, updated_at
FROM integration_platforms
ORDER BY priority ASC, created_at DESC`)
if err != nil {
return nil, err
}
defer rows.Close()
platforms := make([]Platform, 0)
for rows.Next() {
var platform Platform
var configBytes []byte
if err := rows.Scan(
&platform.ID,
&platform.Provider,
&platform.PlatformKey,
&platform.Name,
&platform.BaseURL,
&platform.AuthType,
&platform.Status,
&platform.Priority,
&platform.DefaultPricingMode,
&platform.DefaultDiscountFactor,
&configBytes,
&platform.CreatedAt,
&platform.UpdatedAt,
); err != nil {
return nil, err
}
platform.Config = decodeObject(configBytes)
platforms = append(platforms, platform)
}
return platforms, rows.Err()
}
func (s *Store) CreatePlatform(ctx context.Context, input CreatePlatformInput) (Platform, error) {
credentials, _ := json.Marshal(input.Credentials)
config, _ := json.Marshal(input.Config)
if input.DefaultPricingMode == "" {
input.DefaultPricingMode = "inherit_discount"
}
if input.DefaultDiscountFactor == 0 {
input.DefaultDiscountFactor = 1
}
if input.Priority == 0 {
input.Priority = 100
}
var platform Platform
var configBytes []byte
err := s.pool.QueryRow(ctx, `
INSERT INTO integration_platforms (provider, platform_key, name, base_url, auth_type, credentials, config, default_pricing_mode, default_discount_factor, priority)
VALUES ($1, COALESCE(NULLIF($2, ''), 'platform_' || replace(gen_random_uuid()::text, '-', '')), $3, $4, $5, $6, $7, $8, $9, $10)
RETURNING id::text, provider, platform_key, name, COALESCE(base_url, ''), auth_type, status, priority,
default_pricing_mode, default_discount_factor::float8, config, created_at, updated_at`,
input.Provider, input.PlatformKey, input.Name, input.BaseURL, input.AuthType, credentials, config, input.DefaultPricingMode, input.DefaultDiscountFactor, input.Priority,
).Scan(
&platform.ID,
&platform.Provider,
&platform.PlatformKey,
&platform.Name,
&platform.BaseURL,
&platform.AuthType,
&platform.Status,
&platform.Priority,
&platform.DefaultPricingMode,
&platform.DefaultDiscountFactor,
&configBytes,
&platform.CreatedAt,
&platform.UpdatedAt,
)
if err != nil {
return Platform{}, err
}
platform.Config = decodeObject(configBytes)
return platform, nil
}
func (s *Store) ListModels(ctx context.Context) ([]PlatformModel, error) {
rows, err := s.pool.Query(ctx, `
SELECT m.id::text, m.platform_id::text, COALESCE(m.base_model_id::text, ''), p.provider, p.name,
m.model_name, COALESCE(m.model_alias, ''), m.model_type, m.display_name,
m.capability_override, m.capabilities, m.pricing_mode, COALESCE(m.discount_factor, 0)::float8,
m.billing_config_override, m.billing_config, m.enabled, m.created_at, m.updated_at
FROM platform_models m
JOIN integration_platforms p ON p.id = m.platform_id
ORDER BY m.model_type ASC, m.model_name ASC`)
if err != nil {
return nil, err
}
defer rows.Close()
models := make([]PlatformModel, 0)
for rows.Next() {
var model PlatformModel
var capabilityOverride []byte
var capabilities []byte
var billingConfigOverride []byte
var billingConfig []byte
if err := rows.Scan(
&model.ID,
&model.PlatformID,
&model.BaseModelID,
&model.Provider,
&model.PlatformName,
&model.ModelName,
&model.ModelAlias,
&model.ModelType,
&model.DisplayName,
&capabilityOverride,
&capabilities,
&model.PricingMode,
&model.DiscountFactor,
&billingConfigOverride,
&billingConfig,
&model.Enabled,
&model.CreatedAt,
&model.UpdatedAt,
); err != nil {
return nil, err
}
model.CapabilityOverride = decodeObject(capabilityOverride)
model.Capabilities = decodeObject(capabilities)
model.BillingConfigOverride = decodeObject(billingConfigOverride)
model.BillingConfig = decodeObject(billingConfig)
models = append(models, model)
}
return models, rows.Err()
}
func (s *Store) ListCatalogProviders(ctx context.Context) ([]CatalogProvider, error) {
rows, err := s.pool.Query(ctx, `
SELECT id::text, provider_key, display_name, provider_type, capability_schema,
default_rate_limit_policy, status, created_at, updated_at
FROM model_catalog_providers
ORDER BY provider_key ASC`)
if err != nil {
return nil, err
}
defer rows.Close()
items := make([]CatalogProvider, 0)
for rows.Next() {
var item CatalogProvider
var capabilitySchema []byte
var rateLimitPolicy []byte
if err := rows.Scan(
&item.ID,
&item.ProviderKey,
&item.DisplayName,
&item.ProviderType,
&capabilitySchema,
&rateLimitPolicy,
&item.Status,
&item.CreatedAt,
&item.UpdatedAt,
); err != nil {
return nil, err
}
item.CapabilitySchema = decodeObject(capabilitySchema)
item.DefaultRateLimitPolicy = decodeObject(rateLimitPolicy)
items = append(items, item)
}
return items, rows.Err()
}
func (s *Store) ListBaseModels(ctx context.Context) ([]BaseModel, error) {
rows, err := s.pool.Query(ctx, `
SELECT id::text, provider_key, canonical_model_key, provider_model_name, model_type, display_name,
capabilities, base_billing_config, default_rate_limit_policy, pricing_version,
status, created_at, updated_at
FROM base_model_catalog
ORDER BY provider_key ASC, model_type ASC, canonical_model_key ASC`)
if err != nil {
return nil, err
}
defer rows.Close()
items := make([]BaseModel, 0)
for rows.Next() {
var item BaseModel
var capabilities []byte
var billingConfig []byte
var rateLimitPolicy []byte
if err := rows.Scan(
&item.ID,
&item.ProviderKey,
&item.CanonicalModelKey,
&item.ProviderModelName,
&item.ModelType,
&item.DisplayName,
&capabilities,
&billingConfig,
&rateLimitPolicy,
&item.PricingVersion,
&item.Status,
&item.CreatedAt,
&item.UpdatedAt,
); err != nil {
return nil, err
}
item.Capabilities = decodeObject(capabilities)
item.BaseBillingConfig = decodeObject(billingConfig)
item.DefaultRateLimitPolicy = decodeObject(rateLimitPolicy)
items = append(items, item)
}
return items, rows.Err()
}
func (s *Store) ListPricingRules(ctx context.Context) ([]PricingRule, error) {
rows, err := s.pool.Query(ctx, `
SELECT id::text, scope_type, COALESCE(scope_id::text, ''), resource_type, unit,
base_price::float8, currency, base_weight, dynamic_weight, created_at, updated_at
FROM model_pricing_rules
ORDER BY scope_type ASC, resource_type ASC, created_at DESC`)
if err != nil {
return nil, err
}
defer rows.Close()
items := make([]PricingRule, 0)
for rows.Next() {
var item PricingRule
var baseWeight []byte
var dynamicWeight []byte
if err := rows.Scan(
&item.ID,
&item.ScopeType,
&item.ScopeID,
&item.ResourceType,
&item.Unit,
&item.BasePrice,
&item.Currency,
&baseWeight,
&dynamicWeight,
&item.CreatedAt,
&item.UpdatedAt,
); err != nil {
return nil, err
}
item.BaseWeight = decodeObject(baseWeight)
item.DynamicWeight = decodeObject(dynamicWeight)
items = append(items, item)
}
return items, rows.Err()
}
func (s *Store) ListRateLimitWindows(ctx context.Context) ([]RateLimitWindow, error) {
rows, err := s.pool.Query(ctx, `
SELECT scope_type, scope_key, metric, window_start, limit_value::float8, used_value::float8,
reserved_value::float8, reset_at, updated_at
FROM gateway_rate_limit_counters
WHERE reset_at >= now() - interval '5 minutes'
ORDER BY window_start DESC, scope_type ASC, scope_key ASC, metric ASC`)
if err != nil {
return nil, err
}
defer rows.Close()
items := make([]RateLimitWindow, 0)
for rows.Next() {
var item RateLimitWindow
if err := rows.Scan(
&item.ScopeType,
&item.ScopeKey,
&item.Metric,
&item.WindowStart,
&item.LimitValue,
&item.UsedValue,
&item.ReservedValue,
&item.ResetAt,
&item.UpdatedAt,
); err != nil {
return nil, err
}
items = append(items, item)
}
return items, rows.Err()
}
func (s *Store) CreateTask(ctx context.Context, input CreateTaskInput, user *auth.User) (GatewayTask, error) {
requestBody, _ := json.Marshal(input.Request)
var task GatewayTask
var requestBytes []byte
var resultBytes []byte
var billingsBytes []byte
err := s.pool.QueryRow(ctx, `
INSERT INTO gateway_tasks (kind, user_id, tenant_id, model, request, status)
VALUES ($1, $2, NULLIF($3, ''), $4, $5, 'queued')
RETURNING id::text, kind, user_id, COALESCE(tenant_id, ''), model, request, status, result, billings, COALESCE(error, ''), created_at, updated_at`,
input.Kind, user.ID, user.TenantID, input.Model, requestBody,
).Scan(&task.ID, &task.Kind, &task.UserID, &task.TenantID, &task.Model, &requestBytes, &task.Status, &resultBytes, &billingsBytes, &task.Error, &task.CreatedAt, &task.UpdatedAt)
if err != nil {
return GatewayTask{}, err
}
task.Request = decodeObject(requestBytes)
task.Result = decodeObject(resultBytes)
task.Billings = decodeArray(billingsBytes)
return task, nil
}
func (s *Store) GetTask(ctx context.Context, taskID string) (GatewayTask, error) {
var task GatewayTask
var requestBytes []byte
var resultBytes []byte
var billingsBytes []byte
err := s.pool.QueryRow(ctx, `
SELECT id::text, kind, user_id, COALESCE(tenant_id, ''), model, request, status, result, billings, COALESCE(error, ''), created_at, updated_at
FROM gateway_tasks
WHERE id=$1`, taskID,
).Scan(&task.ID, &task.Kind, &task.UserID, &task.TenantID, &task.Model, &requestBytes, &task.Status, &resultBytes, &billingsBytes, &task.Error, &task.CreatedAt, &task.UpdatedAt)
if err != nil {
return GatewayTask{}, err
}
task.Request = decodeObject(requestBytes)
task.Result = decodeObject(resultBytes)
task.Billings = decodeArray(billingsBytes)
return task, nil
}
func IsNotFound(err error) bool {
return err == pgx.ErrNoRows
}
func decodeObject(bytes []byte) map[string]any {
if len(bytes) == 0 {
return nil
}
var out map[string]any
if err := json.Unmarshal(bytes, &out); err != nil {
return nil
}
return out
}
func decodeArray(bytes []byte) []any {
if len(bytes) == 0 {
return nil
}
var out []any
if err := json.Unmarshal(bytes, &out); err != nil {
return nil
}
return out
}