172 lines
5.3 KiB
Go
172 lines
5.3 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"strings"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
)
|
|
|
|
const catalogProviderColumns = `
|
|
id::text, provider_key, COALESCE(NULLIF(provider_code, ''), provider_key) AS provider_code,
|
|
display_name, provider_type, COALESCE(icon_path, '') AS icon_path,
|
|
COALESCE(default_base_url, '') AS default_base_url, COALESCE(default_auth_type, '') AS default_auth_type,
|
|
COALESCE(source, '') AS source, capability_schema, default_rate_limit_policy,
|
|
metadata, status, created_at, updated_at`
|
|
|
|
type CatalogProviderInput struct {
|
|
ProviderKey string `json:"providerKey"`
|
|
Code string `json:"code"`
|
|
DisplayName string `json:"displayName"`
|
|
ProviderType string `json:"providerType"`
|
|
IconPath string `json:"iconPath"`
|
|
DefaultBaseURL string `json:"defaultBaseUrl"`
|
|
DefaultAuthType string `json:"defaultAuthType"`
|
|
Source string `json:"source"`
|
|
CapabilitySchema map[string]any `json:"capabilitySchema"`
|
|
DefaultRateLimitPolicy map[string]any `json:"defaultRateLimitPolicy"`
|
|
Metadata map[string]any `json:"metadata"`
|
|
Status string `json:"status"`
|
|
}
|
|
|
|
type catalogProviderScanner interface {
|
|
Scan(dest ...any) error
|
|
}
|
|
|
|
func (s *Store) CreateCatalogProvider(ctx context.Context, input CatalogProviderInput) (CatalogProvider, error) {
|
|
input = normalizeCatalogProviderInput(input)
|
|
capabilitySchema, _ := json.Marshal(emptyObjectIfNil(input.CapabilitySchema))
|
|
rateLimitPolicy, _ := json.Marshal(emptyObjectIfNil(input.DefaultRateLimitPolicy))
|
|
metadata, _ := json.Marshal(emptyObjectIfNil(input.Metadata))
|
|
|
|
return scanCatalogProvider(s.pool.QueryRow(ctx, `
|
|
INSERT INTO model_catalog_providers (
|
|
provider_key, provider_code, display_name, provider_type, icon_path, default_base_url, default_auth_type, source,
|
|
capability_schema, default_rate_limit_policy, metadata, status
|
|
)
|
|
VALUES ($1, $2, $3, $4, NULLIF($5, ''), NULLIF($6, ''), $7, $8, $9, $10, $11, $12)
|
|
RETURNING `+catalogProviderColumns,
|
|
input.ProviderKey,
|
|
input.Code,
|
|
input.DisplayName,
|
|
input.ProviderType,
|
|
input.IconPath,
|
|
input.DefaultBaseURL,
|
|
input.DefaultAuthType,
|
|
input.Source,
|
|
capabilitySchema,
|
|
rateLimitPolicy,
|
|
metadata,
|
|
input.Status,
|
|
))
|
|
}
|
|
|
|
func (s *Store) UpdateCatalogProvider(ctx context.Context, id string, input CatalogProviderInput) (CatalogProvider, error) {
|
|
input = normalizeCatalogProviderInput(input)
|
|
capabilitySchema, _ := json.Marshal(emptyObjectIfNil(input.CapabilitySchema))
|
|
rateLimitPolicy, _ := json.Marshal(emptyObjectIfNil(input.DefaultRateLimitPolicy))
|
|
metadata, _ := json.Marshal(emptyObjectIfNil(input.Metadata))
|
|
|
|
return scanCatalogProvider(s.pool.QueryRow(ctx, `
|
|
UPDATE model_catalog_providers
|
|
SET provider_key = $2,
|
|
provider_code = $3,
|
|
display_name = $4,
|
|
provider_type = $5,
|
|
icon_path = NULLIF($6, ''),
|
|
default_base_url = NULLIF($7, ''),
|
|
default_auth_type = $8,
|
|
source = $9,
|
|
capability_schema = $10,
|
|
default_rate_limit_policy = $11,
|
|
metadata = $12,
|
|
status = $13,
|
|
updated_at = now()
|
|
WHERE id = $1::uuid
|
|
RETURNING `+catalogProviderColumns,
|
|
id,
|
|
input.ProviderKey,
|
|
input.Code,
|
|
input.DisplayName,
|
|
input.ProviderType,
|
|
input.IconPath,
|
|
input.DefaultBaseURL,
|
|
input.DefaultAuthType,
|
|
input.Source,
|
|
capabilitySchema,
|
|
rateLimitPolicy,
|
|
metadata,
|
|
input.Status,
|
|
))
|
|
}
|
|
|
|
func (s *Store) DeleteCatalogProvider(ctx context.Context, id string) error {
|
|
result, err := s.pool.Exec(ctx, `DELETE FROM model_catalog_providers WHERE id = $1::uuid`, id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if result.RowsAffected() == 0 {
|
|
return pgx.ErrNoRows
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func scanCatalogProvider(scanner catalogProviderScanner) (CatalogProvider, error) {
|
|
var item CatalogProvider
|
|
var capabilitySchema []byte
|
|
var rateLimitPolicy []byte
|
|
var metadata []byte
|
|
if err := scanner.Scan(
|
|
&item.ID,
|
|
&item.ProviderKey,
|
|
&item.Code,
|
|
&item.DisplayName,
|
|
&item.ProviderType,
|
|
&item.IconPath,
|
|
&item.DefaultBaseURL,
|
|
&item.DefaultAuthType,
|
|
&item.Source,
|
|
&capabilitySchema,
|
|
&rateLimitPolicy,
|
|
&metadata,
|
|
&item.Status,
|
|
&item.CreatedAt,
|
|
&item.UpdatedAt,
|
|
); err != nil {
|
|
return CatalogProvider{}, err
|
|
}
|
|
item.CapabilitySchema = decodeObject(capabilitySchema)
|
|
item.DefaultRateLimitPolicy = decodeObject(rateLimitPolicy)
|
|
item.Metadata = decodeObject(metadata)
|
|
return item, nil
|
|
}
|
|
|
|
func normalizeCatalogProviderInput(input CatalogProviderInput) CatalogProviderInput {
|
|
input.ProviderKey = strings.TrimSpace(input.ProviderKey)
|
|
input.Code = strings.TrimSpace(input.Code)
|
|
input.DisplayName = strings.TrimSpace(input.DisplayName)
|
|
input.ProviderType = strings.TrimSpace(input.ProviderType)
|
|
input.IconPath = strings.TrimSpace(input.IconPath)
|
|
input.DefaultBaseURL = strings.TrimSpace(input.DefaultBaseURL)
|
|
input.DefaultAuthType = strings.TrimSpace(input.DefaultAuthType)
|
|
input.Source = strings.TrimSpace(input.Source)
|
|
input.Status = strings.TrimSpace(input.Status)
|
|
if input.Code == "" {
|
|
input.Code = input.ProviderKey
|
|
}
|
|
if input.ProviderType == "" {
|
|
input.ProviderType = "openai"
|
|
}
|
|
if input.DefaultAuthType == "" {
|
|
input.DefaultAuthType = "APIKey"
|
|
}
|
|
if input.Source == "" {
|
|
input.Source = "gateway"
|
|
}
|
|
if input.Status == "" {
|
|
input.Status = "active"
|
|
}
|
|
return input
|
|
}
|