472 lines
16 KiB
Go
472 lines
16 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"strings"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
)
|
|
|
|
const baseModelColumns = `
|
|
id::text, provider_key, canonical_model_key, provider_model_name, model_type, display_name,
|
|
capabilities, base_billing_config, default_rate_limit_policy, COALESCE(pricing_rule_set_id::text, ''),
|
|
COALESCE(runtime_policy_set_id::text, ''), runtime_policy_override, metadata,
|
|
catalog_type, COALESCE(default_snapshot, '{}'::jsonb), COALESCE(customized_at::text, ''),
|
|
pricing_version, status, created_at, updated_at`
|
|
|
|
type BaseModelInput struct {
|
|
ProviderKey string `json:"providerKey"`
|
|
CanonicalModelKey string `json:"canonicalModelKey"`
|
|
ProviderModelName string `json:"providerModelName"`
|
|
ModelType StringList `json:"modelType"`
|
|
ModelAlias string `json:"modelAlias"`
|
|
DisplayName string `json:"displayName"`
|
|
Capabilities map[string]any `json:"capabilities"`
|
|
BaseBillingConfig map[string]any `json:"baseBillingConfig"`
|
|
DefaultRateLimitPolicy map[string]any `json:"defaultRateLimitPolicy"`
|
|
PricingRuleSetID string `json:"pricingRuleSetId"`
|
|
RuntimePolicySetID string `json:"runtimePolicySetId"`
|
|
RuntimePolicyOverride map[string]any `json:"runtimePolicyOverride"`
|
|
Metadata map[string]any `json:"metadata"`
|
|
CatalogType string `json:"catalogType"`
|
|
DefaultSnapshot map[string]any `json:"defaultSnapshot"`
|
|
PricingVersion int `json:"pricingVersion"`
|
|
Status string `json:"status"`
|
|
}
|
|
|
|
type baseModelScanner interface {
|
|
Scan(dest ...any) error
|
|
}
|
|
|
|
type StringList []string
|
|
|
|
func (list *StringList) UnmarshalJSON(data []byte) error {
|
|
var values []string
|
|
if err := json.Unmarshal(data, &values); err == nil {
|
|
*list = uniqueStringList(values)
|
|
return nil
|
|
}
|
|
var value string
|
|
if err := json.Unmarshal(data, &value); err != nil {
|
|
return err
|
|
}
|
|
*list = uniqueStringList([]string{value})
|
|
return nil
|
|
}
|
|
|
|
func (s *Store) ListBaseModels(ctx context.Context) ([]BaseModel, error) {
|
|
rows, err := s.pool.Query(ctx, `
|
|
SELECT `+baseModelColumns+`
|
|
FROM base_model_catalog
|
|
ORDER BY provider_key ASC, model_type ASC, canonical_model_key ASC`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
items := make([]BaseModel, 0)
|
|
for rows.Next() {
|
|
item, err := scanBaseModel(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
items = append(items, item)
|
|
}
|
|
return items, rows.Err()
|
|
}
|
|
|
|
func (s *Store) CreateBaseModel(ctx context.Context, input BaseModelInput) (BaseModel, error) {
|
|
input = normalizeBaseModelInput(input)
|
|
capabilities, _ := json.Marshal(emptyObjectIfNil(input.Capabilities))
|
|
billingConfig, _ := json.Marshal(emptyObjectIfNil(input.BaseBillingConfig))
|
|
rateLimitPolicy, _ := json.Marshal(emptyObjectIfNil(input.DefaultRateLimitPolicy))
|
|
runtimePolicyOverride, _ := json.Marshal(emptyObjectIfNil(input.RuntimePolicyOverride))
|
|
metadata, _ := json.Marshal(emptyObjectIfNil(input.Metadata))
|
|
defaultSnapshot, _ := json.Marshal(emptyObjectIfNil(input.DefaultSnapshot))
|
|
modelType := primaryString(input.ModelType, "text_generate")
|
|
|
|
return scanBaseModel(s.pool.QueryRow(ctx, `
|
|
INSERT INTO base_model_catalog (
|
|
provider_id, provider_key, canonical_model_key, provider_model_name, model_type, display_name,
|
|
capabilities, base_billing_config, default_rate_limit_policy, pricing_rule_set_id, runtime_policy_set_id, runtime_policy_override,
|
|
metadata, catalog_type, default_snapshot, pricing_version, status
|
|
)
|
|
VALUES (
|
|
(SELECT id FROM model_catalog_providers WHERE provider_key = $1 OR provider_code = $1 LIMIT 1),
|
|
$1, $2, $3, $4, $5, $6, $7, $8,
|
|
COALESCE(NULLIF($9, '')::uuid, (SELECT id FROM model_pricing_rule_sets WHERE rule_set_key = 'default-multimodal-v1' LIMIT 1)),
|
|
COALESCE(NULLIF($10, '')::uuid, (SELECT id FROM model_runtime_policy_sets WHERE policy_key = 'default-runtime-v1' LIMIT 1)),
|
|
$11, $12, NULLIF($13, ''), NULLIF($14::jsonb, '{}'::jsonb), $15, $16
|
|
)
|
|
RETURNING `+baseModelColumns,
|
|
input.ProviderKey,
|
|
input.CanonicalModelKey,
|
|
input.ProviderModelName,
|
|
modelType,
|
|
input.ModelAlias,
|
|
capabilities,
|
|
billingConfig,
|
|
rateLimitPolicy,
|
|
input.PricingRuleSetID,
|
|
input.RuntimePolicySetID,
|
|
runtimePolicyOverride,
|
|
metadata,
|
|
input.CatalogType,
|
|
string(defaultSnapshot),
|
|
input.PricingVersion,
|
|
input.Status,
|
|
))
|
|
}
|
|
|
|
func (s *Store) UpdateBaseModel(ctx context.Context, id string, input BaseModelInput) (BaseModel, error) {
|
|
input = normalizeBaseModelInput(input)
|
|
capabilities, _ := json.Marshal(emptyObjectIfNil(input.Capabilities))
|
|
billingConfig, _ := json.Marshal(emptyObjectIfNil(input.BaseBillingConfig))
|
|
rateLimitPolicy, _ := json.Marshal(emptyObjectIfNil(input.DefaultRateLimitPolicy))
|
|
runtimePolicyOverride, _ := json.Marshal(emptyObjectIfNil(input.RuntimePolicyOverride))
|
|
metadata, _ := json.Marshal(emptyObjectIfNil(input.Metadata))
|
|
defaultSnapshot, _ := json.Marshal(emptyObjectIfNil(input.DefaultSnapshot))
|
|
modelType := primaryString(input.ModelType, "text_generate")
|
|
|
|
return scanBaseModel(s.pool.QueryRow(ctx, `
|
|
UPDATE base_model_catalog
|
|
SET provider_id = (SELECT id FROM model_catalog_providers WHERE provider_key = $2 OR provider_code = $2 LIMIT 1),
|
|
provider_key = $2,
|
|
canonical_model_key = $3,
|
|
provider_model_name = $4,
|
|
model_type = $5,
|
|
display_name = $6,
|
|
capabilities = $7,
|
|
base_billing_config = $8,
|
|
default_rate_limit_policy = $9,
|
|
pricing_rule_set_id = COALESCE(NULLIF($10, '')::uuid, (SELECT id FROM model_pricing_rule_sets WHERE rule_set_key = 'default-multimodal-v1' LIMIT 1)),
|
|
runtime_policy_set_id = COALESCE(NULLIF($11, '')::uuid, (SELECT id FROM model_runtime_policy_sets WHERE policy_key = 'default-runtime-v1' LIMIT 1)),
|
|
runtime_policy_override = $12,
|
|
metadata = $13,
|
|
catalog_type = NULLIF($14, ''),
|
|
default_snapshot = COALESCE(NULLIF($15::jsonb, '{}'::jsonb), default_snapshot),
|
|
customized_at = CASE WHEN NULLIF($14, '') = 'system' THEN now() ELSE NULL END,
|
|
pricing_version = $16,
|
|
status = $17,
|
|
updated_at = now()
|
|
WHERE id = $1::uuid
|
|
RETURNING `+baseModelColumns,
|
|
id,
|
|
input.ProviderKey,
|
|
input.CanonicalModelKey,
|
|
input.ProviderModelName,
|
|
modelType,
|
|
input.ModelAlias,
|
|
capabilities,
|
|
billingConfig,
|
|
rateLimitPolicy,
|
|
input.PricingRuleSetID,
|
|
input.RuntimePolicySetID,
|
|
runtimePolicyOverride,
|
|
metadata,
|
|
input.CatalogType,
|
|
string(defaultSnapshot),
|
|
input.PricingVersion,
|
|
input.Status,
|
|
))
|
|
}
|
|
|
|
func (s *Store) ResetBaseModelToDefault(ctx context.Context, id string) (BaseModel, error) {
|
|
var catalogType string
|
|
var snapshotBytes []byte
|
|
if err := s.pool.QueryRow(ctx, `
|
|
SELECT catalog_type, COALESCE(default_snapshot, '{}'::jsonb)
|
|
FROM base_model_catalog
|
|
WHERE id = $1::uuid`, id).Scan(&catalogType, &snapshotBytes); err != nil {
|
|
return BaseModel{}, err
|
|
}
|
|
if catalogType != "system" {
|
|
return BaseModel{}, ErrProtectedDefault
|
|
}
|
|
snapshot := decodeObject(snapshotBytes)
|
|
if len(snapshot) == 0 {
|
|
return BaseModel{}, ErrProtectedDefault
|
|
}
|
|
|
|
return scanBaseModel(s.pool.QueryRow(ctx, `
|
|
UPDATE base_model_catalog
|
|
SET provider_id = (SELECT id FROM model_catalog_providers WHERE provider_key = COALESCE($2::text, provider_key) OR provider_code = COALESCE($2::text, provider_key) LIMIT 1),
|
|
provider_key = COALESCE($2::text, provider_key),
|
|
canonical_model_key = COALESCE($3::text, canonical_model_key),
|
|
provider_model_name = COALESCE($4::text, provider_model_name),
|
|
model_type = COALESCE($5::text, model_type),
|
|
display_name = COALESCE($6::text, display_name),
|
|
capabilities = COALESCE($7::jsonb, capabilities),
|
|
base_billing_config = COALESCE($8::jsonb, base_billing_config),
|
|
default_rate_limit_policy = COALESCE($9::jsonb, default_rate_limit_policy),
|
|
pricing_rule_set_id = COALESCE(NULLIF($10::text, '')::uuid, pricing_rule_set_id),
|
|
runtime_policy_set_id = COALESCE(NULLIF($11::text, '')::uuid, runtime_policy_set_id),
|
|
runtime_policy_override = COALESCE($12::jsonb, runtime_policy_override),
|
|
metadata = COALESCE($13::jsonb, metadata),
|
|
pricing_version = COALESCE($14::integer, pricing_version),
|
|
status = COALESCE($15::text, status),
|
|
customized_at = NULL,
|
|
updated_at = now()
|
|
WHERE id = $1::uuid
|
|
RETURNING `+baseModelColumns,
|
|
id,
|
|
stringFromSnapshot(snapshot, "providerKey"),
|
|
stringFromSnapshot(snapshot, "canonicalModelKey"),
|
|
stringFromSnapshot(snapshot, "providerModelName"),
|
|
stringFromSnapshot(snapshot, "modelType"),
|
|
stringFromSnapshot(snapshot, "modelAlias", "displayName"),
|
|
jsonFromSnapshot(snapshot, "capabilities"),
|
|
jsonFromSnapshot(snapshot, "baseBillingConfig"),
|
|
jsonFromSnapshot(snapshot, "defaultRateLimitPolicy"),
|
|
stringFromSnapshot(snapshot, "pricingRuleSetId"),
|
|
stringFromSnapshot(snapshot, "runtimePolicySetId"),
|
|
jsonFromSnapshot(snapshot, "runtimePolicyOverride"),
|
|
jsonFromSnapshot(snapshot, "metadata"),
|
|
intFromSnapshot(snapshot, "pricingVersion"),
|
|
stringFromSnapshot(snapshot, "status"),
|
|
))
|
|
}
|
|
|
|
func (s *Store) ResetAllBaseModelsToDefault(ctx context.Context) ([]BaseModel, error) {
|
|
rows, err := s.pool.Query(ctx, `
|
|
UPDATE base_model_catalog
|
|
SET provider_id = (
|
|
SELECT id
|
|
FROM model_catalog_providers
|
|
WHERE provider_key = COALESCE(NULLIF(default_snapshot->>'providerKey', ''), provider_key)
|
|
OR provider_code = COALESCE(NULLIF(default_snapshot->>'providerKey', ''), provider_key)
|
|
LIMIT 1
|
|
),
|
|
provider_key = COALESCE(NULLIF(default_snapshot->>'providerKey', ''), provider_key),
|
|
canonical_model_key = COALESCE(NULLIF(default_snapshot->>'canonicalModelKey', ''), canonical_model_key),
|
|
provider_model_name = COALESCE(NULLIF(default_snapshot->>'providerModelName', ''), provider_model_name),
|
|
model_type = COALESCE(NULLIF(CASE
|
|
WHEN jsonb_typeof(default_snapshot->'modelType') = 'array' THEN default_snapshot->'modelType'->>0
|
|
ELSE default_snapshot->>'modelType'
|
|
END, ''), model_type),
|
|
display_name = COALESCE(NULLIF(COALESCE(default_snapshot->>'modelAlias', default_snapshot->>'displayName'), ''), display_name),
|
|
capabilities = COALESCE(default_snapshot->'capabilities', capabilities),
|
|
base_billing_config = COALESCE(default_snapshot->'baseBillingConfig', base_billing_config),
|
|
default_rate_limit_policy = COALESCE(default_snapshot->'defaultRateLimitPolicy', default_rate_limit_policy),
|
|
pricing_rule_set_id = COALESCE(NULLIF(default_snapshot->>'pricingRuleSetId', '')::uuid, pricing_rule_set_id),
|
|
runtime_policy_set_id = COALESCE(NULLIF(default_snapshot->>'runtimePolicySetId', '')::uuid, runtime_policy_set_id),
|
|
runtime_policy_override = COALESCE(default_snapshot->'runtimePolicyOverride', runtime_policy_override),
|
|
metadata = COALESCE(default_snapshot->'metadata', metadata),
|
|
pricing_version = COALESCE(NULLIF(default_snapshot->>'pricingVersion', '')::integer, pricing_version),
|
|
status = COALESCE(NULLIF(default_snapshot->>'status', ''), status),
|
|
customized_at = NULL,
|
|
updated_at = now()
|
|
WHERE catalog_type = 'system'
|
|
AND COALESCE(default_snapshot, '{}'::jsonb) <> '{}'::jsonb
|
|
RETURNING `+baseModelColumns)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
return scanBaseModelRows(rows)
|
|
}
|
|
|
|
func (s *Store) DeleteBaseModel(ctx context.Context, id string) error {
|
|
result, err := s.pool.Exec(ctx, `DELETE FROM base_model_catalog WHERE id = $1::uuid`, id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if result.RowsAffected() == 0 {
|
|
return pgx.ErrNoRows
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func scanBaseModelRows(rows pgx.Rows) ([]BaseModel, error) {
|
|
items := make([]BaseModel, 0)
|
|
for rows.Next() {
|
|
item, err := scanBaseModel(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
items = append(items, item)
|
|
}
|
|
return items, rows.Err()
|
|
}
|
|
|
|
func scanBaseModel(scanner baseModelScanner) (BaseModel, error) {
|
|
var item BaseModel
|
|
var modelType string
|
|
var modelAlias string
|
|
var capabilities []byte
|
|
var billingConfig []byte
|
|
var rateLimitPolicy []byte
|
|
var runtimePolicyOverride []byte
|
|
var metadata []byte
|
|
var defaultSnapshot []byte
|
|
if err := scanner.Scan(
|
|
&item.ID,
|
|
&item.ProviderKey,
|
|
&item.CanonicalModelKey,
|
|
&item.ProviderModelName,
|
|
&modelType,
|
|
&modelAlias,
|
|
&capabilities,
|
|
&billingConfig,
|
|
&rateLimitPolicy,
|
|
&item.PricingRuleSetID,
|
|
&item.RuntimePolicySetID,
|
|
&runtimePolicyOverride,
|
|
&metadata,
|
|
&item.CatalogType,
|
|
&defaultSnapshot,
|
|
&item.CustomizedAt,
|
|
&item.PricingVersion,
|
|
&item.Status,
|
|
&item.CreatedAt,
|
|
&item.UpdatedAt,
|
|
); err != nil {
|
|
return BaseModel{}, err
|
|
}
|
|
item.Capabilities = decodeObject(capabilities)
|
|
item.BaseBillingConfig = decodeObject(billingConfig)
|
|
item.DefaultRateLimitPolicy = decodeObject(rateLimitPolicy)
|
|
item.RuntimePolicyOverride = decodeObject(runtimePolicyOverride)
|
|
item.Metadata = decodeObject(metadata)
|
|
item.DefaultSnapshot = decodeObject(defaultSnapshot)
|
|
item.ModelType = baseModelTypes(item.Capabilities, item.Metadata, modelType)
|
|
item.ModelAlias = modelAlias
|
|
item.DisplayName = modelAlias
|
|
return item, nil
|
|
}
|
|
|
|
func normalizeBaseModelInput(input BaseModelInput) BaseModelInput {
|
|
input.ProviderKey = strings.TrimSpace(input.ProviderKey)
|
|
input.CanonicalModelKey = strings.TrimSpace(input.CanonicalModelKey)
|
|
input.ProviderModelName = strings.TrimSpace(input.ProviderModelName)
|
|
input.ModelType = uniqueStringList(input.ModelType)
|
|
input.ModelAlias = strings.TrimSpace(input.ModelAlias)
|
|
input.DisplayName = strings.TrimSpace(input.DisplayName)
|
|
input.PricingRuleSetID = strings.TrimSpace(input.PricingRuleSetID)
|
|
input.RuntimePolicySetID = strings.TrimSpace(input.RuntimePolicySetID)
|
|
input.CatalogType = strings.TrimSpace(input.CatalogType)
|
|
input.Status = strings.TrimSpace(input.Status)
|
|
if input.CanonicalModelKey == "" && input.ProviderKey != "" && input.ProviderModelName != "" {
|
|
input.CanonicalModelKey = input.ProviderKey + ":" + input.ProviderModelName
|
|
}
|
|
if input.ModelAlias == "" {
|
|
input.ModelAlias = input.DisplayName
|
|
}
|
|
if input.ModelAlias == "" {
|
|
input.ModelAlias = input.ProviderModelName
|
|
}
|
|
if len(input.ModelType) == 0 {
|
|
input.ModelType = StringList{"text_generate"}
|
|
}
|
|
if input.CatalogType == "" {
|
|
input.CatalogType = "custom"
|
|
}
|
|
if input.PricingVersion <= 0 {
|
|
input.PricingVersion = 1
|
|
}
|
|
if input.Status == "" {
|
|
input.Status = "active"
|
|
}
|
|
return input
|
|
}
|
|
|
|
func stringFromSnapshot(snapshot map[string]any, keys ...string) any {
|
|
for _, key := range keys {
|
|
value, ok := snapshot[key]
|
|
if !ok {
|
|
continue
|
|
}
|
|
switch typed := value.(type) {
|
|
case string:
|
|
if strings.TrimSpace(typed) != "" {
|
|
return typed
|
|
}
|
|
case []any:
|
|
for _, item := range typed {
|
|
if text, ok := item.(string); ok && strings.TrimSpace(text) != "" {
|
|
return strings.TrimSpace(text)
|
|
}
|
|
}
|
|
case []string:
|
|
if primary := primaryString(typed, ""); primary != "" {
|
|
return primary
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func intFromSnapshot(snapshot map[string]any, key string) any {
|
|
switch value := snapshot[key].(type) {
|
|
case float64:
|
|
return int(value)
|
|
case int:
|
|
return value
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func jsonFromSnapshot(snapshot map[string]any, key string) any {
|
|
value, ok := snapshot[key]
|
|
if !ok || value == nil {
|
|
return nil
|
|
}
|
|
raw, err := json.Marshal(value)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
return string(raw)
|
|
}
|
|
|
|
func baseModelTypes(capabilities map[string]any, metadata map[string]any, fallback string) StringList {
|
|
values := make([]string, 0)
|
|
values = append(values, stringListFromAny(capabilities["originalTypes"])...)
|
|
values = append(values, stringListFromAny(metadata["originalTypes"])...)
|
|
if fallback != "" {
|
|
values = append(values, fallback)
|
|
}
|
|
return uniqueStringList(values)
|
|
}
|
|
|
|
func stringListFromAny(value any) []string {
|
|
switch typed := value.(type) {
|
|
case []string:
|
|
return typed
|
|
case []any:
|
|
values := make([]string, 0, len(typed))
|
|
for _, item := range typed {
|
|
if text, ok := item.(string); ok {
|
|
values = append(values, text)
|
|
}
|
|
}
|
|
return values
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func uniqueStringList(values []string) StringList {
|
|
out := make([]string, 0, len(values))
|
|
seen := map[string]bool{}
|
|
for _, value := range values {
|
|
value = strings.TrimSpace(value)
|
|
if value == "" || seen[value] {
|
|
continue
|
|
}
|
|
seen[value] = true
|
|
out = append(out, value)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func primaryString(values []string, fallback string) string {
|
|
for _, value := range values {
|
|
if value = strings.TrimSpace(value); value != "" {
|
|
return value
|
|
}
|
|
}
|
|
return fallback
|
|
}
|