package store import ( "context" "encoding/json" "strings" "github.com/jackc/pgx/v5" ) const baseModelColumns = ` id::text, provider_key, canonical_model_key, provider_model_name, model_type, display_name, capabilities, base_billing_config, default_rate_limit_policy, COALESCE(pricing_rule_set_id::text, ''), COALESCE(runtime_policy_set_id::text, ''), runtime_policy_override, metadata, catalog_type, COALESCE(default_snapshot, '{}'::jsonb), COALESCE(customized_at::text, ''), pricing_version, status, created_at, updated_at` type BaseModelInput struct { ProviderKey string `json:"providerKey"` CanonicalModelKey string `json:"canonicalModelKey"` ProviderModelName string `json:"providerModelName"` ModelType StringList `json:"modelType"` ModelAlias string `json:"modelAlias"` DisplayName string `json:"displayName"` Capabilities map[string]any `json:"capabilities"` BaseBillingConfig map[string]any `json:"baseBillingConfig"` DefaultRateLimitPolicy map[string]any `json:"defaultRateLimitPolicy"` PricingRuleSetID string `json:"pricingRuleSetId"` RuntimePolicySetID string `json:"runtimePolicySetId"` RuntimePolicyOverride map[string]any `json:"runtimePolicyOverride"` Metadata map[string]any `json:"metadata"` CatalogType string `json:"catalogType"` DefaultSnapshot map[string]any `json:"defaultSnapshot"` PricingVersion int `json:"pricingVersion"` Status string `json:"status"` } type baseModelScanner interface { Scan(dest ...any) error } type StringList []string func (list *StringList) UnmarshalJSON(data []byte) error { var values []string if err := json.Unmarshal(data, &values); err == nil { *list = uniqueStringList(values) return nil } var value string if err := json.Unmarshal(data, &value); err != nil { return err } *list = uniqueStringList([]string{value}) return nil } func (s *Store) ListBaseModels(ctx context.Context) ([]BaseModel, error) { rows, err := s.pool.Query(ctx, ` SELECT `+baseModelColumns+` FROM base_model_catalog ORDER BY provider_key ASC, model_type ASC, canonical_model_key ASC`) if err != nil { return nil, err } defer rows.Close() items := make([]BaseModel, 0) for rows.Next() { item, err := scanBaseModel(rows) if err != nil { return nil, err } items = append(items, item) } return items, rows.Err() } func (s *Store) CreateBaseModel(ctx context.Context, input BaseModelInput) (BaseModel, error) { input = normalizeBaseModelInput(input) capabilities, _ := json.Marshal(emptyObjectIfNil(input.Capabilities)) billingConfig, _ := json.Marshal(emptyObjectIfNil(input.BaseBillingConfig)) rateLimitPolicy, _ := json.Marshal(emptyObjectIfNil(input.DefaultRateLimitPolicy)) runtimePolicyOverride, _ := json.Marshal(emptyObjectIfNil(input.RuntimePolicyOverride)) metadata, _ := json.Marshal(emptyObjectIfNil(input.Metadata)) defaultSnapshot, _ := json.Marshal(emptyObjectIfNil(input.DefaultSnapshot)) modelType := primaryString(input.ModelType, "text_generate") return scanBaseModel(s.pool.QueryRow(ctx, ` INSERT INTO base_model_catalog ( provider_id, provider_key, canonical_model_key, provider_model_name, model_type, display_name, capabilities, base_billing_config, default_rate_limit_policy, pricing_rule_set_id, runtime_policy_set_id, runtime_policy_override, metadata, catalog_type, default_snapshot, pricing_version, status ) VALUES ( (SELECT id FROM model_catalog_providers WHERE provider_key = $1 OR provider_code = $1 LIMIT 1), $1, $2, $3, $4, $5, $6, $7, $8, COALESCE(NULLIF($9, '')::uuid, (SELECT id FROM model_pricing_rule_sets WHERE rule_set_key = 'default-multimodal-v1' LIMIT 1)), COALESCE(NULLIF($10, '')::uuid, (SELECT id FROM model_runtime_policy_sets WHERE policy_key = 'default-runtime-v1' LIMIT 1)), $11, $12, NULLIF($13, ''), NULLIF($14::jsonb, '{}'::jsonb), $15, $16 ) RETURNING `+baseModelColumns, input.ProviderKey, input.CanonicalModelKey, input.ProviderModelName, modelType, input.ModelAlias, capabilities, billingConfig, rateLimitPolicy, input.PricingRuleSetID, input.RuntimePolicySetID, runtimePolicyOverride, metadata, input.CatalogType, string(defaultSnapshot), input.PricingVersion, input.Status, )) } func (s *Store) UpdateBaseModel(ctx context.Context, id string, input BaseModelInput) (BaseModel, error) { input = normalizeBaseModelInput(input) capabilities, _ := json.Marshal(emptyObjectIfNil(input.Capabilities)) billingConfig, _ := json.Marshal(emptyObjectIfNil(input.BaseBillingConfig)) rateLimitPolicy, _ := json.Marshal(emptyObjectIfNil(input.DefaultRateLimitPolicy)) runtimePolicyOverride, _ := json.Marshal(emptyObjectIfNil(input.RuntimePolicyOverride)) metadata, _ := json.Marshal(emptyObjectIfNil(input.Metadata)) defaultSnapshot, _ := json.Marshal(emptyObjectIfNil(input.DefaultSnapshot)) modelType := primaryString(input.ModelType, "text_generate") return scanBaseModel(s.pool.QueryRow(ctx, ` UPDATE base_model_catalog SET provider_id = (SELECT id FROM model_catalog_providers WHERE provider_key = $2 OR provider_code = $2 LIMIT 1), provider_key = $2, canonical_model_key = $3, provider_model_name = $4, model_type = $5, display_name = $6, capabilities = $7, base_billing_config = $8, default_rate_limit_policy = $9, pricing_rule_set_id = COALESCE(NULLIF($10, '')::uuid, (SELECT id FROM model_pricing_rule_sets WHERE rule_set_key = 'default-multimodal-v1' LIMIT 1)), runtime_policy_set_id = COALESCE(NULLIF($11, '')::uuid, (SELECT id FROM model_runtime_policy_sets WHERE policy_key = 'default-runtime-v1' LIMIT 1)), runtime_policy_override = $12, metadata = $13, catalog_type = NULLIF($14, ''), default_snapshot = COALESCE(NULLIF($15::jsonb, '{}'::jsonb), default_snapshot), customized_at = CASE WHEN NULLIF($14, '') = 'system' THEN now() ELSE NULL END, pricing_version = $16, status = $17, updated_at = now() WHERE id = $1::uuid RETURNING `+baseModelColumns, id, input.ProviderKey, input.CanonicalModelKey, input.ProviderModelName, modelType, input.ModelAlias, capabilities, billingConfig, rateLimitPolicy, input.PricingRuleSetID, input.RuntimePolicySetID, runtimePolicyOverride, metadata, input.CatalogType, string(defaultSnapshot), input.PricingVersion, input.Status, )) } func (s *Store) ResetBaseModelToDefault(ctx context.Context, id string) (BaseModel, error) { var catalogType string var snapshotBytes []byte if err := s.pool.QueryRow(ctx, ` SELECT catalog_type, COALESCE(default_snapshot, '{}'::jsonb) FROM base_model_catalog WHERE id = $1::uuid`, id).Scan(&catalogType, &snapshotBytes); err != nil { return BaseModel{}, err } if catalogType != "system" { return BaseModel{}, ErrProtectedDefault } snapshot := decodeObject(snapshotBytes) if len(snapshot) == 0 { return BaseModel{}, ErrProtectedDefault } return scanBaseModel(s.pool.QueryRow(ctx, ` UPDATE base_model_catalog SET provider_id = (SELECT id FROM model_catalog_providers WHERE provider_key = COALESCE($2::text, provider_key) OR provider_code = COALESCE($2::text, provider_key) LIMIT 1), provider_key = COALESCE($2::text, provider_key), canonical_model_key = COALESCE($3::text, canonical_model_key), provider_model_name = COALESCE($4::text, provider_model_name), model_type = COALESCE($5::text, model_type), display_name = COALESCE($6::text, display_name), capabilities = COALESCE($7::jsonb, capabilities), base_billing_config = COALESCE($8::jsonb, base_billing_config), default_rate_limit_policy = COALESCE($9::jsonb, default_rate_limit_policy), pricing_rule_set_id = COALESCE(NULLIF($10::text, '')::uuid, pricing_rule_set_id), runtime_policy_set_id = COALESCE(NULLIF($11::text, '')::uuid, runtime_policy_set_id), runtime_policy_override = COALESCE($12::jsonb, runtime_policy_override), metadata = COALESCE($13::jsonb, metadata), pricing_version = COALESCE($14::integer, pricing_version), status = COALESCE($15::text, status), customized_at = NULL, updated_at = now() WHERE id = $1::uuid RETURNING `+baseModelColumns, id, stringFromSnapshot(snapshot, "providerKey"), stringFromSnapshot(snapshot, "canonicalModelKey"), stringFromSnapshot(snapshot, "providerModelName"), stringFromSnapshot(snapshot, "modelType"), stringFromSnapshot(snapshot, "modelAlias", "displayName"), jsonFromSnapshot(snapshot, "capabilities"), jsonFromSnapshot(snapshot, "baseBillingConfig"), jsonFromSnapshot(snapshot, "defaultRateLimitPolicy"), stringFromSnapshot(snapshot, "pricingRuleSetId"), stringFromSnapshot(snapshot, "runtimePolicySetId"), jsonFromSnapshot(snapshot, "runtimePolicyOverride"), jsonFromSnapshot(snapshot, "metadata"), intFromSnapshot(snapshot, "pricingVersion"), stringFromSnapshot(snapshot, "status"), )) } func (s *Store) ResetAllBaseModelsToDefault(ctx context.Context) ([]BaseModel, error) { rows, err := s.pool.Query(ctx, ` UPDATE base_model_catalog SET provider_id = ( SELECT id FROM model_catalog_providers WHERE provider_key = COALESCE(NULLIF(default_snapshot->>'providerKey', ''), provider_key) OR provider_code = COALESCE(NULLIF(default_snapshot->>'providerKey', ''), provider_key) LIMIT 1 ), provider_key = COALESCE(NULLIF(default_snapshot->>'providerKey', ''), provider_key), canonical_model_key = COALESCE(NULLIF(default_snapshot->>'canonicalModelKey', ''), canonical_model_key), provider_model_name = COALESCE(NULLIF(default_snapshot->>'providerModelName', ''), provider_model_name), model_type = COALESCE(NULLIF(CASE WHEN jsonb_typeof(default_snapshot->'modelType') = 'array' THEN default_snapshot->'modelType'->>0 ELSE default_snapshot->>'modelType' END, ''), model_type), display_name = COALESCE(NULLIF(COALESCE(default_snapshot->>'modelAlias', default_snapshot->>'displayName'), ''), display_name), capabilities = COALESCE(default_snapshot->'capabilities', capabilities), base_billing_config = COALESCE(default_snapshot->'baseBillingConfig', base_billing_config), default_rate_limit_policy = COALESCE(default_snapshot->'defaultRateLimitPolicy', default_rate_limit_policy), pricing_rule_set_id = COALESCE(NULLIF(default_snapshot->>'pricingRuleSetId', '')::uuid, pricing_rule_set_id), runtime_policy_set_id = COALESCE(NULLIF(default_snapshot->>'runtimePolicySetId', '')::uuid, runtime_policy_set_id), runtime_policy_override = COALESCE(default_snapshot->'runtimePolicyOverride', runtime_policy_override), metadata = COALESCE(default_snapshot->'metadata', metadata), pricing_version = COALESCE(NULLIF(default_snapshot->>'pricingVersion', '')::integer, pricing_version), status = COALESCE(NULLIF(default_snapshot->>'status', ''), status), customized_at = NULL, updated_at = now() WHERE catalog_type = 'system' AND COALESCE(default_snapshot, '{}'::jsonb) <> '{}'::jsonb RETURNING `+baseModelColumns) if err != nil { return nil, err } defer rows.Close() return scanBaseModelRows(rows) } func (s *Store) DeleteBaseModel(ctx context.Context, id string) error { result, err := s.pool.Exec(ctx, `DELETE FROM base_model_catalog WHERE id = $1::uuid`, id) if err != nil { return err } if result.RowsAffected() == 0 { return pgx.ErrNoRows } return nil } func scanBaseModelRows(rows pgx.Rows) ([]BaseModel, error) { items := make([]BaseModel, 0) for rows.Next() { item, err := scanBaseModel(rows) if err != nil { return nil, err } items = append(items, item) } return items, rows.Err() } func scanBaseModel(scanner baseModelScanner) (BaseModel, error) { var item BaseModel var modelType string var modelAlias string var capabilities []byte var billingConfig []byte var rateLimitPolicy []byte var runtimePolicyOverride []byte var metadata []byte var defaultSnapshot []byte if err := scanner.Scan( &item.ID, &item.ProviderKey, &item.CanonicalModelKey, &item.ProviderModelName, &modelType, &modelAlias, &capabilities, &billingConfig, &rateLimitPolicy, &item.PricingRuleSetID, &item.RuntimePolicySetID, &runtimePolicyOverride, &metadata, &item.CatalogType, &defaultSnapshot, &item.CustomizedAt, &item.PricingVersion, &item.Status, &item.CreatedAt, &item.UpdatedAt, ); err != nil { return BaseModel{}, err } item.Capabilities = decodeObject(capabilities) item.BaseBillingConfig = decodeObject(billingConfig) item.DefaultRateLimitPolicy = decodeObject(rateLimitPolicy) item.RuntimePolicyOverride = decodeObject(runtimePolicyOverride) item.Metadata = decodeObject(metadata) item.DefaultSnapshot = decodeObject(defaultSnapshot) item.ModelType = baseModelTypes(item.Capabilities, item.Metadata, modelType) item.ModelAlias = modelAlias item.DisplayName = modelAlias return item, nil } func normalizeBaseModelInput(input BaseModelInput) BaseModelInput { input.ProviderKey = strings.TrimSpace(input.ProviderKey) input.CanonicalModelKey = strings.TrimSpace(input.CanonicalModelKey) input.ProviderModelName = strings.TrimSpace(input.ProviderModelName) input.ModelType = uniqueStringList(input.ModelType) input.ModelAlias = strings.TrimSpace(input.ModelAlias) input.DisplayName = strings.TrimSpace(input.DisplayName) input.PricingRuleSetID = strings.TrimSpace(input.PricingRuleSetID) input.RuntimePolicySetID = strings.TrimSpace(input.RuntimePolicySetID) input.CatalogType = strings.TrimSpace(input.CatalogType) input.Status = strings.TrimSpace(input.Status) if input.CanonicalModelKey == "" && input.ProviderKey != "" && input.ProviderModelName != "" { input.CanonicalModelKey = input.ProviderKey + ":" + input.ProviderModelName } if input.ModelAlias == "" { input.ModelAlias = input.DisplayName } if input.ModelAlias == "" { input.ModelAlias = input.ProviderModelName } if len(input.ModelType) == 0 { input.ModelType = StringList{"text_generate"} } if input.CatalogType == "" { input.CatalogType = "custom" } if input.PricingVersion <= 0 { input.PricingVersion = 1 } if input.Status == "" { input.Status = "active" } return input } func stringFromSnapshot(snapshot map[string]any, keys ...string) any { for _, key := range keys { value, ok := snapshot[key] if !ok { continue } switch typed := value.(type) { case string: if strings.TrimSpace(typed) != "" { return typed } case []any: for _, item := range typed { if text, ok := item.(string); ok && strings.TrimSpace(text) != "" { return strings.TrimSpace(text) } } case []string: if primary := primaryString(typed, ""); primary != "" { return primary } } } return nil } func intFromSnapshot(snapshot map[string]any, key string) any { switch value := snapshot[key].(type) { case float64: return int(value) case int: return value default: return nil } } func jsonFromSnapshot(snapshot map[string]any, key string) any { value, ok := snapshot[key] if !ok || value == nil { return nil } raw, err := json.Marshal(value) if err != nil { return nil } return string(raw) } func baseModelTypes(capabilities map[string]any, metadata map[string]any, fallback string) StringList { values := make([]string, 0) values = append(values, stringListFromAny(capabilities["originalTypes"])...) values = append(values, stringListFromAny(metadata["originalTypes"])...) if fallback != "" { values = append(values, fallback) } return uniqueStringList(values) } func stringListFromAny(value any) []string { switch typed := value.(type) { case []string: return typed case []any: values := make([]string, 0, len(typed)) for _, item := range typed { if text, ok := item.(string); ok { values = append(values, text) } } return values default: return nil } } func uniqueStringList(values []string) StringList { out := make([]string, 0, len(values)) seen := map[string]bool{} for _, value := range values { value = strings.TrimSpace(value) if value == "" || seen[value] { continue } seen[value] = true out = append(out, value) } return out } func primaryString(values []string, fallback string) string { for _, value := range values { if value = strings.TrimSpace(value); value != "" { return value } } return fallback }