191 lines
5.7 KiB
Go
191 lines
5.7 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"strings"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
)
|
|
|
|
const baseModelColumns = `
|
|
id::text, provider_key, canonical_model_key, provider_model_name, model_type, display_name,
|
|
capabilities, base_billing_config, default_rate_limit_policy, metadata, pricing_version,
|
|
status, created_at, updated_at`
|
|
|
|
type BaseModelInput struct {
|
|
ProviderKey string `json:"providerKey"`
|
|
CanonicalModelKey string `json:"canonicalModelKey"`
|
|
ProviderModelName string `json:"providerModelName"`
|
|
ModelType string `json:"modelType"`
|
|
DisplayName string `json:"displayName"`
|
|
Capabilities map[string]any `json:"capabilities"`
|
|
BaseBillingConfig map[string]any `json:"baseBillingConfig"`
|
|
DefaultRateLimitPolicy map[string]any `json:"defaultRateLimitPolicy"`
|
|
Metadata map[string]any `json:"metadata"`
|
|
PricingVersion int `json:"pricingVersion"`
|
|
Status string `json:"status"`
|
|
}
|
|
|
|
type baseModelScanner interface {
|
|
Scan(dest ...any) error
|
|
}
|
|
|
|
func (s *Store) ListBaseModels(ctx context.Context) ([]BaseModel, error) {
|
|
rows, err := s.pool.Query(ctx, `
|
|
SELECT `+baseModelColumns+`
|
|
FROM base_model_catalog
|
|
ORDER BY provider_key ASC, model_type ASC, canonical_model_key ASC`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
items := make([]BaseModel, 0)
|
|
for rows.Next() {
|
|
item, err := scanBaseModel(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
items = append(items, item)
|
|
}
|
|
return items, rows.Err()
|
|
}
|
|
|
|
func (s *Store) CreateBaseModel(ctx context.Context, input BaseModelInput) (BaseModel, error) {
|
|
input = normalizeBaseModelInput(input)
|
|
capabilities, _ := json.Marshal(emptyObjectIfNil(input.Capabilities))
|
|
billingConfig, _ := json.Marshal(emptyObjectIfNil(input.BaseBillingConfig))
|
|
rateLimitPolicy, _ := json.Marshal(emptyObjectIfNil(input.DefaultRateLimitPolicy))
|
|
metadata, _ := json.Marshal(emptyObjectIfNil(input.Metadata))
|
|
|
|
return scanBaseModel(s.pool.QueryRow(ctx, `
|
|
INSERT INTO base_model_catalog (
|
|
provider_id, provider_key, canonical_model_key, provider_model_name, model_type, display_name,
|
|
capabilities, base_billing_config, default_rate_limit_policy, metadata, pricing_version, status
|
|
)
|
|
VALUES (
|
|
(SELECT id FROM model_catalog_providers WHERE provider_key = $1 OR provider_code = $1 LIMIT 1),
|
|
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11
|
|
)
|
|
RETURNING `+baseModelColumns,
|
|
input.ProviderKey,
|
|
input.CanonicalModelKey,
|
|
input.ProviderModelName,
|
|
input.ModelType,
|
|
input.DisplayName,
|
|
capabilities,
|
|
billingConfig,
|
|
rateLimitPolicy,
|
|
metadata,
|
|
input.PricingVersion,
|
|
input.Status,
|
|
))
|
|
}
|
|
|
|
func (s *Store) UpdateBaseModel(ctx context.Context, id string, input BaseModelInput) (BaseModel, error) {
|
|
input = normalizeBaseModelInput(input)
|
|
capabilities, _ := json.Marshal(emptyObjectIfNil(input.Capabilities))
|
|
billingConfig, _ := json.Marshal(emptyObjectIfNil(input.BaseBillingConfig))
|
|
rateLimitPolicy, _ := json.Marshal(emptyObjectIfNil(input.DefaultRateLimitPolicy))
|
|
metadata, _ := json.Marshal(emptyObjectIfNil(input.Metadata))
|
|
|
|
return scanBaseModel(s.pool.QueryRow(ctx, `
|
|
UPDATE base_model_catalog
|
|
SET provider_id = (SELECT id FROM model_catalog_providers WHERE provider_key = $2 OR provider_code = $2 LIMIT 1),
|
|
provider_key = $2,
|
|
canonical_model_key = $3,
|
|
provider_model_name = $4,
|
|
model_type = $5,
|
|
display_name = $6,
|
|
capabilities = $7,
|
|
base_billing_config = $8,
|
|
default_rate_limit_policy = $9,
|
|
metadata = $10,
|
|
pricing_version = $11,
|
|
status = $12,
|
|
updated_at = now()
|
|
WHERE id = $1::uuid
|
|
RETURNING `+baseModelColumns,
|
|
id,
|
|
input.ProviderKey,
|
|
input.CanonicalModelKey,
|
|
input.ProviderModelName,
|
|
input.ModelType,
|
|
input.DisplayName,
|
|
capabilities,
|
|
billingConfig,
|
|
rateLimitPolicy,
|
|
metadata,
|
|
input.PricingVersion,
|
|
input.Status,
|
|
))
|
|
}
|
|
|
|
func (s *Store) DeleteBaseModel(ctx context.Context, id string) error {
|
|
result, err := s.pool.Exec(ctx, `DELETE FROM base_model_catalog WHERE id = $1::uuid`, id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if result.RowsAffected() == 0 {
|
|
return pgx.ErrNoRows
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func scanBaseModel(scanner baseModelScanner) (BaseModel, error) {
|
|
var item BaseModel
|
|
var capabilities []byte
|
|
var billingConfig []byte
|
|
var rateLimitPolicy []byte
|
|
var metadata []byte
|
|
if err := scanner.Scan(
|
|
&item.ID,
|
|
&item.ProviderKey,
|
|
&item.CanonicalModelKey,
|
|
&item.ProviderModelName,
|
|
&item.ModelType,
|
|
&item.DisplayName,
|
|
&capabilities,
|
|
&billingConfig,
|
|
&rateLimitPolicy,
|
|
&metadata,
|
|
&item.PricingVersion,
|
|
&item.Status,
|
|
&item.CreatedAt,
|
|
&item.UpdatedAt,
|
|
); err != nil {
|
|
return BaseModel{}, err
|
|
}
|
|
item.Capabilities = decodeObject(capabilities)
|
|
item.BaseBillingConfig = decodeObject(billingConfig)
|
|
item.DefaultRateLimitPolicy = decodeObject(rateLimitPolicy)
|
|
item.Metadata = decodeObject(metadata)
|
|
return item, nil
|
|
}
|
|
|
|
func normalizeBaseModelInput(input BaseModelInput) BaseModelInput {
|
|
input.ProviderKey = strings.TrimSpace(input.ProviderKey)
|
|
input.CanonicalModelKey = strings.TrimSpace(input.CanonicalModelKey)
|
|
input.ProviderModelName = strings.TrimSpace(input.ProviderModelName)
|
|
input.ModelType = strings.TrimSpace(input.ModelType)
|
|
input.DisplayName = strings.TrimSpace(input.DisplayName)
|
|
input.Status = strings.TrimSpace(input.Status)
|
|
if input.CanonicalModelKey == "" && input.ProviderKey != "" && input.ProviderModelName != "" {
|
|
input.CanonicalModelKey = input.ProviderKey + ":" + input.ProviderModelName
|
|
}
|
|
if input.DisplayName == "" {
|
|
input.DisplayName = input.ProviderModelName
|
|
}
|
|
if input.ModelType == "" {
|
|
input.ModelType = "text_generate"
|
|
}
|
|
if input.PricingVersion <= 0 {
|
|
input.PricingVersion = 1
|
|
}
|
|
if input.Status == "" {
|
|
input.Status = "active"
|
|
}
|
|
return input
|
|
}
|