easyai-ai-gateway/apps/api/internal/httpapi/model_catalog_test.go

222 lines
8.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package httpapi
import (
"testing"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
func TestBuildModelCatalogAggregatesSources(t *testing.T) {
models := []store.PlatformModel{
{
ID: "model-a",
PlatformID: "platform-a",
ModelName: "seedance",
ModelAlias: "Seedance-2.0",
ModelType: store.StringList{"image_generate"},
DisplayName: "Seedance Source A",
BillingConfig: map[string]any{
"image": map[string]any{"basePrice": float64(10), "dynamicWeight": map[string]any{"1K": float64(1), "2K": float64(2)}},
},
RateLimitPolicy: map[string]any{
"platformLimits": map[string]any{
"max_request_per_minute": 60,
"max_token_per_minute": 1000,
"max_concurrent_requests": 2,
},
},
PricingMode: "inherit_discount",
Enabled: true,
},
{
ID: "model-b",
PlatformID: "platform-b",
ModelName: "seedance",
ModelAlias: "Seedance-2.0",
ModelType: store.StringList{"image_generate"},
DisplayName: "Seedance Source B",
BillingConfig: map[string]any{
"image": map[string]any{"basePrice": float64(10), "dynamicWeight": map[string]any{"1K": float64(1), "2K": float64(2)}},
},
RateLimitPolicy: map[string]any{
"rpm": 40,
"tpm": 2000,
"concurrent": 3,
},
DiscountFactor: 0.8,
PricingMode: "custom",
Enabled: true,
},
}
platforms := []store.Platform{
{ID: "platform-a", Provider: "volces", Name: "火山引擎", Status: "enabled", Priority: 20, DefaultDiscountFactor: 1},
{ID: "platform-b", Provider: "gemini", Name: "Gemini", Status: "enabled", Priority: 10, DefaultDiscountFactor: 1},
}
providers := []store.CatalogProvider{
{ProviderKey: "volces", DisplayName: "火山引擎", IconPath: "volces.png"},
{ProviderKey: "gemini", DisplayName: "Google Gemini", IconPath: "gemini.png"},
}
accessRules := []store.AccessRule{
{SubjectType: "user_group", SubjectID: "group-vip", ResourceType: "platform", ResourceID: "platform-b", Effect: "allow", Status: "active"},
{SubjectType: "user_group", SubjectID: "group-blocked", ResourceType: "platform", ResourceID: "platform-a", Effect: "deny", Status: "active"},
}
userGroups := []store.UserGroup{
{ID: "group-vip", GroupKey: "vip", Name: "VIP 用户组"},
{ID: "group-blocked", GroupKey: "blocked", Name: "Blocked 用户组"},
}
baseModels := []store.BaseModel{
{ID: "", Metadata: map[string]any{"description": "高质量图像生成模型"}},
}
response := buildModelCatalog(models, platforms, providers, nil, accessRules, userGroups, baseModels)
if response.Summary.ModelCount != 1 || response.Summary.SourceCount != 2 {
t.Fatalf("unexpected summary: %+v", response.Summary)
}
item := response.Items[0]
if item.SourceCount != 2 {
t.Fatalf("expected merged source count, got %d", item.SourceCount)
}
if item.Source.Label != "2 个源" {
t.Fatalf("expected source label to only show count, got %q", item.Source.Label)
}
if item.RateLimits.RPM == nil || *item.RateLimits.RPM != 100 {
t.Fatalf("expected summed rpm 100, got %+v", item.RateLimits.RPM)
}
if item.RateLimits.TPM == nil || *item.RateLimits.TPM != 3000 {
t.Fatalf("expected summed tpm 3000, got %+v", item.RateLimits.TPM)
}
if item.RateLimits.Concurrent == nil || *item.RateLimits.Concurrent != 5 {
t.Fatalf("expected summed concurrency 5, got %+v", item.RateLimits.Concurrent)
}
if item.Permission.Label != "用户组 VIP 用户组;拒绝 Blocked 用户组" {
t.Fatalf("expected permission label from access rules, got %q", item.Permission.Label)
}
if len(item.Permission.AllowGroups) != 1 || item.Permission.AllowGroups[0] != "VIP 用户组" {
t.Fatalf("expected allow permission groups, got %+v", item.Permission.AllowGroups)
}
if len(item.Permission.DenyGroups) != 1 || item.Permission.DenyGroups[0] != "Blocked 用户组" {
t.Fatalf("expected deny permission groups, got %+v", item.Permission.DenyGroups)
}
if item.Discount.Label != "80% - 无折扣" {
t.Fatalf("expected friendly discount label, got %q", item.Discount.Label)
}
if len(item.ProviderKeys) != 2 {
t.Fatalf("expected both providers on merged item, got %+v", item.ProviderKeys)
}
if !hasFilterCount(response.Filters.Providers, "volces", 1) || !hasFilterCount(response.Filters.Providers, "gemini", 1) {
t.Fatalf("expected provider filters to count merged model for each provider: %+v", response.Filters.Providers)
}
if !hasFilterCount(response.Filters.Capabilities, "image", 1) {
t.Fatalf("expected image capability filter: %+v", response.Filters.Capabilities)
}
if got := item.Pricing.Lines[0]; got != "图像1K 10 / 2K 20" {
t.Fatalf("unexpected pricing line %q", got)
}
}
func TestBuildModelCatalogUsesBaseModelProviderForProviderFilters(t *testing.T) {
models := []store.PlatformModel{
{
ID: "glm-volces",
PlatformID: "platform-volces",
BaseModelID: "base-glm",
ModelName: "glm-4.7",
ModelAlias: "GLM-4.7",
ModelType: store.StringList{"text_generate"},
DisplayName: "GLM-4.7",
Enabled: true,
},
{
ID: "glm-zhipu",
PlatformID: "platform-zhipu",
BaseModelID: "base-glm",
ModelName: "glm-4.7",
ModelAlias: "GLM-4.7",
ModelType: store.StringList{"text_generate"},
DisplayName: "GLM-4.7",
Enabled: true,
},
}
platforms := []store.Platform{
{ID: "platform-volces", Provider: "volces-openai", Name: "火山引擎(OpenAI兼容)", Status: "enabled"},
{ID: "platform-zhipu", Provider: "zhipu-openai", Name: "智谱官方", Status: "enabled"},
}
providers := []store.CatalogProvider{
{ProviderKey: "volces-openai", DisplayName: "火山引擎(OpenAI兼容)", IconPath: "volces.png"},
{ProviderKey: "zhipu-openai", DisplayName: "智谱AI", IconPath: "zhipu.png"},
}
baseModels := []store.BaseModel{
{ID: "base-glm", ProviderKey: "zhipu-openai", ProviderModelName: "glm-4.7", ModelAlias: "GLM-4.7"},
}
response := buildModelCatalog(models, platforms, providers, nil, nil, nil, baseModels)
if response.Summary.ModelCount != 1 || response.Summary.SourceCount != 2 {
t.Fatalf("unexpected summary: %+v", response.Summary)
}
item := response.Items[0]
if len(item.ProviderKeys) != 1 || item.ProviderKeys[0] != "zhipu-openai" {
t.Fatalf("expected model provider zhipu-openai only, got %+v", item.ProviderKeys)
}
if len(item.Providers) != 1 || item.Providers[0].Name != "智谱AI" || item.Providers[0].SourceCount != 2 {
t.Fatalf("expected provider summary to aggregate both sources under model provider, got %+v", item.Providers)
}
if !hasFilterCount(response.Filters.Providers, "zhipu-openai", 1) {
t.Fatalf("expected zhipu provider filter count 1, got %+v", response.Filters.Providers)
}
if hasFilterCount(response.Filters.Providers, "volces-openai", 1) {
t.Fatalf("did not expect platform provider in model provider filters: %+v", response.Filters.Providers)
}
}
func TestBillingConfigLinesShowsTextInputAndOutputPricing(t *testing.T) {
lines := billingConfigLines(map[string]any{
"text_total": map[string]any{
"basePrice": 0.01,
"formulaConfig": map[string]any{
"inputTokenPrice": 0.01,
"outputTokenPrice": 0.03,
},
},
})
if len(lines) != 2 {
t.Fatalf("expected input and output pricing lines, got %+v", lines)
}
if lines[0] != "输入 0.01/k tokens" {
t.Fatalf("unexpected input pricing line %q", lines[0])
}
if lines[1] != "输出 0.03/k tokens" {
t.Fatalf("unexpected output pricing line %q", lines[1])
}
}
func TestBillingConfigLinesShowsVideoFiveSecondBasis(t *testing.T) {
lines := billingConfigLines(map[string]any{
"video": map[string]any{
"basePrice": float64(75),
"dynamicWeight": map[string]any{"480p": float64(1), "720p": float64(2)},
},
})
if len(lines) != 1 {
t.Fatalf("expected one video pricing line, got %+v", lines)
}
if lines[0] != "视频480p 75 / 720p 1505秒基准" {
t.Fatalf("unexpected video pricing line %q", lines[0])
}
flatLines := billingConfigLines(map[string]any{"videoBase": float64(100)})
if len(flatLines) != 1 || flatLines[0] != "视频100 / 5秒基准" {
t.Fatalf("unexpected flat video pricing line %+v", flatLines)
}
}
func hasFilterCount(options []ModelCatalogFilterOption, value string, count int) bool {
for _, option := range options {
if option.Value == value && option.Count == count {
return true
}
}
return false
}