package store import ( "context" "encoding/json" "strings" "github.com/jackc/pgx/v5" ) type platformModelQuerier interface { QueryRow(ctx context.Context, sql string, args ...any) pgx.Row } type modelCatalogSnapshot struct { ID string ProviderKey string CanonicalModelKey string ProviderModelName string ModelType string DisplayName string Capabilities map[string]any BaseBillingConfig map[string]any DefaultRateLimitPolicy map[string]any RuntimePolicySetID string RuntimePolicyOverride map[string]any } func (s *Store) CreatePlatformModel(ctx context.Context, input CreatePlatformModelInput) (PlatformModel, error) { return s.createPlatformModel(ctx, s.pool, input) } func (s *Store) ReplacePlatformModels(ctx context.Context, platformID string, inputs []CreatePlatformModelInput) ([]PlatformModel, error) { tx, err := s.pool.Begin(ctx) if err != nil { return nil, err } defer tx.Rollback(ctx) keptIDs := make([]string, 0, len(inputs)) for _, input := range inputs { input.PlatformID = platformID model, err := s.createPlatformModel(ctx, tx, input) if err != nil { return nil, err } keptIDs = append(keptIDs, model.ID) } if len(keptIDs) == 0 { if _, err := tx.Exec(ctx, `DELETE FROM platform_models WHERE platform_id = $1::uuid`, platformID); err != nil { return nil, err } } else if _, err := tx.Exec(ctx, ` DELETE FROM platform_models WHERE platform_id = $1::uuid AND NOT (id::text = ANY($2::text[]))`, platformID, keptIDs); err != nil { return nil, err } if err := tx.Commit(ctx); err != nil { return nil, err } return s.ListPlatformModels(ctx, platformID) } func (s *Store) createPlatformModel(ctx context.Context, q platformModelQuerier, input CreatePlatformModelInput) (PlatformModel, error) { base, err := s.lookupBaseModel(ctx, q, input.BaseModelID, input.CanonicalModelKey, input.ModelName) if err != nil && !IsNotFound(err) { return PlatformModel{}, err } if input.ModelType == "" { input.ModelType = base.ModelType } if input.ModelName == "" { input.ModelName = base.ProviderModelName } if input.DisplayName == "" { input.DisplayName = firstNonEmpty(base.DisplayName, input.ModelName) } input.ModelAlias = normalizePlatformModelAlias(input.ModelAlias, base) if input.PricingMode == "" { input.PricingMode = "inherit_discount" } capabilities := input.Capabilities if len(capabilities) == 0 { capabilities = mergeObjects(base.Capabilities, input.CapabilityOverride) } billingConfig := input.BillingConfig if len(billingConfig) == 0 { billingConfig = mergeObjects(base.BaseBillingConfig, input.BillingConfigOverride) } rateLimitPolicy := input.RateLimitPolicy if len(rateLimitPolicy) == 0 { rateLimitPolicy = base.DefaultRateLimitPolicy } runtimePolicySetID := strings.TrimSpace(input.RuntimePolicySetID) if runtimePolicySetID == "" { runtimePolicySetID = base.RuntimePolicySetID } runtimePolicyOverride := input.RuntimePolicyOverride if len(runtimePolicyOverride) == 0 { runtimePolicyOverride = base.RuntimePolicyOverride } capabilityOverrideJSON, _ := json.Marshal(emptyObjectIfNil(input.CapabilityOverride)) capabilitiesJSON, _ := json.Marshal(emptyObjectIfNil(capabilities)) billingOverrideJSON, _ := json.Marshal(emptyObjectIfNil(input.BillingConfigOverride)) billingJSON, _ := json.Marshal(emptyObjectIfNil(billingConfig)) permissionJSON, _ := json.Marshal(emptyObjectIfNil(input.PermissionConfig)) retryJSON, _ := json.Marshal(emptyObjectIfNil(input.RetryPolicy)) rateLimitJSON, _ := json.Marshal(emptyObjectIfNil(rateLimitPolicy)) runtimePolicyOverrideJSON, _ := json.Marshal(emptyObjectIfNil(runtimePolicyOverride)) discount := any(nil) if input.DiscountFactor > 0 { discount = input.DiscountFactor } baseID := any(nil) if base.ID != "" { baseID = base.ID } var model PlatformModel var capabilityOverrideBytes []byte var capabilitiesBytes []byte var billingOverrideBytes []byte var billingBytes []byte var permissionBytes []byte var retryPolicyBytes []byte var rateLimitPolicyBytes []byte var runtimePolicyOverrideBytes []byte err = q.QueryRow(ctx, ` INSERT INTO platform_models ( platform_id, base_model_id, model_name, model_alias, model_type, display_name, capability_override, capabilities, pricing_mode, discount_factor, pricing_rule_set_id, billing_config_override, billing_config, permission_config, retry_policy, rate_limit_policy, runtime_policy_set_id, runtime_policy_override, enabled ) VALUES ( $1::uuid, $2::uuid, $3, NULLIF($4, ''), $5, $6, $7::jsonb, $8::jsonb, $9, $10::numeric, NULLIF($11, '')::uuid, $12::jsonb, $13::jsonb, $14::jsonb, $15::jsonb, $16::jsonb, NULLIF($17, '')::uuid, $18::jsonb, true ) ON CONFLICT (platform_id, model_name, model_type) DO UPDATE SET base_model_id = EXCLUDED.base_model_id, model_alias = EXCLUDED.model_alias, display_name = EXCLUDED.display_name, capability_override = EXCLUDED.capability_override, capabilities = EXCLUDED.capabilities, pricing_mode = EXCLUDED.pricing_mode, discount_factor = EXCLUDED.discount_factor, pricing_rule_set_id = EXCLUDED.pricing_rule_set_id, billing_config_override = EXCLUDED.billing_config_override, billing_config = EXCLUDED.billing_config, permission_config = EXCLUDED.permission_config, retry_policy = EXCLUDED.retry_policy, rate_limit_policy = EXCLUDED.rate_limit_policy, runtime_policy_set_id = EXCLUDED.runtime_policy_set_id, runtime_policy_override = EXCLUDED.runtime_policy_override, enabled = true, updated_at = now() RETURNING id::text, platform_id::text, COALESCE(base_model_id::text, ''), model_name, COALESCE(model_alias, ''), model_type, display_name, capability_override, capabilities, pricing_mode, COALESCE(discount_factor, 0)::float8, COALESCE(pricing_rule_set_id::text, ''), billing_config_override, billing_config, permission_config, retry_policy, rate_limit_policy, COALESCE(runtime_policy_set_id::text, ''), runtime_policy_override, enabled, created_at, updated_at`, input.PlatformID, baseID, input.ModelName, input.ModelAlias, input.ModelType, input.DisplayName, string(capabilityOverrideJSON), string(capabilitiesJSON), input.PricingMode, discount, input.PricingRuleSetID, string(billingOverrideJSON), string(billingJSON), string(permissionJSON), string(retryJSON), string(rateLimitJSON), runtimePolicySetID, string(runtimePolicyOverrideJSON), ).Scan( &model.ID, &model.PlatformID, &model.BaseModelID, &model.ModelName, &model.ModelAlias, &model.ModelType, &model.DisplayName, &capabilityOverrideBytes, &capabilitiesBytes, &model.PricingMode, &model.DiscountFactor, &model.PricingRuleSetID, &billingOverrideBytes, &billingBytes, &permissionBytes, &retryPolicyBytes, &rateLimitPolicyBytes, &model.RuntimePolicySetID, &runtimePolicyOverrideBytes, &model.Enabled, &model.CreatedAt, &model.UpdatedAt, ) if err != nil { return PlatformModel{}, err } model.CapabilityOverride = decodeObject(capabilityOverrideBytes) model.Capabilities = decodeObject(capabilitiesBytes) model.BillingConfigOverride = decodeObject(billingOverrideBytes) model.BillingConfig = decodeObject(billingBytes) model.PermissionConfig = decodeObject(permissionBytes) model.RetryPolicy = decodeObject(retryPolicyBytes) model.RateLimitPolicy = decodeObject(rateLimitPolicyBytes) model.RuntimePolicyOverride = decodeObject(runtimePolicyOverrideBytes) return model, nil } func (s *Store) DeletePlatformModel(ctx context.Context, id string) error { result, err := s.pool.Exec(ctx, `DELETE FROM platform_models WHERE id = $1::uuid`, id) if err != nil { return err } if result.RowsAffected() == 0 { return pgx.ErrNoRows } return nil } func (s *Store) lookupBaseModel(ctx context.Context, q platformModelQuerier, id string, canonicalKey string, modelName string) (modelCatalogSnapshot, error) { var item modelCatalogSnapshot var capabilities []byte var billingConfig []byte var rateLimitPolicy []byte var runtimePolicyOverride []byte err := q.QueryRow(ctx, ` SELECT id::text, provider_key, canonical_model_key, provider_model_name, model_type, display_name, capabilities, base_billing_config, default_rate_limit_policy, COALESCE(runtime_policy_set_id::text, ''), runtime_policy_override FROM base_model_catalog WHERE ($1 <> '' AND id = NULLIF($1, '')::uuid) OR ($2 <> '' AND canonical_model_key = $2) OR ($3 <> '' AND provider_model_name = $3) ORDER BY CASE WHEN id::text = $1 THEN 0 WHEN canonical_model_key = $2 THEN 1 ELSE 2 END LIMIT 1`, strings.TrimSpace(id), strings.TrimSpace(canonicalKey), strings.TrimSpace(modelName)).Scan( &item.ID, &item.ProviderKey, &item.CanonicalModelKey, &item.ProviderModelName, &item.ModelType, &item.DisplayName, &capabilities, &billingConfig, &rateLimitPolicy, &item.RuntimePolicySetID, &runtimePolicyOverride, ) if err != nil { if err == pgx.ErrNoRows { return modelCatalogSnapshot{}, err } return modelCatalogSnapshot{}, err } item.Capabilities = decodeObject(capabilities) item.BaseBillingConfig = decodeObject(billingConfig) item.DefaultRateLimitPolicy = decodeObject(rateLimitPolicy) item.RuntimePolicyOverride = decodeObject(runtimePolicyOverride) return item, nil } func normalizePlatformModelAlias(alias string, base modelCatalogSnapshot) string { alias = strings.TrimSpace(alias) if alias == "" { alias = firstNonEmpty(base.ProviderModelName, base.DisplayName, base.CanonicalModelKey) } if base.ProviderKey != "" { alias = strings.TrimPrefix(alias, base.ProviderKey+":") } if alias == base.CanonicalModelKey { alias = stripAliasPrefix(alias) } return strings.TrimSpace(alias) } func stripAliasPrefix(alias string) string { if before, after, ok := strings.Cut(alias, ":"); ok && before != "" && after != "" { return after } return alias } func mergeObjects(base map[string]any, override map[string]any) map[string]any { out := map[string]any{} for key, value := range base { out[key] = value } for key, value := range override { out[key] = value } if len(out) == 0 { return nil } return out } func emptyObjectIfNil(value map[string]any) map[string]any { if value == nil { return map[string]any{} } return value }