1033 lines
34 KiB
Go
1033 lines
34 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"strings"
|
|
"time"
|
|
"unicode"
|
|
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
type Store struct {
|
|
pool *pgxpool.Pool
|
|
}
|
|
|
|
var (
|
|
ErrInvalidCredentials = errors.New("invalid account or password")
|
|
ErrInvalidInvitation = errors.New("invalid or expired invitation code")
|
|
ErrWeakPassword = errors.New("password must be at least 8 characters")
|
|
)
|
|
|
|
func Connect(ctx context.Context, databaseURL string) (*Store, error) {
|
|
pool, err := pgxpool.New(ctx, databaseURL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err := pool.Ping(ctx); err != nil {
|
|
pool.Close()
|
|
return nil, err
|
|
}
|
|
return &Store{pool: pool}, nil
|
|
}
|
|
|
|
func (s *Store) Close() {
|
|
s.pool.Close()
|
|
}
|
|
|
|
func (s *Store) Ping(ctx context.Context) error {
|
|
return s.pool.Ping(ctx)
|
|
}
|
|
|
|
type Platform struct {
|
|
ID string `json:"id"`
|
|
Provider string `json:"provider"`
|
|
PlatformKey string `json:"platformKey"`
|
|
Name string `json:"name"`
|
|
BaseURL string `json:"baseUrl,omitempty"`
|
|
AuthType string `json:"authType"`
|
|
Status string `json:"status"`
|
|
Priority int `json:"priority"`
|
|
DefaultPricingMode string `json:"defaultPricingMode"`
|
|
DefaultDiscountFactor float64 `json:"defaultDiscountFactor"`
|
|
Config map[string]any `json:"config,omitempty"`
|
|
CreatedAt time.Time `json:"createdAt"`
|
|
UpdatedAt time.Time `json:"updatedAt"`
|
|
}
|
|
|
|
type CreatePlatformInput struct {
|
|
Provider string `json:"provider"`
|
|
PlatformKey string `json:"platformKey"`
|
|
Name string `json:"name"`
|
|
BaseURL string `json:"baseUrl"`
|
|
AuthType string `json:"authType"`
|
|
Credentials map[string]any `json:"credentials"`
|
|
Config map[string]any `json:"config"`
|
|
DefaultPricingMode string `json:"defaultPricingMode"`
|
|
DefaultDiscountFactor float64 `json:"defaultDiscountFactor"`
|
|
Priority int `json:"priority"`
|
|
}
|
|
|
|
type PlatformModel struct {
|
|
ID string `json:"id"`
|
|
PlatformID string `json:"platformId"`
|
|
BaseModelID string `json:"baseModelId,omitempty"`
|
|
Provider string `json:"provider,omitempty"`
|
|
PlatformName string `json:"platformName,omitempty"`
|
|
ModelName string `json:"modelName"`
|
|
ModelAlias string `json:"modelAlias,omitempty"`
|
|
ModelType string `json:"modelType"`
|
|
DisplayName string `json:"displayName"`
|
|
CapabilityOverride map[string]any `json:"capabilityOverride,omitempty"`
|
|
Capabilities map[string]any `json:"capabilities,omitempty"`
|
|
PricingMode string `json:"pricingMode"`
|
|
DiscountFactor float64 `json:"discountFactor,omitempty"`
|
|
BillingConfigOverride map[string]any `json:"billingConfigOverride,omitempty"`
|
|
BillingConfig map[string]any `json:"billingConfig,omitempty"`
|
|
Enabled bool `json:"enabled"`
|
|
CreatedAt time.Time `json:"createdAt"`
|
|
UpdatedAt time.Time `json:"updatedAt"`
|
|
}
|
|
|
|
type CatalogProvider struct {
|
|
ID string `json:"id"`
|
|
ProviderKey string `json:"providerKey"`
|
|
DisplayName string `json:"displayName"`
|
|
ProviderType string `json:"providerType"`
|
|
CapabilitySchema map[string]any `json:"capabilitySchema,omitempty"`
|
|
DefaultRateLimitPolicy map[string]any `json:"defaultRateLimitPolicy,omitempty"`
|
|
Status string `json:"status"`
|
|
CreatedAt time.Time `json:"createdAt"`
|
|
UpdatedAt time.Time `json:"updatedAt"`
|
|
}
|
|
|
|
type BaseModel struct {
|
|
ID string `json:"id"`
|
|
ProviderKey string `json:"providerKey"`
|
|
CanonicalModelKey string `json:"canonicalModelKey"`
|
|
ProviderModelName string `json:"providerModelName"`
|
|
ModelType string `json:"modelType"`
|
|
DisplayName string `json:"displayName"`
|
|
Capabilities map[string]any `json:"capabilities,omitempty"`
|
|
BaseBillingConfig map[string]any `json:"baseBillingConfig,omitempty"`
|
|
DefaultRateLimitPolicy map[string]any `json:"defaultRateLimitPolicy,omitempty"`
|
|
PricingVersion int `json:"pricingVersion"`
|
|
Status string `json:"status"`
|
|
CreatedAt time.Time `json:"createdAt"`
|
|
UpdatedAt time.Time `json:"updatedAt"`
|
|
}
|
|
|
|
type PricingRule struct {
|
|
ID string `json:"id"`
|
|
ScopeType string `json:"scopeType"`
|
|
ScopeID string `json:"scopeId,omitempty"`
|
|
ResourceType string `json:"resourceType"`
|
|
Unit string `json:"unit"`
|
|
BasePrice float64 `json:"basePrice"`
|
|
Currency string `json:"currency"`
|
|
BaseWeight map[string]any `json:"baseWeight,omitempty"`
|
|
DynamicWeight map[string]any `json:"dynamicWeight,omitempty"`
|
|
CreatedAt time.Time `json:"createdAt"`
|
|
UpdatedAt time.Time `json:"updatedAt"`
|
|
}
|
|
|
|
type GatewayTenant struct {
|
|
ID string `json:"id"`
|
|
TenantKey string `json:"tenantKey"`
|
|
Source string `json:"source"`
|
|
ExternalTenantID string `json:"externalTenantId,omitempty"`
|
|
Name string `json:"name"`
|
|
Description string `json:"description,omitempty"`
|
|
DefaultUserGroupID string `json:"defaultUserGroupId,omitempty"`
|
|
PlanKey string `json:"planKey,omitempty"`
|
|
BillingProfile map[string]any `json:"billingProfile,omitempty"`
|
|
RateLimitPolicy map[string]any `json:"rateLimitPolicy,omitempty"`
|
|
AuthPolicy map[string]any `json:"authPolicy,omitempty"`
|
|
Metadata map[string]any `json:"metadata,omitempty"`
|
|
Status string `json:"status"`
|
|
SyncedAt string `json:"syncedAt,omitempty"`
|
|
SourceUpdatedAt string `json:"sourceUpdatedAt,omitempty"`
|
|
CreatedAt time.Time `json:"createdAt"`
|
|
UpdatedAt time.Time `json:"updatedAt"`
|
|
}
|
|
|
|
type LocalRegisterInput struct {
|
|
Username string `json:"username"`
|
|
Email string `json:"email"`
|
|
Password string `json:"password"`
|
|
DisplayName string `json:"displayName"`
|
|
TenantKey string `json:"tenantKey"`
|
|
TenantName string `json:"tenantName"`
|
|
InvitationCode string `json:"invitationCode"`
|
|
}
|
|
|
|
type LocalLoginInput struct {
|
|
Account string `json:"account"`
|
|
Password string `json:"password"`
|
|
}
|
|
|
|
type GatewayUser struct {
|
|
ID string `json:"id"`
|
|
UserKey string `json:"userKey"`
|
|
Source string `json:"source"`
|
|
ExternalUserID string `json:"externalUserId,omitempty"`
|
|
Username string `json:"username"`
|
|
DisplayName string `json:"displayName,omitempty"`
|
|
Email string `json:"email,omitempty"`
|
|
Phone string `json:"phone,omitempty"`
|
|
AvatarURL string `json:"avatarUrl,omitempty"`
|
|
GatewayTenantID string `json:"gatewayTenantId,omitempty"`
|
|
TenantID string `json:"tenantId,omitempty"`
|
|
TenantKey string `json:"tenantKey,omitempty"`
|
|
DefaultUserGroupID string `json:"defaultUserGroupId,omitempty"`
|
|
Roles []string `json:"roles,omitempty"`
|
|
AuthProfile map[string]any `json:"authProfile,omitempty"`
|
|
Metadata map[string]any `json:"metadata,omitempty"`
|
|
Status string `json:"status"`
|
|
LastLoginAt string `json:"lastLoginAt,omitempty"`
|
|
SyncedAt string `json:"syncedAt,omitempty"`
|
|
SourceUpdatedAt string `json:"sourceUpdatedAt,omitempty"`
|
|
CreatedAt time.Time `json:"createdAt"`
|
|
UpdatedAt time.Time `json:"updatedAt"`
|
|
}
|
|
|
|
type UserGroup struct {
|
|
ID string `json:"id"`
|
|
GroupKey string `json:"groupKey"`
|
|
Name string `json:"name"`
|
|
Description string `json:"description,omitempty"`
|
|
Source string `json:"source"`
|
|
Priority int `json:"priority"`
|
|
RechargeDiscountPolicy map[string]any `json:"rechargeDiscountPolicy,omitempty"`
|
|
BillingDiscountPolicy map[string]any `json:"billingDiscountPolicy,omitempty"`
|
|
RateLimitPolicy map[string]any `json:"rateLimitPolicy,omitempty"`
|
|
QuotaPolicy map[string]any `json:"quotaPolicy,omitempty"`
|
|
Metadata map[string]any `json:"metadata,omitempty"`
|
|
Status string `json:"status"`
|
|
CreatedAt time.Time `json:"createdAt"`
|
|
UpdatedAt time.Time `json:"updatedAt"`
|
|
}
|
|
|
|
type RateLimitWindow struct {
|
|
ScopeType string `json:"scopeType"`
|
|
ScopeKey string `json:"scopeKey"`
|
|
Metric string `json:"metric"`
|
|
WindowStart time.Time `json:"windowStart"`
|
|
LimitValue float64 `json:"limitValue"`
|
|
UsedValue float64 `json:"usedValue"`
|
|
ReservedValue float64 `json:"reservedValue"`
|
|
ResetAt time.Time `json:"resetAt"`
|
|
UpdatedAt time.Time `json:"updatedAt"`
|
|
}
|
|
|
|
type CreateTaskInput struct {
|
|
Kind string `json:"kind"`
|
|
Model string `json:"model"`
|
|
Request map[string]any `json:"request"`
|
|
}
|
|
|
|
type GatewayTask struct {
|
|
ID string `json:"id"`
|
|
Kind string `json:"kind"`
|
|
UserID string `json:"userId"`
|
|
GatewayUserID string `json:"gatewayUserId,omitempty"`
|
|
UserSource string `json:"userSource,omitempty"`
|
|
GatewayTenantID string `json:"gatewayTenantId,omitempty"`
|
|
TenantID string `json:"tenantId,omitempty"`
|
|
TenantKey string `json:"tenantKey,omitempty"`
|
|
UserGroupID string `json:"userGroupId,omitempty"`
|
|
UserGroupKey string `json:"userGroupKey,omitempty"`
|
|
Model string `json:"model"`
|
|
Request map[string]any `json:"request,omitempty"`
|
|
Status string `json:"status"`
|
|
Result map[string]any `json:"result,omitempty"`
|
|
Billings []any `json:"billings,omitempty"`
|
|
Error string `json:"error,omitempty"`
|
|
CreatedAt time.Time `json:"createdAt"`
|
|
UpdatedAt time.Time `json:"updatedAt"`
|
|
}
|
|
|
|
func (s *Store) ListPlatforms(ctx context.Context) ([]Platform, error) {
|
|
rows, err := s.pool.Query(ctx, `
|
|
SELECT id::text, provider, platform_key, name, COALESCE(base_url, ''), auth_type, status, priority,
|
|
default_pricing_mode, default_discount_factor::float8, config, created_at, updated_at
|
|
FROM integration_platforms
|
|
ORDER BY priority ASC, created_at DESC`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
platforms := make([]Platform, 0)
|
|
for rows.Next() {
|
|
var platform Platform
|
|
var configBytes []byte
|
|
if err := rows.Scan(
|
|
&platform.ID,
|
|
&platform.Provider,
|
|
&platform.PlatformKey,
|
|
&platform.Name,
|
|
&platform.BaseURL,
|
|
&platform.AuthType,
|
|
&platform.Status,
|
|
&platform.Priority,
|
|
&platform.DefaultPricingMode,
|
|
&platform.DefaultDiscountFactor,
|
|
&configBytes,
|
|
&platform.CreatedAt,
|
|
&platform.UpdatedAt,
|
|
); err != nil {
|
|
return nil, err
|
|
}
|
|
platform.Config = decodeObject(configBytes)
|
|
platforms = append(platforms, platform)
|
|
}
|
|
return platforms, rows.Err()
|
|
}
|
|
|
|
func (s *Store) CreatePlatform(ctx context.Context, input CreatePlatformInput) (Platform, error) {
|
|
credentials, _ := json.Marshal(input.Credentials)
|
|
config, _ := json.Marshal(input.Config)
|
|
if input.DefaultPricingMode == "" {
|
|
input.DefaultPricingMode = "inherit_discount"
|
|
}
|
|
if input.DefaultDiscountFactor == 0 {
|
|
input.DefaultDiscountFactor = 1
|
|
}
|
|
if input.Priority == 0 {
|
|
input.Priority = 100
|
|
}
|
|
var platform Platform
|
|
var configBytes []byte
|
|
err := s.pool.QueryRow(ctx, `
|
|
INSERT INTO integration_platforms (provider, platform_key, name, base_url, auth_type, credentials, config, default_pricing_mode, default_discount_factor, priority)
|
|
VALUES ($1, COALESCE(NULLIF($2, ''), 'platform_' || replace(gen_random_uuid()::text, '-', '')), $3, $4, $5, $6, $7, $8, $9, $10)
|
|
RETURNING id::text, provider, platform_key, name, COALESCE(base_url, ''), auth_type, status, priority,
|
|
default_pricing_mode, default_discount_factor::float8, config, created_at, updated_at`,
|
|
input.Provider, input.PlatformKey, input.Name, input.BaseURL, input.AuthType, credentials, config, input.DefaultPricingMode, input.DefaultDiscountFactor, input.Priority,
|
|
).Scan(
|
|
&platform.ID,
|
|
&platform.Provider,
|
|
&platform.PlatformKey,
|
|
&platform.Name,
|
|
&platform.BaseURL,
|
|
&platform.AuthType,
|
|
&platform.Status,
|
|
&platform.Priority,
|
|
&platform.DefaultPricingMode,
|
|
&platform.DefaultDiscountFactor,
|
|
&configBytes,
|
|
&platform.CreatedAt,
|
|
&platform.UpdatedAt,
|
|
)
|
|
if err != nil {
|
|
return Platform{}, err
|
|
}
|
|
platform.Config = decodeObject(configBytes)
|
|
return platform, nil
|
|
}
|
|
|
|
func (s *Store) ListModels(ctx context.Context) ([]PlatformModel, error) {
|
|
rows, err := s.pool.Query(ctx, `
|
|
SELECT m.id::text, m.platform_id::text, COALESCE(m.base_model_id::text, ''), p.provider, p.name,
|
|
m.model_name, COALESCE(m.model_alias, ''), m.model_type, m.display_name,
|
|
m.capability_override, m.capabilities, m.pricing_mode, COALESCE(m.discount_factor, 0)::float8,
|
|
m.billing_config_override, m.billing_config, m.enabled, m.created_at, m.updated_at
|
|
FROM platform_models m
|
|
JOIN integration_platforms p ON p.id = m.platform_id
|
|
ORDER BY m.model_type ASC, m.model_name ASC`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
models := make([]PlatformModel, 0)
|
|
for rows.Next() {
|
|
var model PlatformModel
|
|
var capabilityOverride []byte
|
|
var capabilities []byte
|
|
var billingConfigOverride []byte
|
|
var billingConfig []byte
|
|
if err := rows.Scan(
|
|
&model.ID,
|
|
&model.PlatformID,
|
|
&model.BaseModelID,
|
|
&model.Provider,
|
|
&model.PlatformName,
|
|
&model.ModelName,
|
|
&model.ModelAlias,
|
|
&model.ModelType,
|
|
&model.DisplayName,
|
|
&capabilityOverride,
|
|
&capabilities,
|
|
&model.PricingMode,
|
|
&model.DiscountFactor,
|
|
&billingConfigOverride,
|
|
&billingConfig,
|
|
&model.Enabled,
|
|
&model.CreatedAt,
|
|
&model.UpdatedAt,
|
|
); err != nil {
|
|
return nil, err
|
|
}
|
|
model.CapabilityOverride = decodeObject(capabilityOverride)
|
|
model.Capabilities = decodeObject(capabilities)
|
|
model.BillingConfigOverride = decodeObject(billingConfigOverride)
|
|
model.BillingConfig = decodeObject(billingConfig)
|
|
models = append(models, model)
|
|
}
|
|
return models, rows.Err()
|
|
}
|
|
|
|
func (s *Store) ListCatalogProviders(ctx context.Context) ([]CatalogProvider, error) {
|
|
rows, err := s.pool.Query(ctx, `
|
|
SELECT id::text, provider_key, display_name, provider_type, capability_schema,
|
|
default_rate_limit_policy, status, created_at, updated_at
|
|
FROM model_catalog_providers
|
|
ORDER BY provider_key ASC`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
items := make([]CatalogProvider, 0)
|
|
for rows.Next() {
|
|
var item CatalogProvider
|
|
var capabilitySchema []byte
|
|
var rateLimitPolicy []byte
|
|
if err := rows.Scan(
|
|
&item.ID,
|
|
&item.ProviderKey,
|
|
&item.DisplayName,
|
|
&item.ProviderType,
|
|
&capabilitySchema,
|
|
&rateLimitPolicy,
|
|
&item.Status,
|
|
&item.CreatedAt,
|
|
&item.UpdatedAt,
|
|
); err != nil {
|
|
return nil, err
|
|
}
|
|
item.CapabilitySchema = decodeObject(capabilitySchema)
|
|
item.DefaultRateLimitPolicy = decodeObject(rateLimitPolicy)
|
|
items = append(items, item)
|
|
}
|
|
return items, rows.Err()
|
|
}
|
|
|
|
func (s *Store) ListBaseModels(ctx context.Context) ([]BaseModel, error) {
|
|
rows, err := s.pool.Query(ctx, `
|
|
SELECT id::text, provider_key, canonical_model_key, provider_model_name, model_type, display_name,
|
|
capabilities, base_billing_config, default_rate_limit_policy, pricing_version,
|
|
status, created_at, updated_at
|
|
FROM base_model_catalog
|
|
ORDER BY provider_key ASC, model_type ASC, canonical_model_key ASC`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
items := make([]BaseModel, 0)
|
|
for rows.Next() {
|
|
var item BaseModel
|
|
var capabilities []byte
|
|
var billingConfig []byte
|
|
var rateLimitPolicy []byte
|
|
if err := rows.Scan(
|
|
&item.ID,
|
|
&item.ProviderKey,
|
|
&item.CanonicalModelKey,
|
|
&item.ProviderModelName,
|
|
&item.ModelType,
|
|
&item.DisplayName,
|
|
&capabilities,
|
|
&billingConfig,
|
|
&rateLimitPolicy,
|
|
&item.PricingVersion,
|
|
&item.Status,
|
|
&item.CreatedAt,
|
|
&item.UpdatedAt,
|
|
); err != nil {
|
|
return nil, err
|
|
}
|
|
item.Capabilities = decodeObject(capabilities)
|
|
item.BaseBillingConfig = decodeObject(billingConfig)
|
|
item.DefaultRateLimitPolicy = decodeObject(rateLimitPolicy)
|
|
items = append(items, item)
|
|
}
|
|
return items, rows.Err()
|
|
}
|
|
|
|
func (s *Store) ListPricingRules(ctx context.Context) ([]PricingRule, error) {
|
|
rows, err := s.pool.Query(ctx, `
|
|
SELECT id::text, scope_type, COALESCE(scope_id::text, ''), resource_type, unit,
|
|
base_price::float8, currency, base_weight, dynamic_weight, created_at, updated_at
|
|
FROM model_pricing_rules
|
|
ORDER BY scope_type ASC, resource_type ASC, created_at DESC`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
items := make([]PricingRule, 0)
|
|
for rows.Next() {
|
|
var item PricingRule
|
|
var baseWeight []byte
|
|
var dynamicWeight []byte
|
|
if err := rows.Scan(
|
|
&item.ID,
|
|
&item.ScopeType,
|
|
&item.ScopeID,
|
|
&item.ResourceType,
|
|
&item.Unit,
|
|
&item.BasePrice,
|
|
&item.Currency,
|
|
&baseWeight,
|
|
&dynamicWeight,
|
|
&item.CreatedAt,
|
|
&item.UpdatedAt,
|
|
); err != nil {
|
|
return nil, err
|
|
}
|
|
item.BaseWeight = decodeObject(baseWeight)
|
|
item.DynamicWeight = decodeObject(dynamicWeight)
|
|
items = append(items, item)
|
|
}
|
|
return items, rows.Err()
|
|
}
|
|
|
|
func (s *Store) ListTenants(ctx context.Context) ([]GatewayTenant, error) {
|
|
rows, err := s.pool.Query(ctx, `
|
|
SELECT id::text, tenant_key, source, COALESCE(external_tenant_id, ''), name, COALESCE(description, ''),
|
|
COALESCE(default_user_group_id::text, ''), COALESCE(plan_key, ''), billing_profile, rate_limit_policy,
|
|
auth_policy, metadata, status, COALESCE(synced_at::text, ''), COALESCE(source_updated_at::text, ''),
|
|
created_at, updated_at
|
|
FROM gateway_tenants
|
|
WHERE deleted_at IS NULL
|
|
ORDER BY created_at DESC`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
items := make([]GatewayTenant, 0)
|
|
for rows.Next() {
|
|
var item GatewayTenant
|
|
var billingProfile []byte
|
|
var rateLimitPolicy []byte
|
|
var authPolicy []byte
|
|
var metadata []byte
|
|
if err := rows.Scan(
|
|
&item.ID,
|
|
&item.TenantKey,
|
|
&item.Source,
|
|
&item.ExternalTenantID,
|
|
&item.Name,
|
|
&item.Description,
|
|
&item.DefaultUserGroupID,
|
|
&item.PlanKey,
|
|
&billingProfile,
|
|
&rateLimitPolicy,
|
|
&authPolicy,
|
|
&metadata,
|
|
&item.Status,
|
|
&item.SyncedAt,
|
|
&item.SourceUpdatedAt,
|
|
&item.CreatedAt,
|
|
&item.UpdatedAt,
|
|
); err != nil {
|
|
return nil, err
|
|
}
|
|
item.BillingProfile = decodeObject(billingProfile)
|
|
item.RateLimitPolicy = decodeObject(rateLimitPolicy)
|
|
item.AuthPolicy = decodeObject(authPolicy)
|
|
item.Metadata = decodeObject(metadata)
|
|
items = append(items, item)
|
|
}
|
|
return items, rows.Err()
|
|
}
|
|
|
|
func (s *Store) ListUsers(ctx context.Context) ([]GatewayUser, error) {
|
|
rows, err := s.pool.Query(ctx, `
|
|
SELECT id::text, user_key, source, COALESCE(external_user_id, ''), username,
|
|
COALESCE(display_name, ''), COALESCE(email, ''), COALESCE(phone, ''), COALESCE(avatar_url, ''),
|
|
COALESCE(gateway_tenant_id::text, ''), COALESCE(tenant_id, ''), COALESCE(tenant_key, ''),
|
|
COALESCE(default_user_group_id::text, ''), roles, auth_profile, metadata,
|
|
status, COALESCE(last_login_at::text, ''), COALESCE(synced_at::text, ''), COALESCE(source_updated_at::text, ''),
|
|
created_at, updated_at
|
|
FROM gateway_users
|
|
WHERE deleted_at IS NULL
|
|
ORDER BY created_at DESC`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
items := make([]GatewayUser, 0)
|
|
for rows.Next() {
|
|
var item GatewayUser
|
|
var roles []byte
|
|
var authProfile []byte
|
|
var metadata []byte
|
|
if err := rows.Scan(
|
|
&item.ID,
|
|
&item.UserKey,
|
|
&item.Source,
|
|
&item.ExternalUserID,
|
|
&item.Username,
|
|
&item.DisplayName,
|
|
&item.Email,
|
|
&item.Phone,
|
|
&item.AvatarURL,
|
|
&item.GatewayTenantID,
|
|
&item.TenantID,
|
|
&item.TenantKey,
|
|
&item.DefaultUserGroupID,
|
|
&roles,
|
|
&authProfile,
|
|
&metadata,
|
|
&item.Status,
|
|
&item.LastLoginAt,
|
|
&item.SyncedAt,
|
|
&item.SourceUpdatedAt,
|
|
&item.CreatedAt,
|
|
&item.UpdatedAt,
|
|
); err != nil {
|
|
return nil, err
|
|
}
|
|
item.Roles = decodeStringArray(roles)
|
|
item.AuthProfile = decodeObject(authProfile)
|
|
item.Metadata = decodeObject(metadata)
|
|
items = append(items, item)
|
|
}
|
|
return items, rows.Err()
|
|
}
|
|
|
|
func (s *Store) ListUserGroups(ctx context.Context) ([]UserGroup, error) {
|
|
rows, err := s.pool.Query(ctx, `
|
|
SELECT id::text, group_key, name, COALESCE(description, ''), source, priority,
|
|
recharge_discount_policy, billing_discount_policy, rate_limit_policy, quota_policy, metadata,
|
|
status, created_at, updated_at
|
|
FROM gateway_user_groups
|
|
ORDER BY priority ASC, group_key ASC`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
items := make([]UserGroup, 0)
|
|
for rows.Next() {
|
|
var item UserGroup
|
|
var rechargeDiscountPolicy []byte
|
|
var billingDiscountPolicy []byte
|
|
var rateLimitPolicy []byte
|
|
var quotaPolicy []byte
|
|
var metadata []byte
|
|
if err := rows.Scan(
|
|
&item.ID,
|
|
&item.GroupKey,
|
|
&item.Name,
|
|
&item.Description,
|
|
&item.Source,
|
|
&item.Priority,
|
|
&rechargeDiscountPolicy,
|
|
&billingDiscountPolicy,
|
|
&rateLimitPolicy,
|
|
"aPolicy,
|
|
&metadata,
|
|
&item.Status,
|
|
&item.CreatedAt,
|
|
&item.UpdatedAt,
|
|
); err != nil {
|
|
return nil, err
|
|
}
|
|
item.RechargeDiscountPolicy = decodeObject(rechargeDiscountPolicy)
|
|
item.BillingDiscountPolicy = decodeObject(billingDiscountPolicy)
|
|
item.RateLimitPolicy = decodeObject(rateLimitPolicy)
|
|
item.QuotaPolicy = decodeObject(quotaPolicy)
|
|
item.Metadata = decodeObject(metadata)
|
|
items = append(items, item)
|
|
}
|
|
return items, rows.Err()
|
|
}
|
|
|
|
func (s *Store) RegisterLocalUser(ctx context.Context, input LocalRegisterInput) (GatewayUser, error) {
|
|
account := normalizeAccount(firstNonEmpty(input.Username, input.Email))
|
|
if account == "" {
|
|
return GatewayUser{}, errors.New("username or email is required")
|
|
}
|
|
if len(input.Password) < 8 {
|
|
return GatewayUser{}, ErrWeakPassword
|
|
}
|
|
tenantKey := normalizeKey(input.TenantKey)
|
|
if tenantKey == "" {
|
|
tenantKey = "personal-" + normalizeKey(account)
|
|
}
|
|
tenantName := strings.TrimSpace(input.TenantName)
|
|
if tenantName == "" {
|
|
tenantName = tenantKey
|
|
}
|
|
displayName := strings.TrimSpace(input.DisplayName)
|
|
username := strings.TrimSpace(input.Username)
|
|
if username == "" {
|
|
username = account
|
|
}
|
|
email := strings.TrimSpace(strings.ToLower(input.Email))
|
|
invitationCode := strings.TrimSpace(input.InvitationCode)
|
|
|
|
passwordHash, err := bcrypt.GenerateFromPassword([]byte(input.Password), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return GatewayUser{}, err
|
|
}
|
|
|
|
tx, err := s.pool.Begin(ctx)
|
|
if err != nil {
|
|
return GatewayUser{}, err
|
|
}
|
|
defer tx.Rollback(ctx)
|
|
|
|
var tenantID string
|
|
userGroupID := ""
|
|
role := "user"
|
|
invitationID := ""
|
|
if invitationCode != "" {
|
|
if err := tx.QueryRow(ctx, `
|
|
SELECT i.id::text,
|
|
i.tenant_id::text,
|
|
t.tenant_key,
|
|
t.name,
|
|
COALESCE(i.user_group_id::text, t.default_user_group_id::text, ''),
|
|
COALESCE(NULLIF(i.role, ''), 'user')
|
|
FROM gateway_tenant_invitations i
|
|
JOIN gateway_tenants t ON t.id = i.tenant_id
|
|
WHERE lower(i.invite_code) = lower($1)
|
|
AND i.status = 'active'
|
|
AND t.status = 'active'
|
|
AND (i.expires_at IS NULL OR i.expires_at > now())
|
|
AND (i.max_uses IS NULL OR i.used_count < i.max_uses)
|
|
FOR UPDATE OF i`,
|
|
invitationCode,
|
|
).Scan(&invitationID, &tenantID, &tenantKey, &tenantName, &userGroupID, &role); err != nil {
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return GatewayUser{}, ErrInvalidInvitation
|
|
}
|
|
return GatewayUser{}, err
|
|
}
|
|
} else if err := tx.QueryRow(ctx, `
|
|
INSERT INTO gateway_tenants (tenant_key, source, external_tenant_id, name)
|
|
VALUES ($1, 'gateway', $1, $2)
|
|
ON CONFLICT (tenant_key) DO UPDATE SET updated_at=now()
|
|
RETURNING id::text`,
|
|
tenantKey, tenantName,
|
|
).Scan(&tenantID); err != nil {
|
|
return GatewayUser{}, err
|
|
}
|
|
|
|
rolesJSON, err := json.Marshal([]string{role})
|
|
if err != nil {
|
|
return GatewayUser{}, err
|
|
}
|
|
|
|
var user GatewayUser
|
|
var roles []byte
|
|
var authProfile []byte
|
|
var metadata []byte
|
|
if err := tx.QueryRow(ctx, `
|
|
INSERT INTO gateway_users (
|
|
user_key, source, external_user_id, username, display_name, email,
|
|
password_hash, gateway_tenant_id, tenant_id, tenant_key, default_user_group_id, roles, status
|
|
)
|
|
VALUES ($1, 'gateway', $2, $3, NULLIF($4, ''), NULLIF($5, ''), $6, $7::uuid, $8, $8, NULLIF($9, '')::uuid, $10::jsonb, 'active')
|
|
RETURNING id::text, user_key, source, COALESCE(external_user_id, ''), username,
|
|
COALESCE(display_name, ''), COALESCE(email, ''), COALESCE(phone, ''), COALESCE(avatar_url, ''),
|
|
COALESCE(gateway_tenant_id::text, ''), COALESCE(tenant_id, ''), COALESCE(tenant_key, ''),
|
|
COALESCE(default_user_group_id::text, ''), roles, auth_profile, metadata,
|
|
status, COALESCE(last_login_at::text, ''), COALESCE(synced_at::text, ''), COALESCE(source_updated_at::text, ''),
|
|
created_at, updated_at`,
|
|
"gateway:"+account, account, username, displayName, email, string(passwordHash), tenantID, tenantKey, userGroupID, string(rolesJSON),
|
|
).Scan(
|
|
&user.ID,
|
|
&user.UserKey,
|
|
&user.Source,
|
|
&user.ExternalUserID,
|
|
&user.Username,
|
|
&user.DisplayName,
|
|
&user.Email,
|
|
&user.Phone,
|
|
&user.AvatarURL,
|
|
&user.GatewayTenantID,
|
|
&user.TenantID,
|
|
&user.TenantKey,
|
|
&user.DefaultUserGroupID,
|
|
&roles,
|
|
&authProfile,
|
|
&metadata,
|
|
&user.Status,
|
|
&user.LastLoginAt,
|
|
&user.SyncedAt,
|
|
&user.SourceUpdatedAt,
|
|
&user.CreatedAt,
|
|
&user.UpdatedAt,
|
|
); err != nil {
|
|
return GatewayUser{}, err
|
|
}
|
|
if invitationID != "" {
|
|
if _, err := tx.Exec(ctx, `
|
|
UPDATE gateway_tenant_invitations
|
|
SET used_count = used_count + 1, updated_at = now()
|
|
WHERE id = $1::uuid`, invitationID); err != nil {
|
|
return GatewayUser{}, err
|
|
}
|
|
}
|
|
if userGroupID != "" {
|
|
metadata, err := json.Marshal(map[string]any{
|
|
"source": "registration",
|
|
"invitationId": invitationID,
|
|
})
|
|
if err != nil {
|
|
return GatewayUser{}, err
|
|
}
|
|
if _, err := tx.Exec(ctx, `
|
|
INSERT INTO gateway_user_group_memberships (group_id, principal_type, principal_id, source, metadata)
|
|
VALUES ($1::uuid, 'user', $2, 'gateway', $3::jsonb)
|
|
ON CONFLICT (group_id, principal_type, principal_id)
|
|
DO UPDATE SET status = 'active', updated_at = now()`,
|
|
userGroupID, user.ID, string(metadata),
|
|
); err != nil {
|
|
return GatewayUser{}, err
|
|
}
|
|
}
|
|
if err := tx.Commit(ctx); err != nil {
|
|
return GatewayUser{}, err
|
|
}
|
|
user.Roles = decodeStringArray(roles)
|
|
user.AuthProfile = decodeObject(authProfile)
|
|
user.Metadata = decodeObject(metadata)
|
|
return user, nil
|
|
}
|
|
|
|
func (s *Store) AuthenticateLocalUser(ctx context.Context, input LocalLoginInput) (GatewayUser, error) {
|
|
account := normalizeAccount(input.Account)
|
|
if account == "" || input.Password == "" {
|
|
return GatewayUser{}, ErrInvalidCredentials
|
|
}
|
|
var user GatewayUser
|
|
var passwordHash string
|
|
var roles []byte
|
|
var authProfile []byte
|
|
var metadata []byte
|
|
err := s.pool.QueryRow(ctx, `
|
|
SELECT id::text, user_key, source, COALESCE(external_user_id, ''), username,
|
|
COALESCE(display_name, ''), COALESCE(email, ''), COALESCE(phone, ''), COALESCE(avatar_url, ''),
|
|
COALESCE(gateway_tenant_id::text, ''), COALESCE(tenant_id, ''), COALESCE(tenant_key, ''),
|
|
COALESCE(default_user_group_id::text, ''), roles, auth_profile, metadata,
|
|
status, COALESCE(password_hash, ''), COALESCE(last_login_at::text, ''), COALESCE(synced_at::text, ''),
|
|
COALESCE(source_updated_at::text, ''), created_at, updated_at
|
|
FROM gateway_users
|
|
WHERE source='gateway'
|
|
AND deleted_at IS NULL
|
|
AND (external_user_id=$1 OR lower(username)=$1 OR lower(COALESCE(email, ''))=$1)
|
|
ORDER BY created_at ASC
|
|
LIMIT 1`, account,
|
|
).Scan(
|
|
&user.ID,
|
|
&user.UserKey,
|
|
&user.Source,
|
|
&user.ExternalUserID,
|
|
&user.Username,
|
|
&user.DisplayName,
|
|
&user.Email,
|
|
&user.Phone,
|
|
&user.AvatarURL,
|
|
&user.GatewayTenantID,
|
|
&user.TenantID,
|
|
&user.TenantKey,
|
|
&user.DefaultUserGroupID,
|
|
&roles,
|
|
&authProfile,
|
|
&metadata,
|
|
&user.Status,
|
|
&passwordHash,
|
|
&user.LastLoginAt,
|
|
&user.SyncedAt,
|
|
&user.SourceUpdatedAt,
|
|
&user.CreatedAt,
|
|
&user.UpdatedAt,
|
|
)
|
|
if err != nil {
|
|
if IsNotFound(err) {
|
|
return GatewayUser{}, ErrInvalidCredentials
|
|
}
|
|
return GatewayUser{}, err
|
|
}
|
|
if user.Status != "active" || passwordHash == "" {
|
|
return GatewayUser{}, ErrInvalidCredentials
|
|
}
|
|
if err := bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(input.Password)); err != nil {
|
|
return GatewayUser{}, ErrInvalidCredentials
|
|
}
|
|
user.Roles = decodeStringArray(roles)
|
|
user.AuthProfile = decodeObject(authProfile)
|
|
user.Metadata = decodeObject(metadata)
|
|
_, _ = s.pool.Exec(ctx, `UPDATE gateway_users SET last_login_at=now(), updated_at=now() WHERE id=$1`, user.ID)
|
|
return user, nil
|
|
}
|
|
|
|
func (s *Store) ListRateLimitWindows(ctx context.Context) ([]RateLimitWindow, error) {
|
|
rows, err := s.pool.Query(ctx, `
|
|
SELECT scope_type, scope_key, metric, window_start, limit_value::float8, used_value::float8,
|
|
reserved_value::float8, reset_at, updated_at
|
|
FROM gateway_rate_limit_counters
|
|
WHERE reset_at >= now() - interval '5 minutes'
|
|
ORDER BY window_start DESC, scope_type ASC, scope_key ASC, metric ASC`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
items := make([]RateLimitWindow, 0)
|
|
for rows.Next() {
|
|
var item RateLimitWindow
|
|
if err := rows.Scan(
|
|
&item.ScopeType,
|
|
&item.ScopeKey,
|
|
&item.Metric,
|
|
&item.WindowStart,
|
|
&item.LimitValue,
|
|
&item.UsedValue,
|
|
&item.ReservedValue,
|
|
&item.ResetAt,
|
|
&item.UpdatedAt,
|
|
); err != nil {
|
|
return nil, err
|
|
}
|
|
items = append(items, item)
|
|
}
|
|
return items, rows.Err()
|
|
}
|
|
|
|
func (s *Store) CreateTask(ctx context.Context, input CreateTaskInput, user *auth.User) (GatewayTask, error) {
|
|
requestBody, _ := json.Marshal(input.Request)
|
|
var task GatewayTask
|
|
var requestBytes []byte
|
|
var resultBytes []byte
|
|
var billingsBytes []byte
|
|
err := s.pool.QueryRow(ctx, `
|
|
INSERT INTO gateway_tasks (
|
|
kind, user_id, gateway_user_id, user_source, gateway_tenant_id, tenant_id, tenant_key,
|
|
api_key_id, user_group_id, user_group_key, model, request, status
|
|
)
|
|
VALUES ($1, $2, NULLIF($3, '')::uuid, COALESCE(NULLIF($4, ''), 'gateway'), NULLIF($5, '')::uuid, NULLIF($6, ''), NULLIF($7, ''), NULLIF($8, ''), NULLIF($9, '')::uuid, NULLIF($10, ''), $11, $12, 'queued')
|
|
RETURNING id::text, kind, user_id, COALESCE(gateway_user_id::text, ''), user_source,
|
|
COALESCE(gateway_tenant_id::text, ''), COALESCE(tenant_id, ''), COALESCE(tenant_key, ''),
|
|
COALESCE(user_group_id::text, ''), COALESCE(user_group_key, ''), model, request, status, result, billings, COALESCE(error, ''), created_at, updated_at`,
|
|
input.Kind, user.ID, user.GatewayUserID, user.Source, user.GatewayTenantID, user.TenantID, user.TenantKey, user.APIKeyID, user.UserGroupID, user.UserGroupKey, input.Model, requestBody,
|
|
).Scan(&task.ID, &task.Kind, &task.UserID, &task.GatewayUserID, &task.UserSource, &task.GatewayTenantID, &task.TenantID, &task.TenantKey, &task.UserGroupID, &task.UserGroupKey, &task.Model, &requestBytes, &task.Status, &resultBytes, &billingsBytes, &task.Error, &task.CreatedAt, &task.UpdatedAt)
|
|
if err != nil {
|
|
return GatewayTask{}, err
|
|
}
|
|
task.Request = decodeObject(requestBytes)
|
|
task.Result = decodeObject(resultBytes)
|
|
task.Billings = decodeArray(billingsBytes)
|
|
return task, nil
|
|
}
|
|
|
|
func (s *Store) GetTask(ctx context.Context, taskID string) (GatewayTask, error) {
|
|
var task GatewayTask
|
|
var requestBytes []byte
|
|
var resultBytes []byte
|
|
var billingsBytes []byte
|
|
err := s.pool.QueryRow(ctx, `
|
|
SELECT id::text, kind, user_id, COALESCE(gateway_user_id::text, ''), user_source,
|
|
COALESCE(gateway_tenant_id::text, ''), COALESCE(tenant_id, ''), COALESCE(tenant_key, ''),
|
|
COALESCE(user_group_id::text, ''), COALESCE(user_group_key, ''), model, request, status, result, billings, COALESCE(error, ''), created_at, updated_at
|
|
FROM gateway_tasks
|
|
WHERE id=$1`, taskID,
|
|
).Scan(&task.ID, &task.Kind, &task.UserID, &task.GatewayUserID, &task.UserSource, &task.GatewayTenantID, &task.TenantID, &task.TenantKey, &task.UserGroupID, &task.UserGroupKey, &task.Model, &requestBytes, &task.Status, &resultBytes, &billingsBytes, &task.Error, &task.CreatedAt, &task.UpdatedAt)
|
|
if err != nil {
|
|
return GatewayTask{}, err
|
|
}
|
|
task.Request = decodeObject(requestBytes)
|
|
task.Result = decodeObject(resultBytes)
|
|
task.Billings = decodeArray(billingsBytes)
|
|
return task, nil
|
|
}
|
|
|
|
func IsNotFound(err error) bool {
|
|
return err == pgx.ErrNoRows
|
|
}
|
|
|
|
func decodeObject(bytes []byte) map[string]any {
|
|
if len(bytes) == 0 {
|
|
return nil
|
|
}
|
|
var out map[string]any
|
|
if err := json.Unmarshal(bytes, &out); err != nil {
|
|
return nil
|
|
}
|
|
return out
|
|
}
|
|
|
|
func decodeArray(bytes []byte) []any {
|
|
if len(bytes) == 0 {
|
|
return nil
|
|
}
|
|
var out []any
|
|
if err := json.Unmarshal(bytes, &out); err != nil {
|
|
return nil
|
|
}
|
|
return out
|
|
}
|
|
|
|
func decodeStringArray(bytes []byte) []string {
|
|
if len(bytes) == 0 {
|
|
return nil
|
|
}
|
|
var out []string
|
|
if err := json.Unmarshal(bytes, &out); err == nil {
|
|
return out
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func firstNonEmpty(values ...string) string {
|
|
for _, value := range values {
|
|
if strings.TrimSpace(value) != "" {
|
|
return value
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func normalizeAccount(value string) string {
|
|
return strings.ToLower(strings.TrimSpace(value))
|
|
}
|
|
|
|
func normalizeKey(value string) string {
|
|
value = strings.ToLower(strings.TrimSpace(value))
|
|
var b strings.Builder
|
|
lastDash := false
|
|
for _, r := range value {
|
|
switch {
|
|
case unicode.IsLetter(r), unicode.IsDigit(r):
|
|
b.WriteRune(r)
|
|
lastDash = false
|
|
case r == '-' || r == '_' || r == '.' || unicode.IsSpace(r):
|
|
if !lastDash && b.Len() > 0 {
|
|
b.WriteByte('-')
|
|
lastDash = true
|
|
}
|
|
}
|
|
}
|
|
out := strings.Trim(b.String(), "-")
|
|
if out == "" {
|
|
return "default"
|
|
}
|
|
return out
|
|
}
|