344 lines
11 KiB
Go
344 lines
11 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"strings"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
)
|
|
|
|
type platformModelQuerier interface {
|
|
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
|
|
}
|
|
|
|
type modelCatalogSnapshot struct {
|
|
ID string
|
|
ProviderKey string
|
|
CanonicalModelKey string
|
|
ProviderModelName string
|
|
ModelType string
|
|
DisplayName string
|
|
Capabilities map[string]any
|
|
BaseBillingConfig map[string]any
|
|
DefaultRateLimitPolicy map[string]any
|
|
RuntimePolicySetID string
|
|
RuntimePolicyOverride map[string]any
|
|
}
|
|
|
|
func (s *Store) CreatePlatformModel(ctx context.Context, input CreatePlatformModelInput) (PlatformModel, error) {
|
|
return s.createPlatformModel(ctx, s.pool, input)
|
|
}
|
|
|
|
func (s *Store) ReplacePlatformModels(ctx context.Context, platformID string, inputs []CreatePlatformModelInput) ([]PlatformModel, error) {
|
|
tx, err := s.pool.Begin(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer tx.Rollback(ctx)
|
|
|
|
keptIDs := make([]string, 0, len(inputs))
|
|
for _, input := range inputs {
|
|
input.PlatformID = platformID
|
|
model, err := s.createPlatformModel(ctx, tx, input)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
keptIDs = append(keptIDs, model.ID)
|
|
}
|
|
|
|
if len(keptIDs) == 0 {
|
|
if _, err := tx.Exec(ctx, `
|
|
WITH deleted AS (
|
|
DELETE FROM platform_models
|
|
WHERE platform_id = $1::uuid
|
|
RETURNING id
|
|
)
|
|
DELETE FROM gateway_access_rules
|
|
WHERE resource_type = 'platform_model'
|
|
AND resource_id IN (SELECT id FROM deleted)`, platformID); err != nil {
|
|
return nil, err
|
|
}
|
|
} else if _, err := tx.Exec(ctx, `
|
|
WITH deleted AS (
|
|
DELETE FROM platform_models
|
|
WHERE platform_id = $1::uuid
|
|
AND NOT (id::text = ANY($2::text[]))
|
|
RETURNING id
|
|
)
|
|
DELETE FROM gateway_access_rules
|
|
WHERE resource_type = 'platform_model'
|
|
AND resource_id IN (SELECT id FROM deleted)`, platformID, keptIDs); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := tx.Commit(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
return s.ListPlatformModels(ctx, platformID)
|
|
}
|
|
|
|
func (s *Store) createPlatformModel(ctx context.Context, q platformModelQuerier, input CreatePlatformModelInput) (PlatformModel, error) {
|
|
base, err := s.lookupBaseModel(ctx, q, input.BaseModelID, input.CanonicalModelKey, input.ModelName)
|
|
if err != nil && !IsNotFound(err) {
|
|
return PlatformModel{}, err
|
|
}
|
|
if input.ModelType == "" {
|
|
input.ModelType = base.ModelType
|
|
}
|
|
if input.ModelName == "" {
|
|
input.ModelName = base.ProviderModelName
|
|
}
|
|
if input.DisplayName == "" {
|
|
input.DisplayName = firstNonEmpty(base.DisplayName, input.ModelName)
|
|
}
|
|
input.ModelAlias = normalizePlatformModelAlias(input.ModelAlias, base)
|
|
if input.PricingMode == "" {
|
|
input.PricingMode = "inherit_discount"
|
|
}
|
|
capabilities := input.Capabilities
|
|
if len(capabilities) == 0 {
|
|
capabilities = mergeObjects(base.Capabilities, input.CapabilityOverride)
|
|
}
|
|
billingConfig := input.BillingConfig
|
|
if len(billingConfig) == 0 {
|
|
billingConfig = mergeObjects(base.BaseBillingConfig, input.BillingConfigOverride)
|
|
}
|
|
rateLimitPolicy := input.RateLimitPolicy
|
|
if len(rateLimitPolicy) == 0 {
|
|
rateLimitPolicy = base.DefaultRateLimitPolicy
|
|
}
|
|
runtimePolicySetID := strings.TrimSpace(input.RuntimePolicySetID)
|
|
if runtimePolicySetID == "" {
|
|
runtimePolicySetID = base.RuntimePolicySetID
|
|
}
|
|
runtimePolicyOverride := input.RuntimePolicyOverride
|
|
if len(runtimePolicyOverride) == 0 {
|
|
runtimePolicyOverride = base.RuntimePolicyOverride
|
|
}
|
|
|
|
capabilityOverrideJSON, _ := json.Marshal(emptyObjectIfNil(input.CapabilityOverride))
|
|
capabilitiesJSON, _ := json.Marshal(emptyObjectIfNil(capabilities))
|
|
billingOverrideJSON, _ := json.Marshal(emptyObjectIfNil(input.BillingConfigOverride))
|
|
billingJSON, _ := json.Marshal(emptyObjectIfNil(billingConfig))
|
|
permissionJSON, _ := json.Marshal(emptyObjectIfNil(input.PermissionConfig))
|
|
retryJSON, _ := json.Marshal(emptyObjectIfNil(input.RetryPolicy))
|
|
rateLimitJSON, _ := json.Marshal(emptyObjectIfNil(rateLimitPolicy))
|
|
runtimePolicyOverrideJSON, _ := json.Marshal(emptyObjectIfNil(runtimePolicyOverride))
|
|
|
|
discount := any(nil)
|
|
if input.DiscountFactor > 0 {
|
|
discount = input.DiscountFactor
|
|
}
|
|
baseID := any(nil)
|
|
if base.ID != "" {
|
|
baseID = base.ID
|
|
}
|
|
|
|
var model PlatformModel
|
|
var capabilityOverrideBytes []byte
|
|
var capabilitiesBytes []byte
|
|
var billingOverrideBytes []byte
|
|
var billingBytes []byte
|
|
var permissionBytes []byte
|
|
var retryPolicyBytes []byte
|
|
var rateLimitPolicyBytes []byte
|
|
var runtimePolicyOverrideBytes []byte
|
|
err = q.QueryRow(ctx, `
|
|
INSERT INTO platform_models (
|
|
platform_id, base_model_id, model_name, model_alias, model_type, display_name,
|
|
capability_override, capabilities, pricing_mode, discount_factor,
|
|
pricing_rule_set_id, billing_config_override, billing_config, permission_config, retry_policy, rate_limit_policy,
|
|
runtime_policy_set_id, runtime_policy_override, enabled
|
|
)
|
|
VALUES (
|
|
$1::uuid, $2::uuid, $3, NULLIF($4, ''), $5, $6,
|
|
$7::jsonb, $8::jsonb, $9, $10::numeric,
|
|
NULLIF($11, '')::uuid, $12::jsonb, $13::jsonb, $14::jsonb, $15::jsonb, $16::jsonb,
|
|
NULLIF($17, '')::uuid, $18::jsonb, true
|
|
)
|
|
ON CONFLICT (platform_id, model_name, model_type) DO UPDATE
|
|
SET base_model_id = EXCLUDED.base_model_id,
|
|
model_alias = EXCLUDED.model_alias,
|
|
display_name = EXCLUDED.display_name,
|
|
capability_override = EXCLUDED.capability_override,
|
|
capabilities = EXCLUDED.capabilities,
|
|
pricing_mode = EXCLUDED.pricing_mode,
|
|
discount_factor = EXCLUDED.discount_factor,
|
|
pricing_rule_set_id = EXCLUDED.pricing_rule_set_id,
|
|
billing_config_override = EXCLUDED.billing_config_override,
|
|
billing_config = EXCLUDED.billing_config,
|
|
permission_config = EXCLUDED.permission_config,
|
|
retry_policy = EXCLUDED.retry_policy,
|
|
rate_limit_policy = EXCLUDED.rate_limit_policy,
|
|
runtime_policy_set_id = EXCLUDED.runtime_policy_set_id,
|
|
runtime_policy_override = EXCLUDED.runtime_policy_override,
|
|
enabled = true,
|
|
updated_at = now()
|
|
RETURNING id::text, platform_id::text, COALESCE(base_model_id::text, ''), model_name,
|
|
COALESCE(model_alias, ''), model_type, display_name, capability_override,
|
|
capabilities, pricing_mode, COALESCE(discount_factor, 0)::float8,
|
|
COALESCE(pricing_rule_set_id::text, ''), billing_config_override, billing_config,
|
|
permission_config, retry_policy, rate_limit_policy, COALESCE(runtime_policy_set_id::text, ''), runtime_policy_override,
|
|
enabled, created_at, updated_at`,
|
|
input.PlatformID,
|
|
baseID,
|
|
input.ModelName,
|
|
input.ModelAlias,
|
|
input.ModelType,
|
|
input.DisplayName,
|
|
string(capabilityOverrideJSON),
|
|
string(capabilitiesJSON),
|
|
input.PricingMode,
|
|
discount,
|
|
input.PricingRuleSetID,
|
|
string(billingOverrideJSON),
|
|
string(billingJSON),
|
|
string(permissionJSON),
|
|
string(retryJSON),
|
|
string(rateLimitJSON),
|
|
runtimePolicySetID,
|
|
string(runtimePolicyOverrideJSON),
|
|
).Scan(
|
|
&model.ID,
|
|
&model.PlatformID,
|
|
&model.BaseModelID,
|
|
&model.ModelName,
|
|
&model.ModelAlias,
|
|
&model.ModelType,
|
|
&model.DisplayName,
|
|
&capabilityOverrideBytes,
|
|
&capabilitiesBytes,
|
|
&model.PricingMode,
|
|
&model.DiscountFactor,
|
|
&model.PricingRuleSetID,
|
|
&billingOverrideBytes,
|
|
&billingBytes,
|
|
&permissionBytes,
|
|
&retryPolicyBytes,
|
|
&rateLimitPolicyBytes,
|
|
&model.RuntimePolicySetID,
|
|
&runtimePolicyOverrideBytes,
|
|
&model.Enabled,
|
|
&model.CreatedAt,
|
|
&model.UpdatedAt,
|
|
)
|
|
if err != nil {
|
|
return PlatformModel{}, err
|
|
}
|
|
model.CapabilityOverride = decodeObject(capabilityOverrideBytes)
|
|
model.Capabilities = decodeObject(capabilitiesBytes)
|
|
model.BillingConfigOverride = decodeObject(billingOverrideBytes)
|
|
model.BillingConfig = decodeObject(billingBytes)
|
|
model.PermissionConfig = decodeObject(permissionBytes)
|
|
model.RetryPolicy = decodeObject(retryPolicyBytes)
|
|
model.RateLimitPolicy = decodeObject(rateLimitPolicyBytes)
|
|
model.RuntimePolicyOverride = decodeObject(runtimePolicyOverrideBytes)
|
|
return model, nil
|
|
}
|
|
|
|
func (s *Store) DeletePlatformModel(ctx context.Context, id string) error {
|
|
tx, err := s.pool.Begin(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback(ctx)
|
|
|
|
result, err := tx.Exec(ctx, `DELETE FROM platform_models WHERE id = $1::uuid`, id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if result.RowsAffected() == 0 {
|
|
return pgx.ErrNoRows
|
|
}
|
|
if _, err := tx.Exec(ctx, `
|
|
DELETE FROM gateway_access_rules
|
|
WHERE resource_type = 'platform_model' AND resource_id = $1::uuid`, id); err != nil {
|
|
return err
|
|
}
|
|
return tx.Commit(ctx)
|
|
}
|
|
|
|
func (s *Store) lookupBaseModel(ctx context.Context, q platformModelQuerier, id string, canonicalKey string, modelName string) (modelCatalogSnapshot, error) {
|
|
var item modelCatalogSnapshot
|
|
var capabilities []byte
|
|
var billingConfig []byte
|
|
var rateLimitPolicy []byte
|
|
var runtimePolicyOverride []byte
|
|
err := q.QueryRow(ctx, `
|
|
SELECT id::text, provider_key, canonical_model_key, provider_model_name, model_type, display_name,
|
|
capabilities, base_billing_config, default_rate_limit_policy,
|
|
COALESCE(runtime_policy_set_id::text, ''), runtime_policy_override
|
|
FROM base_model_catalog
|
|
WHERE ($1 <> '' AND id = NULLIF($1, '')::uuid)
|
|
OR ($2 <> '' AND canonical_model_key = $2)
|
|
OR ($3 <> '' AND provider_model_name = $3)
|
|
ORDER BY CASE WHEN id::text = $1 THEN 0 WHEN canonical_model_key = $2 THEN 1 ELSE 2 END
|
|
LIMIT 1`, strings.TrimSpace(id), strings.TrimSpace(canonicalKey), strings.TrimSpace(modelName)).Scan(
|
|
&item.ID,
|
|
&item.ProviderKey,
|
|
&item.CanonicalModelKey,
|
|
&item.ProviderModelName,
|
|
&item.ModelType,
|
|
&item.DisplayName,
|
|
&capabilities,
|
|
&billingConfig,
|
|
&rateLimitPolicy,
|
|
&item.RuntimePolicySetID,
|
|
&runtimePolicyOverride,
|
|
)
|
|
if err != nil {
|
|
if err == pgx.ErrNoRows {
|
|
return modelCatalogSnapshot{}, err
|
|
}
|
|
return modelCatalogSnapshot{}, err
|
|
}
|
|
item.Capabilities = decodeObject(capabilities)
|
|
item.BaseBillingConfig = decodeObject(billingConfig)
|
|
item.DefaultRateLimitPolicy = decodeObject(rateLimitPolicy)
|
|
item.RuntimePolicyOverride = decodeObject(runtimePolicyOverride)
|
|
return item, nil
|
|
}
|
|
|
|
func normalizePlatformModelAlias(alias string, base modelCatalogSnapshot) string {
|
|
alias = strings.TrimSpace(alias)
|
|
if alias == "" {
|
|
alias = firstNonEmpty(base.ProviderModelName, base.DisplayName, base.CanonicalModelKey)
|
|
}
|
|
if base.ProviderKey != "" {
|
|
alias = strings.TrimPrefix(alias, base.ProviderKey+":")
|
|
}
|
|
if alias == base.CanonicalModelKey {
|
|
alias = stripAliasPrefix(alias)
|
|
}
|
|
return strings.TrimSpace(alias)
|
|
}
|
|
|
|
func stripAliasPrefix(alias string) string {
|
|
if before, after, ok := strings.Cut(alias, ":"); ok && before != "" && after != "" {
|
|
return after
|
|
}
|
|
return alias
|
|
}
|
|
|
|
func mergeObjects(base map[string]any, override map[string]any) map[string]any {
|
|
out := map[string]any{}
|
|
for key, value := range base {
|
|
out[key] = value
|
|
}
|
|
for key, value := range override {
|
|
out[key] = value
|
|
}
|
|
if len(out) == 0 {
|
|
return nil
|
|
}
|
|
return out
|
|
}
|
|
|
|
func emptyObjectIfNil(value map[string]any) map[string]any {
|
|
if value == nil {
|
|
return map[string]any{}
|
|
}
|
|
return value
|
|
}
|