easyai-ai-gateway/apps/api/internal/store/base_models.go

472 lines
16 KiB
Go

package store
import (
"context"
"encoding/json"
"strings"
"github.com/jackc/pgx/v5"
)
const baseModelColumns = `
id::text, provider_key, canonical_model_key, provider_model_name, model_type, display_name,
capabilities, base_billing_config, default_rate_limit_policy, COALESCE(pricing_rule_set_id::text, ''),
COALESCE(runtime_policy_set_id::text, ''), runtime_policy_override, metadata,
catalog_type, COALESCE(default_snapshot, '{}'::jsonb), COALESCE(customized_at::text, ''),
pricing_version, status, created_at, updated_at`
type BaseModelInput struct {
ProviderKey string `json:"providerKey"`
CanonicalModelKey string `json:"canonicalModelKey"`
ProviderModelName string `json:"providerModelName"`
ModelType StringList `json:"modelType"`
ModelAlias string `json:"modelAlias"`
DisplayName string `json:"displayName"`
Capabilities map[string]any `json:"capabilities"`
BaseBillingConfig map[string]any `json:"baseBillingConfig"`
DefaultRateLimitPolicy map[string]any `json:"defaultRateLimitPolicy"`
PricingRuleSetID string `json:"pricingRuleSetId"`
RuntimePolicySetID string `json:"runtimePolicySetId"`
RuntimePolicyOverride map[string]any `json:"runtimePolicyOverride"`
Metadata map[string]any `json:"metadata"`
CatalogType string `json:"catalogType"`
DefaultSnapshot map[string]any `json:"defaultSnapshot"`
PricingVersion int `json:"pricingVersion"`
Status string `json:"status"`
}
type baseModelScanner interface {
Scan(dest ...any) error
}
type StringList []string
func (list *StringList) UnmarshalJSON(data []byte) error {
var values []string
if err := json.Unmarshal(data, &values); err == nil {
*list = uniqueStringList(values)
return nil
}
var value string
if err := json.Unmarshal(data, &value); err != nil {
return err
}
*list = uniqueStringList([]string{value})
return nil
}
func (s *Store) ListBaseModels(ctx context.Context) ([]BaseModel, error) {
rows, err := s.pool.Query(ctx, `
SELECT `+baseModelColumns+`
FROM base_model_catalog
ORDER BY provider_key ASC, model_type ASC, canonical_model_key ASC`)
if err != nil {
return nil, err
}
defer rows.Close()
items := make([]BaseModel, 0)
for rows.Next() {
item, err := scanBaseModel(rows)
if err != nil {
return nil, err
}
items = append(items, item)
}
return items, rows.Err()
}
func (s *Store) CreateBaseModel(ctx context.Context, input BaseModelInput) (BaseModel, error) {
input = normalizeBaseModelInput(input)
capabilities, _ := json.Marshal(emptyObjectIfNil(input.Capabilities))
billingConfig, _ := json.Marshal(emptyObjectIfNil(input.BaseBillingConfig))
rateLimitPolicy, _ := json.Marshal(emptyObjectIfNil(input.DefaultRateLimitPolicy))
runtimePolicyOverride, _ := json.Marshal(emptyObjectIfNil(input.RuntimePolicyOverride))
metadata, _ := json.Marshal(emptyObjectIfNil(input.Metadata))
defaultSnapshot, _ := json.Marshal(emptyObjectIfNil(input.DefaultSnapshot))
modelType := primaryString(input.ModelType, "text_generate")
return scanBaseModel(s.pool.QueryRow(ctx, `
INSERT INTO base_model_catalog (
provider_id, provider_key, canonical_model_key, provider_model_name, model_type, display_name,
capabilities, base_billing_config, default_rate_limit_policy, pricing_rule_set_id, runtime_policy_set_id, runtime_policy_override,
metadata, catalog_type, default_snapshot, pricing_version, status
)
VALUES (
(SELECT id FROM model_catalog_providers WHERE provider_key = $1 OR provider_code = $1 LIMIT 1),
$1, $2, $3, $4, $5, $6, $7, $8,
COALESCE(NULLIF($9, '')::uuid, (SELECT id FROM model_pricing_rule_sets WHERE rule_set_key = 'default-multimodal-v1' LIMIT 1)),
COALESCE(NULLIF($10, '')::uuid, (SELECT id FROM model_runtime_policy_sets WHERE policy_key = 'default-runtime-v1' LIMIT 1)),
$11, $12, NULLIF($13, ''), NULLIF($14::jsonb, '{}'::jsonb), $15, $16
)
RETURNING `+baseModelColumns,
input.ProviderKey,
input.CanonicalModelKey,
input.ProviderModelName,
modelType,
input.ModelAlias,
capabilities,
billingConfig,
rateLimitPolicy,
input.PricingRuleSetID,
input.RuntimePolicySetID,
runtimePolicyOverride,
metadata,
input.CatalogType,
string(defaultSnapshot),
input.PricingVersion,
input.Status,
))
}
func (s *Store) UpdateBaseModel(ctx context.Context, id string, input BaseModelInput) (BaseModel, error) {
input = normalizeBaseModelInput(input)
capabilities, _ := json.Marshal(emptyObjectIfNil(input.Capabilities))
billingConfig, _ := json.Marshal(emptyObjectIfNil(input.BaseBillingConfig))
rateLimitPolicy, _ := json.Marshal(emptyObjectIfNil(input.DefaultRateLimitPolicy))
runtimePolicyOverride, _ := json.Marshal(emptyObjectIfNil(input.RuntimePolicyOverride))
metadata, _ := json.Marshal(emptyObjectIfNil(input.Metadata))
defaultSnapshot, _ := json.Marshal(emptyObjectIfNil(input.DefaultSnapshot))
modelType := primaryString(input.ModelType, "text_generate")
return scanBaseModel(s.pool.QueryRow(ctx, `
UPDATE base_model_catalog
SET provider_id = (SELECT id FROM model_catalog_providers WHERE provider_key = $2 OR provider_code = $2 LIMIT 1),
provider_key = $2,
canonical_model_key = $3,
provider_model_name = $4,
model_type = $5,
display_name = $6,
capabilities = $7,
base_billing_config = $8,
default_rate_limit_policy = $9,
pricing_rule_set_id = COALESCE(NULLIF($10, '')::uuid, (SELECT id FROM model_pricing_rule_sets WHERE rule_set_key = 'default-multimodal-v1' LIMIT 1)),
runtime_policy_set_id = COALESCE(NULLIF($11, '')::uuid, (SELECT id FROM model_runtime_policy_sets WHERE policy_key = 'default-runtime-v1' LIMIT 1)),
runtime_policy_override = $12,
metadata = $13,
catalog_type = NULLIF($14, ''),
default_snapshot = COALESCE(NULLIF($15::jsonb, '{}'::jsonb), default_snapshot),
customized_at = CASE WHEN NULLIF($14, '') = 'system' THEN now() ELSE NULL END,
pricing_version = $16,
status = $17,
updated_at = now()
WHERE id = $1::uuid
RETURNING `+baseModelColumns,
id,
input.ProviderKey,
input.CanonicalModelKey,
input.ProviderModelName,
modelType,
input.ModelAlias,
capabilities,
billingConfig,
rateLimitPolicy,
input.PricingRuleSetID,
input.RuntimePolicySetID,
runtimePolicyOverride,
metadata,
input.CatalogType,
string(defaultSnapshot),
input.PricingVersion,
input.Status,
))
}
func (s *Store) ResetBaseModelToDefault(ctx context.Context, id string) (BaseModel, error) {
var catalogType string
var snapshotBytes []byte
if err := s.pool.QueryRow(ctx, `
SELECT catalog_type, COALESCE(default_snapshot, '{}'::jsonb)
FROM base_model_catalog
WHERE id = $1::uuid`, id).Scan(&catalogType, &snapshotBytes); err != nil {
return BaseModel{}, err
}
if catalogType != "system" {
return BaseModel{}, ErrProtectedDefault
}
snapshot := decodeObject(snapshotBytes)
if len(snapshot) == 0 {
return BaseModel{}, ErrProtectedDefault
}
return scanBaseModel(s.pool.QueryRow(ctx, `
UPDATE base_model_catalog
SET provider_id = (SELECT id FROM model_catalog_providers WHERE provider_key = COALESCE($2::text, provider_key) OR provider_code = COALESCE($2::text, provider_key) LIMIT 1),
provider_key = COALESCE($2::text, provider_key),
canonical_model_key = COALESCE($3::text, canonical_model_key),
provider_model_name = COALESCE($4::text, provider_model_name),
model_type = COALESCE($5::text, model_type),
display_name = COALESCE($6::text, display_name),
capabilities = COALESCE($7::jsonb, capabilities),
base_billing_config = COALESCE($8::jsonb, base_billing_config),
default_rate_limit_policy = COALESCE($9::jsonb, default_rate_limit_policy),
pricing_rule_set_id = COALESCE(NULLIF($10::text, '')::uuid, pricing_rule_set_id),
runtime_policy_set_id = COALESCE(NULLIF($11::text, '')::uuid, runtime_policy_set_id),
runtime_policy_override = COALESCE($12::jsonb, runtime_policy_override),
metadata = COALESCE($13::jsonb, metadata),
pricing_version = COALESCE($14::integer, pricing_version),
status = COALESCE($15::text, status),
customized_at = NULL,
updated_at = now()
WHERE id = $1::uuid
RETURNING `+baseModelColumns,
id,
stringFromSnapshot(snapshot, "providerKey"),
stringFromSnapshot(snapshot, "canonicalModelKey"),
stringFromSnapshot(snapshot, "providerModelName"),
stringFromSnapshot(snapshot, "modelType"),
stringFromSnapshot(snapshot, "modelAlias", "displayName"),
jsonFromSnapshot(snapshot, "capabilities"),
jsonFromSnapshot(snapshot, "baseBillingConfig"),
jsonFromSnapshot(snapshot, "defaultRateLimitPolicy"),
stringFromSnapshot(snapshot, "pricingRuleSetId"),
stringFromSnapshot(snapshot, "runtimePolicySetId"),
jsonFromSnapshot(snapshot, "runtimePolicyOverride"),
jsonFromSnapshot(snapshot, "metadata"),
intFromSnapshot(snapshot, "pricingVersion"),
stringFromSnapshot(snapshot, "status"),
))
}
func (s *Store) ResetAllBaseModelsToDefault(ctx context.Context) ([]BaseModel, error) {
rows, err := s.pool.Query(ctx, `
UPDATE base_model_catalog
SET provider_id = (
SELECT id
FROM model_catalog_providers
WHERE provider_key = COALESCE(NULLIF(default_snapshot->>'providerKey', ''), provider_key)
OR provider_code = COALESCE(NULLIF(default_snapshot->>'providerKey', ''), provider_key)
LIMIT 1
),
provider_key = COALESCE(NULLIF(default_snapshot->>'providerKey', ''), provider_key),
canonical_model_key = COALESCE(NULLIF(default_snapshot->>'canonicalModelKey', ''), canonical_model_key),
provider_model_name = COALESCE(NULLIF(default_snapshot->>'providerModelName', ''), provider_model_name),
model_type = COALESCE(NULLIF(CASE
WHEN jsonb_typeof(default_snapshot->'modelType') = 'array' THEN default_snapshot->'modelType'->>0
ELSE default_snapshot->>'modelType'
END, ''), model_type),
display_name = COALESCE(NULLIF(COALESCE(default_snapshot->>'modelAlias', default_snapshot->>'displayName'), ''), display_name),
capabilities = COALESCE(default_snapshot->'capabilities', capabilities),
base_billing_config = COALESCE(default_snapshot->'baseBillingConfig', base_billing_config),
default_rate_limit_policy = COALESCE(default_snapshot->'defaultRateLimitPolicy', default_rate_limit_policy),
pricing_rule_set_id = COALESCE(NULLIF(default_snapshot->>'pricingRuleSetId', '')::uuid, pricing_rule_set_id),
runtime_policy_set_id = COALESCE(NULLIF(default_snapshot->>'runtimePolicySetId', '')::uuid, runtime_policy_set_id),
runtime_policy_override = COALESCE(default_snapshot->'runtimePolicyOverride', runtime_policy_override),
metadata = COALESCE(default_snapshot->'metadata', metadata),
pricing_version = COALESCE(NULLIF(default_snapshot->>'pricingVersion', '')::integer, pricing_version),
status = COALESCE(NULLIF(default_snapshot->>'status', ''), status),
customized_at = NULL,
updated_at = now()
WHERE catalog_type = 'system'
AND COALESCE(default_snapshot, '{}'::jsonb) <> '{}'::jsonb
RETURNING `+baseModelColumns)
if err != nil {
return nil, err
}
defer rows.Close()
return scanBaseModelRows(rows)
}
func (s *Store) DeleteBaseModel(ctx context.Context, id string) error {
result, err := s.pool.Exec(ctx, `DELETE FROM base_model_catalog WHERE id = $1::uuid`, id)
if err != nil {
return err
}
if result.RowsAffected() == 0 {
return pgx.ErrNoRows
}
return nil
}
func scanBaseModelRows(rows pgx.Rows) ([]BaseModel, error) {
items := make([]BaseModel, 0)
for rows.Next() {
item, err := scanBaseModel(rows)
if err != nil {
return nil, err
}
items = append(items, item)
}
return items, rows.Err()
}
func scanBaseModel(scanner baseModelScanner) (BaseModel, error) {
var item BaseModel
var modelType string
var modelAlias string
var capabilities []byte
var billingConfig []byte
var rateLimitPolicy []byte
var runtimePolicyOverride []byte
var metadata []byte
var defaultSnapshot []byte
if err := scanner.Scan(
&item.ID,
&item.ProviderKey,
&item.CanonicalModelKey,
&item.ProviderModelName,
&modelType,
&modelAlias,
&capabilities,
&billingConfig,
&rateLimitPolicy,
&item.PricingRuleSetID,
&item.RuntimePolicySetID,
&runtimePolicyOverride,
&metadata,
&item.CatalogType,
&defaultSnapshot,
&item.CustomizedAt,
&item.PricingVersion,
&item.Status,
&item.CreatedAt,
&item.UpdatedAt,
); err != nil {
return BaseModel{}, err
}
item.Capabilities = decodeObject(capabilities)
item.BaseBillingConfig = decodeObject(billingConfig)
item.DefaultRateLimitPolicy = decodeObject(rateLimitPolicy)
item.RuntimePolicyOverride = decodeObject(runtimePolicyOverride)
item.Metadata = decodeObject(metadata)
item.DefaultSnapshot = decodeObject(defaultSnapshot)
item.ModelType = baseModelTypes(item.Capabilities, item.Metadata, modelType)
item.ModelAlias = modelAlias
item.DisplayName = modelAlias
return item, nil
}
func normalizeBaseModelInput(input BaseModelInput) BaseModelInput {
input.ProviderKey = strings.TrimSpace(input.ProviderKey)
input.CanonicalModelKey = strings.TrimSpace(input.CanonicalModelKey)
input.ProviderModelName = strings.TrimSpace(input.ProviderModelName)
input.ModelType = uniqueStringList(input.ModelType)
input.ModelAlias = strings.TrimSpace(input.ModelAlias)
input.DisplayName = strings.TrimSpace(input.DisplayName)
input.PricingRuleSetID = strings.TrimSpace(input.PricingRuleSetID)
input.RuntimePolicySetID = strings.TrimSpace(input.RuntimePolicySetID)
input.CatalogType = strings.TrimSpace(input.CatalogType)
input.Status = strings.TrimSpace(input.Status)
if input.CanonicalModelKey == "" && input.ProviderKey != "" && input.ProviderModelName != "" {
input.CanonicalModelKey = input.ProviderKey + ":" + input.ProviderModelName
}
if input.ModelAlias == "" {
input.ModelAlias = input.DisplayName
}
if input.ModelAlias == "" {
input.ModelAlias = input.ProviderModelName
}
if len(input.ModelType) == 0 {
input.ModelType = StringList{"text_generate"}
}
if input.CatalogType == "" {
input.CatalogType = "custom"
}
if input.PricingVersion <= 0 {
input.PricingVersion = 1
}
if input.Status == "" {
input.Status = "active"
}
return input
}
func stringFromSnapshot(snapshot map[string]any, keys ...string) any {
for _, key := range keys {
value, ok := snapshot[key]
if !ok {
continue
}
switch typed := value.(type) {
case string:
if strings.TrimSpace(typed) != "" {
return typed
}
case []any:
for _, item := range typed {
if text, ok := item.(string); ok && strings.TrimSpace(text) != "" {
return strings.TrimSpace(text)
}
}
case []string:
if primary := primaryString(typed, ""); primary != "" {
return primary
}
}
}
return nil
}
func intFromSnapshot(snapshot map[string]any, key string) any {
switch value := snapshot[key].(type) {
case float64:
return int(value)
case int:
return value
default:
return nil
}
}
func jsonFromSnapshot(snapshot map[string]any, key string) any {
value, ok := snapshot[key]
if !ok || value == nil {
return nil
}
raw, err := json.Marshal(value)
if err != nil {
return nil
}
return string(raw)
}
func baseModelTypes(capabilities map[string]any, metadata map[string]any, fallback string) StringList {
values := make([]string, 0)
values = append(values, stringListFromAny(capabilities["originalTypes"])...)
values = append(values, stringListFromAny(metadata["originalTypes"])...)
if fallback != "" {
values = append(values, fallback)
}
return uniqueStringList(values)
}
func stringListFromAny(value any) []string {
switch typed := value.(type) {
case []string:
return typed
case []any:
values := make([]string, 0, len(typed))
for _, item := range typed {
if text, ok := item.(string); ok {
values = append(values, text)
}
}
return values
default:
return nil
}
}
func uniqueStringList(values []string) StringList {
out := make([]string, 0, len(values))
seen := map[string]bool{}
for _, value := range values {
value = strings.TrimSpace(value)
if value == "" || seen[value] {
continue
}
seen[value] = true
out = append(out, value)
}
return out
}
func primaryString(values []string, fallback string) string {
for _, value := range values {
if value = strings.TrimSpace(value); value != "" {
return value
}
}
return fallback
}