207 lines
6.5 KiB
Go
207 lines
6.5 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"strings"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
)
|
|
|
|
type modelCatalogSnapshot struct {
|
|
ID string
|
|
ProviderKey string
|
|
CanonicalModelKey string
|
|
ProviderModelName string
|
|
ModelType string
|
|
DisplayName string
|
|
Capabilities map[string]any
|
|
BaseBillingConfig map[string]any
|
|
DefaultRateLimitPolicy map[string]any
|
|
}
|
|
|
|
func (s *Store) CreatePlatformModel(ctx context.Context, input CreatePlatformModelInput) (PlatformModel, error) {
|
|
base, err := s.lookupBaseModel(ctx, input.BaseModelID, input.CanonicalModelKey, input.ModelName)
|
|
if err != nil && !IsNotFound(err) {
|
|
return PlatformModel{}, err
|
|
}
|
|
if input.ModelType == "" {
|
|
input.ModelType = base.ModelType
|
|
}
|
|
if input.ModelName == "" {
|
|
input.ModelName = base.ProviderModelName
|
|
}
|
|
if input.DisplayName == "" {
|
|
input.DisplayName = firstNonEmpty(base.DisplayName, input.ModelName)
|
|
}
|
|
if input.PricingMode == "" {
|
|
input.PricingMode = "inherit_discount"
|
|
}
|
|
capabilities := input.Capabilities
|
|
if len(capabilities) == 0 {
|
|
capabilities = mergeObjects(base.Capabilities, input.CapabilityOverride)
|
|
}
|
|
billingConfig := input.BillingConfig
|
|
if len(billingConfig) == 0 {
|
|
billingConfig = mergeObjects(base.BaseBillingConfig, input.BillingConfigOverride)
|
|
}
|
|
rateLimitPolicy := input.RateLimitPolicy
|
|
if len(rateLimitPolicy) == 0 {
|
|
rateLimitPolicy = base.DefaultRateLimitPolicy
|
|
}
|
|
|
|
capabilityOverrideJSON, _ := json.Marshal(emptyObjectIfNil(input.CapabilityOverride))
|
|
capabilitiesJSON, _ := json.Marshal(emptyObjectIfNil(capabilities))
|
|
billingOverrideJSON, _ := json.Marshal(emptyObjectIfNil(input.BillingConfigOverride))
|
|
billingJSON, _ := json.Marshal(emptyObjectIfNil(billingConfig))
|
|
permissionJSON, _ := json.Marshal(emptyObjectIfNil(input.PermissionConfig))
|
|
retryJSON, _ := json.Marshal(emptyObjectIfNil(input.RetryPolicy))
|
|
rateLimitJSON, _ := json.Marshal(emptyObjectIfNil(rateLimitPolicy))
|
|
|
|
discount := any(nil)
|
|
if input.DiscountFactor > 0 {
|
|
discount = input.DiscountFactor
|
|
}
|
|
baseID := any(nil)
|
|
if base.ID != "" {
|
|
baseID = base.ID
|
|
}
|
|
|
|
var model PlatformModel
|
|
var capabilityOverrideBytes []byte
|
|
var capabilitiesBytes []byte
|
|
var billingOverrideBytes []byte
|
|
var billingBytes []byte
|
|
err = s.pool.QueryRow(ctx, `
|
|
INSERT INTO platform_models (
|
|
platform_id, base_model_id, model_name, model_alias, model_type, display_name,
|
|
capability_override, capabilities, pricing_mode, discount_factor,
|
|
pricing_rule_set_id, billing_config_override, billing_config, permission_config, retry_policy, rate_limit_policy, enabled
|
|
)
|
|
VALUES (
|
|
$1::uuid, $2::uuid, $3, NULLIF($4, ''), $5, $6,
|
|
$7::jsonb, $8::jsonb, $9, $10::numeric,
|
|
NULLIF($11, '')::uuid, $12::jsonb, $13::jsonb, $14::jsonb, $15::jsonb, $16::jsonb, true
|
|
)
|
|
ON CONFLICT (platform_id, model_name, model_type) DO UPDATE
|
|
SET base_model_id = EXCLUDED.base_model_id,
|
|
model_alias = EXCLUDED.model_alias,
|
|
display_name = EXCLUDED.display_name,
|
|
capability_override = EXCLUDED.capability_override,
|
|
capabilities = EXCLUDED.capabilities,
|
|
pricing_mode = EXCLUDED.pricing_mode,
|
|
discount_factor = EXCLUDED.discount_factor,
|
|
pricing_rule_set_id = EXCLUDED.pricing_rule_set_id,
|
|
billing_config_override = EXCLUDED.billing_config_override,
|
|
billing_config = EXCLUDED.billing_config,
|
|
permission_config = EXCLUDED.permission_config,
|
|
retry_policy = EXCLUDED.retry_policy,
|
|
rate_limit_policy = EXCLUDED.rate_limit_policy,
|
|
enabled = true,
|
|
updated_at = now()
|
|
RETURNING id::text, platform_id::text, COALESCE(base_model_id::text, ''), model_name,
|
|
COALESCE(model_alias, ''), model_type, display_name, capability_override,
|
|
capabilities, pricing_mode, COALESCE(discount_factor, 0)::float8,
|
|
COALESCE(pricing_rule_set_id::text, ''), billing_config_override, billing_config, enabled, created_at, updated_at`,
|
|
input.PlatformID,
|
|
baseID,
|
|
input.ModelName,
|
|
input.ModelAlias,
|
|
input.ModelType,
|
|
input.DisplayName,
|
|
string(capabilityOverrideJSON),
|
|
string(capabilitiesJSON),
|
|
input.PricingMode,
|
|
discount,
|
|
input.PricingRuleSetID,
|
|
string(billingOverrideJSON),
|
|
string(billingJSON),
|
|
string(permissionJSON),
|
|
string(retryJSON),
|
|
string(rateLimitJSON),
|
|
).Scan(
|
|
&model.ID,
|
|
&model.PlatformID,
|
|
&model.BaseModelID,
|
|
&model.ModelName,
|
|
&model.ModelAlias,
|
|
&model.ModelType,
|
|
&model.DisplayName,
|
|
&capabilityOverrideBytes,
|
|
&capabilitiesBytes,
|
|
&model.PricingMode,
|
|
&model.DiscountFactor,
|
|
&model.PricingRuleSetID,
|
|
&billingOverrideBytes,
|
|
&billingBytes,
|
|
&model.Enabled,
|
|
&model.CreatedAt,
|
|
&model.UpdatedAt,
|
|
)
|
|
if err != nil {
|
|
return PlatformModel{}, err
|
|
}
|
|
model.CapabilityOverride = decodeObject(capabilityOverrideBytes)
|
|
model.Capabilities = decodeObject(capabilitiesBytes)
|
|
model.BillingConfigOverride = decodeObject(billingOverrideBytes)
|
|
model.BillingConfig = decodeObject(billingBytes)
|
|
return model, nil
|
|
}
|
|
|
|
func (s *Store) lookupBaseModel(ctx context.Context, id string, canonicalKey string, modelName string) (modelCatalogSnapshot, error) {
|
|
var item modelCatalogSnapshot
|
|
var capabilities []byte
|
|
var billingConfig []byte
|
|
var rateLimitPolicy []byte
|
|
err := s.pool.QueryRow(ctx, `
|
|
SELECT id::text, provider_key, canonical_model_key, provider_model_name, model_type, display_name,
|
|
capabilities, base_billing_config, default_rate_limit_policy
|
|
FROM base_model_catalog
|
|
WHERE ($1 <> '' AND id = NULLIF($1, '')::uuid)
|
|
OR ($2 <> '' AND canonical_model_key = $2)
|
|
OR ($3 <> '' AND provider_model_name = $3)
|
|
ORDER BY CASE WHEN id::text = $1 THEN 0 WHEN canonical_model_key = $2 THEN 1 ELSE 2 END
|
|
LIMIT 1`, strings.TrimSpace(id), strings.TrimSpace(canonicalKey), strings.TrimSpace(modelName)).Scan(
|
|
&item.ID,
|
|
&item.ProviderKey,
|
|
&item.CanonicalModelKey,
|
|
&item.ProviderModelName,
|
|
&item.ModelType,
|
|
&item.DisplayName,
|
|
&capabilities,
|
|
&billingConfig,
|
|
&rateLimitPolicy,
|
|
)
|
|
if err != nil {
|
|
if err == pgx.ErrNoRows {
|
|
return modelCatalogSnapshot{}, err
|
|
}
|
|
return modelCatalogSnapshot{}, err
|
|
}
|
|
item.Capabilities = decodeObject(capabilities)
|
|
item.BaseBillingConfig = decodeObject(billingConfig)
|
|
item.DefaultRateLimitPolicy = decodeObject(rateLimitPolicy)
|
|
return item, nil
|
|
}
|
|
|
|
func mergeObjects(base map[string]any, override map[string]any) map[string]any {
|
|
out := map[string]any{}
|
|
for key, value := range base {
|
|
out[key] = value
|
|
}
|
|
for key, value := range override {
|
|
out[key] = value
|
|
}
|
|
if len(out) == 0 {
|
|
return nil
|
|
}
|
|
return out
|
|
}
|
|
|
|
func emptyObjectIfNil(value map[string]any) map[string]any {
|
|
if value == nil {
|
|
return map[string]any{}
|
|
}
|
|
return value
|
|
}
|