package store import ( "context" "encoding/json" "time" "github.com/easyai/easyai-ai-gateway/apps/api/internal/auth" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" ) type Store struct { pool *pgxpool.Pool } func Connect(ctx context.Context, databaseURL string) (*Store, error) { pool, err := pgxpool.New(ctx, databaseURL) if err != nil { return nil, err } if err := pool.Ping(ctx); err != nil { pool.Close() return nil, err } return &Store{pool: pool}, nil } func (s *Store) Close() { s.pool.Close() } func (s *Store) Ping(ctx context.Context) error { return s.pool.Ping(ctx) } type Platform struct { ID string `json:"id"` Provider string `json:"provider"` PlatformKey string `json:"platformKey"` Name string `json:"name"` BaseURL string `json:"baseUrl,omitempty"` AuthType string `json:"authType"` Status string `json:"status"` Priority int `json:"priority"` DefaultPricingMode string `json:"defaultPricingMode"` DefaultDiscountFactor float64 `json:"defaultDiscountFactor"` Config map[string]any `json:"config,omitempty"` CreatedAt time.Time `json:"createdAt"` UpdatedAt time.Time `json:"updatedAt"` } type CreatePlatformInput struct { Provider string `json:"provider"` PlatformKey string `json:"platformKey"` Name string `json:"name"` BaseURL string `json:"baseUrl"` AuthType string `json:"authType"` Credentials map[string]any `json:"credentials"` Config map[string]any `json:"config"` DefaultPricingMode string `json:"defaultPricingMode"` DefaultDiscountFactor float64 `json:"defaultDiscountFactor"` Priority int `json:"priority"` } type PlatformModel struct { ID string `json:"id"` PlatformID string `json:"platformId"` BaseModelID string `json:"baseModelId,omitempty"` Provider string `json:"provider,omitempty"` PlatformName string `json:"platformName,omitempty"` ModelName string `json:"modelName"` ModelAlias string `json:"modelAlias,omitempty"` ModelType string `json:"modelType"` DisplayName string `json:"displayName"` CapabilityOverride map[string]any `json:"capabilityOverride,omitempty"` Capabilities map[string]any `json:"capabilities,omitempty"` PricingMode string `json:"pricingMode"` DiscountFactor float64 `json:"discountFactor,omitempty"` BillingConfigOverride map[string]any `json:"billingConfigOverride,omitempty"` BillingConfig map[string]any `json:"billingConfig,omitempty"` Enabled bool `json:"enabled"` CreatedAt time.Time `json:"createdAt"` UpdatedAt time.Time `json:"updatedAt"` } type CatalogProvider struct { ID string `json:"id"` ProviderKey string `json:"providerKey"` DisplayName string `json:"displayName"` ProviderType string `json:"providerType"` CapabilitySchema map[string]any `json:"capabilitySchema,omitempty"` DefaultRateLimitPolicy map[string]any `json:"defaultRateLimitPolicy,omitempty"` Status string `json:"status"` CreatedAt time.Time `json:"createdAt"` UpdatedAt time.Time `json:"updatedAt"` } type BaseModel struct { ID string `json:"id"` ProviderKey string `json:"providerKey"` CanonicalModelKey string `json:"canonicalModelKey"` ProviderModelName string `json:"providerModelName"` ModelType string `json:"modelType"` DisplayName string `json:"displayName"` Capabilities map[string]any `json:"capabilities,omitempty"` BaseBillingConfig map[string]any `json:"baseBillingConfig,omitempty"` DefaultRateLimitPolicy map[string]any `json:"defaultRateLimitPolicy,omitempty"` PricingVersion int `json:"pricingVersion"` Status string `json:"status"` CreatedAt time.Time `json:"createdAt"` UpdatedAt time.Time `json:"updatedAt"` } type PricingRule struct { ID string `json:"id"` ScopeType string `json:"scopeType"` ScopeID string `json:"scopeId,omitempty"` ResourceType string `json:"resourceType"` Unit string `json:"unit"` BasePrice float64 `json:"basePrice"` Currency string `json:"currency"` BaseWeight map[string]any `json:"baseWeight,omitempty"` DynamicWeight map[string]any `json:"dynamicWeight,omitempty"` CreatedAt time.Time `json:"createdAt"` UpdatedAt time.Time `json:"updatedAt"` } type RateLimitWindow struct { ScopeType string `json:"scopeType"` ScopeKey string `json:"scopeKey"` Metric string `json:"metric"` WindowStart time.Time `json:"windowStart"` LimitValue float64 `json:"limitValue"` UsedValue float64 `json:"usedValue"` ReservedValue float64 `json:"reservedValue"` ResetAt time.Time `json:"resetAt"` UpdatedAt time.Time `json:"updatedAt"` } type CreateTaskInput struct { Kind string `json:"kind"` Model string `json:"model"` Request map[string]any `json:"request"` } type GatewayTask struct { ID string `json:"id"` Kind string `json:"kind"` UserID string `json:"userId"` TenantID string `json:"tenantId,omitempty"` Model string `json:"model"` Request map[string]any `json:"request,omitempty"` Status string `json:"status"` Result map[string]any `json:"result,omitempty"` Billings []any `json:"billings,omitempty"` Error string `json:"error,omitempty"` CreatedAt time.Time `json:"createdAt"` UpdatedAt time.Time `json:"updatedAt"` } func (s *Store) ListPlatforms(ctx context.Context) ([]Platform, error) { rows, err := s.pool.Query(ctx, ` SELECT id::text, provider, platform_key, name, COALESCE(base_url, ''), auth_type, status, priority, default_pricing_mode, default_discount_factor::float8, config, created_at, updated_at FROM integration_platforms ORDER BY priority ASC, created_at DESC`) if err != nil { return nil, err } defer rows.Close() platforms := make([]Platform, 0) for rows.Next() { var platform Platform var configBytes []byte if err := rows.Scan( &platform.ID, &platform.Provider, &platform.PlatformKey, &platform.Name, &platform.BaseURL, &platform.AuthType, &platform.Status, &platform.Priority, &platform.DefaultPricingMode, &platform.DefaultDiscountFactor, &configBytes, &platform.CreatedAt, &platform.UpdatedAt, ); err != nil { return nil, err } platform.Config = decodeObject(configBytes) platforms = append(platforms, platform) } return platforms, rows.Err() } func (s *Store) CreatePlatform(ctx context.Context, input CreatePlatformInput) (Platform, error) { credentials, _ := json.Marshal(input.Credentials) config, _ := json.Marshal(input.Config) if input.DefaultPricingMode == "" { input.DefaultPricingMode = "inherit_discount" } if input.DefaultDiscountFactor == 0 { input.DefaultDiscountFactor = 1 } if input.Priority == 0 { input.Priority = 100 } var platform Platform var configBytes []byte err := s.pool.QueryRow(ctx, ` INSERT INTO integration_platforms (provider, platform_key, name, base_url, auth_type, credentials, config, default_pricing_mode, default_discount_factor, priority) VALUES ($1, COALESCE(NULLIF($2, ''), 'platform_' || replace(gen_random_uuid()::text, '-', '')), $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id::text, provider, platform_key, name, COALESCE(base_url, ''), auth_type, status, priority, default_pricing_mode, default_discount_factor::float8, config, created_at, updated_at`, input.Provider, input.PlatformKey, input.Name, input.BaseURL, input.AuthType, credentials, config, input.DefaultPricingMode, input.DefaultDiscountFactor, input.Priority, ).Scan( &platform.ID, &platform.Provider, &platform.PlatformKey, &platform.Name, &platform.BaseURL, &platform.AuthType, &platform.Status, &platform.Priority, &platform.DefaultPricingMode, &platform.DefaultDiscountFactor, &configBytes, &platform.CreatedAt, &platform.UpdatedAt, ) if err != nil { return Platform{}, err } platform.Config = decodeObject(configBytes) return platform, nil } func (s *Store) ListModels(ctx context.Context) ([]PlatformModel, error) { rows, err := s.pool.Query(ctx, ` SELECT m.id::text, m.platform_id::text, COALESCE(m.base_model_id::text, ''), p.provider, p.name, m.model_name, COALESCE(m.model_alias, ''), m.model_type, m.display_name, m.capability_override, m.capabilities, m.pricing_mode, COALESCE(m.discount_factor, 0)::float8, m.billing_config_override, m.billing_config, m.enabled, m.created_at, m.updated_at FROM platform_models m JOIN integration_platforms p ON p.id = m.platform_id ORDER BY m.model_type ASC, m.model_name ASC`) if err != nil { return nil, err } defer rows.Close() models := make([]PlatformModel, 0) for rows.Next() { var model PlatformModel var capabilityOverride []byte var capabilities []byte var billingConfigOverride []byte var billingConfig []byte if err := rows.Scan( &model.ID, &model.PlatformID, &model.BaseModelID, &model.Provider, &model.PlatformName, &model.ModelName, &model.ModelAlias, &model.ModelType, &model.DisplayName, &capabilityOverride, &capabilities, &model.PricingMode, &model.DiscountFactor, &billingConfigOverride, &billingConfig, &model.Enabled, &model.CreatedAt, &model.UpdatedAt, ); err != nil { return nil, err } model.CapabilityOverride = decodeObject(capabilityOverride) model.Capabilities = decodeObject(capabilities) model.BillingConfigOverride = decodeObject(billingConfigOverride) model.BillingConfig = decodeObject(billingConfig) models = append(models, model) } return models, rows.Err() } func (s *Store) ListCatalogProviders(ctx context.Context) ([]CatalogProvider, error) { rows, err := s.pool.Query(ctx, ` SELECT id::text, provider_key, display_name, provider_type, capability_schema, default_rate_limit_policy, status, created_at, updated_at FROM model_catalog_providers ORDER BY provider_key ASC`) if err != nil { return nil, err } defer rows.Close() items := make([]CatalogProvider, 0) for rows.Next() { var item CatalogProvider var capabilitySchema []byte var rateLimitPolicy []byte if err := rows.Scan( &item.ID, &item.ProviderKey, &item.DisplayName, &item.ProviderType, &capabilitySchema, &rateLimitPolicy, &item.Status, &item.CreatedAt, &item.UpdatedAt, ); err != nil { return nil, err } item.CapabilitySchema = decodeObject(capabilitySchema) item.DefaultRateLimitPolicy = decodeObject(rateLimitPolicy) items = append(items, item) } return items, rows.Err() } func (s *Store) ListBaseModels(ctx context.Context) ([]BaseModel, error) { rows, err := s.pool.Query(ctx, ` SELECT id::text, provider_key, canonical_model_key, provider_model_name, model_type, display_name, capabilities, base_billing_config, default_rate_limit_policy, pricing_version, status, created_at, updated_at FROM base_model_catalog ORDER BY provider_key ASC, model_type ASC, canonical_model_key ASC`) if err != nil { return nil, err } defer rows.Close() items := make([]BaseModel, 0) for rows.Next() { var item BaseModel var capabilities []byte var billingConfig []byte var rateLimitPolicy []byte if err := rows.Scan( &item.ID, &item.ProviderKey, &item.CanonicalModelKey, &item.ProviderModelName, &item.ModelType, &item.DisplayName, &capabilities, &billingConfig, &rateLimitPolicy, &item.PricingVersion, &item.Status, &item.CreatedAt, &item.UpdatedAt, ); err != nil { return nil, err } item.Capabilities = decodeObject(capabilities) item.BaseBillingConfig = decodeObject(billingConfig) item.DefaultRateLimitPolicy = decodeObject(rateLimitPolicy) items = append(items, item) } return items, rows.Err() } func (s *Store) ListPricingRules(ctx context.Context) ([]PricingRule, error) { rows, err := s.pool.Query(ctx, ` SELECT id::text, scope_type, COALESCE(scope_id::text, ''), resource_type, unit, base_price::float8, currency, base_weight, dynamic_weight, created_at, updated_at FROM model_pricing_rules ORDER BY scope_type ASC, resource_type ASC, created_at DESC`) if err != nil { return nil, err } defer rows.Close() items := make([]PricingRule, 0) for rows.Next() { var item PricingRule var baseWeight []byte var dynamicWeight []byte if err := rows.Scan( &item.ID, &item.ScopeType, &item.ScopeID, &item.ResourceType, &item.Unit, &item.BasePrice, &item.Currency, &baseWeight, &dynamicWeight, &item.CreatedAt, &item.UpdatedAt, ); err != nil { return nil, err } item.BaseWeight = decodeObject(baseWeight) item.DynamicWeight = decodeObject(dynamicWeight) items = append(items, item) } return items, rows.Err() } func (s *Store) ListRateLimitWindows(ctx context.Context) ([]RateLimitWindow, error) { rows, err := s.pool.Query(ctx, ` SELECT scope_type, scope_key, metric, window_start, limit_value::float8, used_value::float8, reserved_value::float8, reset_at, updated_at FROM gateway_rate_limit_counters WHERE reset_at >= now() - interval '5 minutes' ORDER BY window_start DESC, scope_type ASC, scope_key ASC, metric ASC`) if err != nil { return nil, err } defer rows.Close() items := make([]RateLimitWindow, 0) for rows.Next() { var item RateLimitWindow if err := rows.Scan( &item.ScopeType, &item.ScopeKey, &item.Metric, &item.WindowStart, &item.LimitValue, &item.UsedValue, &item.ReservedValue, &item.ResetAt, &item.UpdatedAt, ); err != nil { return nil, err } items = append(items, item) } return items, rows.Err() } func (s *Store) CreateTask(ctx context.Context, input CreateTaskInput, user *auth.User) (GatewayTask, error) { requestBody, _ := json.Marshal(input.Request) var task GatewayTask var requestBytes []byte var resultBytes []byte var billingsBytes []byte err := s.pool.QueryRow(ctx, ` INSERT INTO gateway_tasks (kind, user_id, tenant_id, model, request, status) VALUES ($1, $2, NULLIF($3, ''), $4, $5, 'queued') RETURNING id::text, kind, user_id, COALESCE(tenant_id, ''), model, request, status, result, billings, COALESCE(error, ''), created_at, updated_at`, input.Kind, user.ID, user.TenantID, input.Model, requestBody, ).Scan(&task.ID, &task.Kind, &task.UserID, &task.TenantID, &task.Model, &requestBytes, &task.Status, &resultBytes, &billingsBytes, &task.Error, &task.CreatedAt, &task.UpdatedAt) if err != nil { return GatewayTask{}, err } task.Request = decodeObject(requestBytes) task.Result = decodeObject(resultBytes) task.Billings = decodeArray(billingsBytes) return task, nil } func (s *Store) GetTask(ctx context.Context, taskID string) (GatewayTask, error) { var task GatewayTask var requestBytes []byte var resultBytes []byte var billingsBytes []byte err := s.pool.QueryRow(ctx, ` SELECT id::text, kind, user_id, COALESCE(tenant_id, ''), model, request, status, result, billings, COALESCE(error, ''), created_at, updated_at FROM gateway_tasks WHERE id=$1`, taskID, ).Scan(&task.ID, &task.Kind, &task.UserID, &task.TenantID, &task.Model, &requestBytes, &task.Status, &resultBytes, &billingsBytes, &task.Error, &task.CreatedAt, &task.UpdatedAt) if err != nil { return GatewayTask{}, err } task.Request = decodeObject(requestBytes) task.Result = decodeObject(resultBytes) task.Billings = decodeArray(billingsBytes) return task, nil } func IsNotFound(err error) bool { return err == pgx.ErrNoRows } func decodeObject(bytes []byte) map[string]any { if len(bytes) == 0 { return nil } var out map[string]any if err := json.Unmarshal(bytes, &out); err != nil { return nil } return out } func decodeArray(bytes []byte) []any { if len(bytes) == 0 { return nil } var out []any if err := json.Unmarshal(bytes, &out); err != nil { return nil } return out }