easyai-ai-gateway/apps/api/internal/store/base_models.go

191 lines
5.7 KiB
Go

package store
import (
"context"
"encoding/json"
"strings"
"github.com/jackc/pgx/v5"
)
const baseModelColumns = `
id::text, provider_key, canonical_model_key, provider_model_name, model_type, display_name,
capabilities, base_billing_config, default_rate_limit_policy, metadata, pricing_version,
status, created_at, updated_at`
type BaseModelInput struct {
ProviderKey string `json:"providerKey"`
CanonicalModelKey string `json:"canonicalModelKey"`
ProviderModelName string `json:"providerModelName"`
ModelType string `json:"modelType"`
DisplayName string `json:"displayName"`
Capabilities map[string]any `json:"capabilities"`
BaseBillingConfig map[string]any `json:"baseBillingConfig"`
DefaultRateLimitPolicy map[string]any `json:"defaultRateLimitPolicy"`
Metadata map[string]any `json:"metadata"`
PricingVersion int `json:"pricingVersion"`
Status string `json:"status"`
}
type baseModelScanner interface {
Scan(dest ...any) error
}
func (s *Store) ListBaseModels(ctx context.Context) ([]BaseModel, error) {
rows, err := s.pool.Query(ctx, `
SELECT `+baseModelColumns+`
FROM base_model_catalog
ORDER BY provider_key ASC, model_type ASC, canonical_model_key ASC`)
if err != nil {
return nil, err
}
defer rows.Close()
items := make([]BaseModel, 0)
for rows.Next() {
item, err := scanBaseModel(rows)
if err != nil {
return nil, err
}
items = append(items, item)
}
return items, rows.Err()
}
func (s *Store) CreateBaseModel(ctx context.Context, input BaseModelInput) (BaseModel, error) {
input = normalizeBaseModelInput(input)
capabilities, _ := json.Marshal(emptyObjectIfNil(input.Capabilities))
billingConfig, _ := json.Marshal(emptyObjectIfNil(input.BaseBillingConfig))
rateLimitPolicy, _ := json.Marshal(emptyObjectIfNil(input.DefaultRateLimitPolicy))
metadata, _ := json.Marshal(emptyObjectIfNil(input.Metadata))
return scanBaseModel(s.pool.QueryRow(ctx, `
INSERT INTO base_model_catalog (
provider_id, provider_key, canonical_model_key, provider_model_name, model_type, display_name,
capabilities, base_billing_config, default_rate_limit_policy, metadata, pricing_version, status
)
VALUES (
(SELECT id FROM model_catalog_providers WHERE provider_key = $1 OR provider_code = $1 LIMIT 1),
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11
)
RETURNING `+baseModelColumns,
input.ProviderKey,
input.CanonicalModelKey,
input.ProviderModelName,
input.ModelType,
input.DisplayName,
capabilities,
billingConfig,
rateLimitPolicy,
metadata,
input.PricingVersion,
input.Status,
))
}
func (s *Store) UpdateBaseModel(ctx context.Context, id string, input BaseModelInput) (BaseModel, error) {
input = normalizeBaseModelInput(input)
capabilities, _ := json.Marshal(emptyObjectIfNil(input.Capabilities))
billingConfig, _ := json.Marshal(emptyObjectIfNil(input.BaseBillingConfig))
rateLimitPolicy, _ := json.Marshal(emptyObjectIfNil(input.DefaultRateLimitPolicy))
metadata, _ := json.Marshal(emptyObjectIfNil(input.Metadata))
return scanBaseModel(s.pool.QueryRow(ctx, `
UPDATE base_model_catalog
SET provider_id = (SELECT id FROM model_catalog_providers WHERE provider_key = $2 OR provider_code = $2 LIMIT 1),
provider_key = $2,
canonical_model_key = $3,
provider_model_name = $4,
model_type = $5,
display_name = $6,
capabilities = $7,
base_billing_config = $8,
default_rate_limit_policy = $9,
metadata = $10,
pricing_version = $11,
status = $12,
updated_at = now()
WHERE id = $1::uuid
RETURNING `+baseModelColumns,
id,
input.ProviderKey,
input.CanonicalModelKey,
input.ProviderModelName,
input.ModelType,
input.DisplayName,
capabilities,
billingConfig,
rateLimitPolicy,
metadata,
input.PricingVersion,
input.Status,
))
}
func (s *Store) DeleteBaseModel(ctx context.Context, id string) error {
result, err := s.pool.Exec(ctx, `DELETE FROM base_model_catalog WHERE id = $1::uuid`, id)
if err != nil {
return err
}
if result.RowsAffected() == 0 {
return pgx.ErrNoRows
}
return nil
}
func scanBaseModel(scanner baseModelScanner) (BaseModel, error) {
var item BaseModel
var capabilities []byte
var billingConfig []byte
var rateLimitPolicy []byte
var metadata []byte
if err := scanner.Scan(
&item.ID,
&item.ProviderKey,
&item.CanonicalModelKey,
&item.ProviderModelName,
&item.ModelType,
&item.DisplayName,
&capabilities,
&billingConfig,
&rateLimitPolicy,
&metadata,
&item.PricingVersion,
&item.Status,
&item.CreatedAt,
&item.UpdatedAt,
); err != nil {
return BaseModel{}, err
}
item.Capabilities = decodeObject(capabilities)
item.BaseBillingConfig = decodeObject(billingConfig)
item.DefaultRateLimitPolicy = decodeObject(rateLimitPolicy)
item.Metadata = decodeObject(metadata)
return item, nil
}
func normalizeBaseModelInput(input BaseModelInput) BaseModelInput {
input.ProviderKey = strings.TrimSpace(input.ProviderKey)
input.CanonicalModelKey = strings.TrimSpace(input.CanonicalModelKey)
input.ProviderModelName = strings.TrimSpace(input.ProviderModelName)
input.ModelType = strings.TrimSpace(input.ModelType)
input.DisplayName = strings.TrimSpace(input.DisplayName)
input.Status = strings.TrimSpace(input.Status)
if input.CanonicalModelKey == "" && input.ProviderKey != "" && input.ProviderModelName != "" {
input.CanonicalModelKey = input.ProviderKey + ":" + input.ProviderModelName
}
if input.DisplayName == "" {
input.DisplayName = input.ProviderModelName
}
if input.ModelType == "" {
input.ModelType = "text_generate"
}
if input.PricingVersion <= 0 {
input.PricingVersion = 1
}
if input.Status == "" {
input.Status = "active"
}
return input
}