easyai-ai-gateway/apps/api/internal/store/platform_models.go

344 lines
11 KiB
Go

package store
import (
"context"
"encoding/json"
"strings"
"github.com/jackc/pgx/v5"
)
type platformModelQuerier interface {
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
}
type modelCatalogSnapshot struct {
ID string
ProviderKey string
CanonicalModelKey string
ProviderModelName string
ModelType string
DisplayName string
Capabilities map[string]any
BaseBillingConfig map[string]any
DefaultRateLimitPolicy map[string]any
RuntimePolicySetID string
RuntimePolicyOverride map[string]any
}
func (s *Store) CreatePlatformModel(ctx context.Context, input CreatePlatformModelInput) (PlatformModel, error) {
return s.createPlatformModel(ctx, s.pool, input)
}
func (s *Store) ReplacePlatformModels(ctx context.Context, platformID string, inputs []CreatePlatformModelInput) ([]PlatformModel, error) {
tx, err := s.pool.Begin(ctx)
if err != nil {
return nil, err
}
defer tx.Rollback(ctx)
keptIDs := make([]string, 0, len(inputs))
for _, input := range inputs {
input.PlatformID = platformID
model, err := s.createPlatformModel(ctx, tx, input)
if err != nil {
return nil, err
}
keptIDs = append(keptIDs, model.ID)
}
if len(keptIDs) == 0 {
if _, err := tx.Exec(ctx, `
WITH deleted AS (
DELETE FROM platform_models
WHERE platform_id = $1::uuid
RETURNING id
)
DELETE FROM gateway_access_rules
WHERE resource_type = 'platform_model'
AND resource_id IN (SELECT id FROM deleted)`, platformID); err != nil {
return nil, err
}
} else if _, err := tx.Exec(ctx, `
WITH deleted AS (
DELETE FROM platform_models
WHERE platform_id = $1::uuid
AND NOT (id::text = ANY($2::text[]))
RETURNING id
)
DELETE FROM gateway_access_rules
WHERE resource_type = 'platform_model'
AND resource_id IN (SELECT id FROM deleted)`, platformID, keptIDs); err != nil {
return nil, err
}
if err := tx.Commit(ctx); err != nil {
return nil, err
}
return s.ListPlatformModels(ctx, platformID)
}
func (s *Store) createPlatformModel(ctx context.Context, q platformModelQuerier, input CreatePlatformModelInput) (PlatformModel, error) {
base, err := s.lookupBaseModel(ctx, q, input.BaseModelID, input.CanonicalModelKey, input.ModelName)
if err != nil && !IsNotFound(err) {
return PlatformModel{}, err
}
if input.ModelType == "" {
input.ModelType = base.ModelType
}
if input.ModelName == "" {
input.ModelName = base.ProviderModelName
}
if input.DisplayName == "" {
input.DisplayName = firstNonEmpty(base.DisplayName, input.ModelName)
}
input.ModelAlias = normalizePlatformModelAlias(input.ModelAlias, base)
if input.PricingMode == "" {
input.PricingMode = "inherit_discount"
}
capabilities := input.Capabilities
if len(capabilities) == 0 {
capabilities = mergeObjects(base.Capabilities, input.CapabilityOverride)
}
billingConfig := input.BillingConfig
if len(billingConfig) == 0 {
billingConfig = mergeObjects(base.BaseBillingConfig, input.BillingConfigOverride)
}
rateLimitPolicy := input.RateLimitPolicy
if len(rateLimitPolicy) == 0 {
rateLimitPolicy = base.DefaultRateLimitPolicy
}
runtimePolicySetID := strings.TrimSpace(input.RuntimePolicySetID)
if runtimePolicySetID == "" {
runtimePolicySetID = base.RuntimePolicySetID
}
runtimePolicyOverride := input.RuntimePolicyOverride
if len(runtimePolicyOverride) == 0 {
runtimePolicyOverride = base.RuntimePolicyOverride
}
capabilityOverrideJSON, _ := json.Marshal(emptyObjectIfNil(input.CapabilityOverride))
capabilitiesJSON, _ := json.Marshal(emptyObjectIfNil(capabilities))
billingOverrideJSON, _ := json.Marshal(emptyObjectIfNil(input.BillingConfigOverride))
billingJSON, _ := json.Marshal(emptyObjectIfNil(billingConfig))
permissionJSON, _ := json.Marshal(emptyObjectIfNil(input.PermissionConfig))
retryJSON, _ := json.Marshal(emptyObjectIfNil(input.RetryPolicy))
rateLimitJSON, _ := json.Marshal(emptyObjectIfNil(rateLimitPolicy))
runtimePolicyOverrideJSON, _ := json.Marshal(emptyObjectIfNil(runtimePolicyOverride))
discount := any(nil)
if input.DiscountFactor > 0 {
discount = input.DiscountFactor
}
baseID := any(nil)
if base.ID != "" {
baseID = base.ID
}
var model PlatformModel
var capabilityOverrideBytes []byte
var capabilitiesBytes []byte
var billingOverrideBytes []byte
var billingBytes []byte
var permissionBytes []byte
var retryPolicyBytes []byte
var rateLimitPolicyBytes []byte
var runtimePolicyOverrideBytes []byte
err = q.QueryRow(ctx, `
INSERT INTO platform_models (
platform_id, base_model_id, model_name, model_alias, model_type, display_name,
capability_override, capabilities, pricing_mode, discount_factor,
pricing_rule_set_id, billing_config_override, billing_config, permission_config, retry_policy, rate_limit_policy,
runtime_policy_set_id, runtime_policy_override, enabled
)
VALUES (
$1::uuid, $2::uuid, $3, NULLIF($4, ''), $5, $6,
$7::jsonb, $8::jsonb, $9, $10::numeric,
NULLIF($11, '')::uuid, $12::jsonb, $13::jsonb, $14::jsonb, $15::jsonb, $16::jsonb,
NULLIF($17, '')::uuid, $18::jsonb, true
)
ON CONFLICT (platform_id, model_name, model_type) DO UPDATE
SET base_model_id = EXCLUDED.base_model_id,
model_alias = EXCLUDED.model_alias,
display_name = EXCLUDED.display_name,
capability_override = EXCLUDED.capability_override,
capabilities = EXCLUDED.capabilities,
pricing_mode = EXCLUDED.pricing_mode,
discount_factor = EXCLUDED.discount_factor,
pricing_rule_set_id = EXCLUDED.pricing_rule_set_id,
billing_config_override = EXCLUDED.billing_config_override,
billing_config = EXCLUDED.billing_config,
permission_config = EXCLUDED.permission_config,
retry_policy = EXCLUDED.retry_policy,
rate_limit_policy = EXCLUDED.rate_limit_policy,
runtime_policy_set_id = EXCLUDED.runtime_policy_set_id,
runtime_policy_override = EXCLUDED.runtime_policy_override,
enabled = true,
updated_at = now()
RETURNING id::text, platform_id::text, COALESCE(base_model_id::text, ''), model_name,
COALESCE(model_alias, ''), model_type, display_name, capability_override,
capabilities, pricing_mode, COALESCE(discount_factor, 0)::float8,
COALESCE(pricing_rule_set_id::text, ''), billing_config_override, billing_config,
permission_config, retry_policy, rate_limit_policy, COALESCE(runtime_policy_set_id::text, ''), runtime_policy_override,
enabled, created_at, updated_at`,
input.PlatformID,
baseID,
input.ModelName,
input.ModelAlias,
input.ModelType,
input.DisplayName,
string(capabilityOverrideJSON),
string(capabilitiesJSON),
input.PricingMode,
discount,
input.PricingRuleSetID,
string(billingOverrideJSON),
string(billingJSON),
string(permissionJSON),
string(retryJSON),
string(rateLimitJSON),
runtimePolicySetID,
string(runtimePolicyOverrideJSON),
).Scan(
&model.ID,
&model.PlatformID,
&model.BaseModelID,
&model.ModelName,
&model.ModelAlias,
&model.ModelType,
&model.DisplayName,
&capabilityOverrideBytes,
&capabilitiesBytes,
&model.PricingMode,
&model.DiscountFactor,
&model.PricingRuleSetID,
&billingOverrideBytes,
&billingBytes,
&permissionBytes,
&retryPolicyBytes,
&rateLimitPolicyBytes,
&model.RuntimePolicySetID,
&runtimePolicyOverrideBytes,
&model.Enabled,
&model.CreatedAt,
&model.UpdatedAt,
)
if err != nil {
return PlatformModel{}, err
}
model.CapabilityOverride = decodeObject(capabilityOverrideBytes)
model.Capabilities = decodeObject(capabilitiesBytes)
model.BillingConfigOverride = decodeObject(billingOverrideBytes)
model.BillingConfig = decodeObject(billingBytes)
model.PermissionConfig = decodeObject(permissionBytes)
model.RetryPolicy = decodeObject(retryPolicyBytes)
model.RateLimitPolicy = decodeObject(rateLimitPolicyBytes)
model.RuntimePolicyOverride = decodeObject(runtimePolicyOverrideBytes)
return model, nil
}
func (s *Store) DeletePlatformModel(ctx context.Context, id string) error {
tx, err := s.pool.Begin(ctx)
if err != nil {
return err
}
defer tx.Rollback(ctx)
result, err := tx.Exec(ctx, `DELETE FROM platform_models WHERE id = $1::uuid`, id)
if err != nil {
return err
}
if result.RowsAffected() == 0 {
return pgx.ErrNoRows
}
if _, err := tx.Exec(ctx, `
DELETE FROM gateway_access_rules
WHERE resource_type = 'platform_model' AND resource_id = $1::uuid`, id); err != nil {
return err
}
return tx.Commit(ctx)
}
func (s *Store) lookupBaseModel(ctx context.Context, q platformModelQuerier, id string, canonicalKey string, modelName string) (modelCatalogSnapshot, error) {
var item modelCatalogSnapshot
var capabilities []byte
var billingConfig []byte
var rateLimitPolicy []byte
var runtimePolicyOverride []byte
err := q.QueryRow(ctx, `
SELECT id::text, provider_key, canonical_model_key, provider_model_name, model_type, display_name,
capabilities, base_billing_config, default_rate_limit_policy,
COALESCE(runtime_policy_set_id::text, ''), runtime_policy_override
FROM base_model_catalog
WHERE ($1 <> '' AND id = NULLIF($1, '')::uuid)
OR ($2 <> '' AND canonical_model_key = $2)
OR ($3 <> '' AND provider_model_name = $3)
ORDER BY CASE WHEN id::text = $1 THEN 0 WHEN canonical_model_key = $2 THEN 1 ELSE 2 END
LIMIT 1`, strings.TrimSpace(id), strings.TrimSpace(canonicalKey), strings.TrimSpace(modelName)).Scan(
&item.ID,
&item.ProviderKey,
&item.CanonicalModelKey,
&item.ProviderModelName,
&item.ModelType,
&item.DisplayName,
&capabilities,
&billingConfig,
&rateLimitPolicy,
&item.RuntimePolicySetID,
&runtimePolicyOverride,
)
if err != nil {
if err == pgx.ErrNoRows {
return modelCatalogSnapshot{}, err
}
return modelCatalogSnapshot{}, err
}
item.Capabilities = decodeObject(capabilities)
item.BaseBillingConfig = decodeObject(billingConfig)
item.DefaultRateLimitPolicy = decodeObject(rateLimitPolicy)
item.RuntimePolicyOverride = decodeObject(runtimePolicyOverride)
return item, nil
}
func normalizePlatformModelAlias(alias string, base modelCatalogSnapshot) string {
alias = strings.TrimSpace(alias)
if alias == "" {
alias = firstNonEmpty(base.ProviderModelName, base.DisplayName, base.CanonicalModelKey)
}
if base.ProviderKey != "" {
alias = strings.TrimPrefix(alias, base.ProviderKey+":")
}
if alias == base.CanonicalModelKey {
alias = stripAliasPrefix(alias)
}
return strings.TrimSpace(alias)
}
func stripAliasPrefix(alias string) string {
if before, after, ok := strings.Cut(alias, ":"); ok && before != "" && after != "" {
return after
}
return alias
}
func mergeObjects(base map[string]any, override map[string]any) map[string]any {
out := map[string]any{}
for key, value := range base {
out[key] = value
}
for key, value := range override {
out[key] = value
}
if len(out) == 0 {
return nil
}
return out
}
func emptyObjectIfNil(value map[string]any) map[string]any {
if value == nil {
return map[string]any{}
}
return value
}