222 lines
8.2 KiB
Go
222 lines
8.2 KiB
Go
package httpapi
|
||
|
||
import (
|
||
"testing"
|
||
|
||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||
)
|
||
|
||
func TestBuildModelCatalogAggregatesSources(t *testing.T) {
|
||
models := []store.PlatformModel{
|
||
{
|
||
ID: "model-a",
|
||
PlatformID: "platform-a",
|
||
ModelName: "seedance",
|
||
ModelAlias: "Seedance-2.0",
|
||
ModelType: store.StringList{"image_generate"},
|
||
DisplayName: "Seedance Source A",
|
||
BillingConfig: map[string]any{
|
||
"image": map[string]any{"basePrice": float64(10), "dynamicWeight": map[string]any{"1K": float64(1), "2K": float64(2)}},
|
||
},
|
||
RateLimitPolicy: map[string]any{
|
||
"platformLimits": map[string]any{
|
||
"max_request_per_minute": 60,
|
||
"max_token_per_minute": 1000,
|
||
"max_concurrent_requests": 2,
|
||
},
|
||
},
|
||
PricingMode: "inherit_discount",
|
||
Enabled: true,
|
||
},
|
||
{
|
||
ID: "model-b",
|
||
PlatformID: "platform-b",
|
||
ModelName: "seedance",
|
||
ModelAlias: "Seedance-2.0",
|
||
ModelType: store.StringList{"image_generate"},
|
||
DisplayName: "Seedance Source B",
|
||
BillingConfig: map[string]any{
|
||
"image": map[string]any{"basePrice": float64(10), "dynamicWeight": map[string]any{"1K": float64(1), "2K": float64(2)}},
|
||
},
|
||
RateLimitPolicy: map[string]any{
|
||
"rpm": 40,
|
||
"tpm": 2000,
|
||
"concurrent": 3,
|
||
},
|
||
DiscountFactor: 0.8,
|
||
PricingMode: "custom",
|
||
Enabled: true,
|
||
},
|
||
}
|
||
platforms := []store.Platform{
|
||
{ID: "platform-a", Provider: "volces", Name: "火山引擎", Status: "enabled", Priority: 20, DefaultDiscountFactor: 1},
|
||
{ID: "platform-b", Provider: "gemini", Name: "Gemini", Status: "enabled", Priority: 10, DefaultDiscountFactor: 1},
|
||
}
|
||
providers := []store.CatalogProvider{
|
||
{ProviderKey: "volces", DisplayName: "火山引擎", IconPath: "volces.png"},
|
||
{ProviderKey: "gemini", DisplayName: "Google Gemini", IconPath: "gemini.png"},
|
||
}
|
||
accessRules := []store.AccessRule{
|
||
{SubjectType: "user_group", SubjectID: "group-vip", ResourceType: "platform", ResourceID: "platform-b", Effect: "allow", Status: "active"},
|
||
{SubjectType: "user_group", SubjectID: "group-blocked", ResourceType: "platform", ResourceID: "platform-a", Effect: "deny", Status: "active"},
|
||
}
|
||
userGroups := []store.UserGroup{
|
||
{ID: "group-vip", GroupKey: "vip", Name: "VIP 用户组"},
|
||
{ID: "group-blocked", GroupKey: "blocked", Name: "Blocked 用户组"},
|
||
}
|
||
baseModels := []store.BaseModel{
|
||
{ID: "", Metadata: map[string]any{"description": "高质量图像生成模型"}},
|
||
}
|
||
|
||
response := buildModelCatalog(models, platforms, providers, nil, accessRules, userGroups, baseModels)
|
||
if response.Summary.ModelCount != 1 || response.Summary.SourceCount != 2 {
|
||
t.Fatalf("unexpected summary: %+v", response.Summary)
|
||
}
|
||
item := response.Items[0]
|
||
if item.SourceCount != 2 {
|
||
t.Fatalf("expected merged source count, got %d", item.SourceCount)
|
||
}
|
||
if item.Source.Label != "2 个源" {
|
||
t.Fatalf("expected source label to only show count, got %q", item.Source.Label)
|
||
}
|
||
if item.RateLimits.RPM == nil || *item.RateLimits.RPM != 100 {
|
||
t.Fatalf("expected summed rpm 100, got %+v", item.RateLimits.RPM)
|
||
}
|
||
if item.RateLimits.TPM == nil || *item.RateLimits.TPM != 3000 {
|
||
t.Fatalf("expected summed tpm 3000, got %+v", item.RateLimits.TPM)
|
||
}
|
||
if item.RateLimits.Concurrent == nil || *item.RateLimits.Concurrent != 5 {
|
||
t.Fatalf("expected summed concurrency 5, got %+v", item.RateLimits.Concurrent)
|
||
}
|
||
if item.Permission.Label != "用户组 VIP 用户组;拒绝 Blocked 用户组" {
|
||
t.Fatalf("expected permission label from access rules, got %q", item.Permission.Label)
|
||
}
|
||
if len(item.Permission.AllowGroups) != 1 || item.Permission.AllowGroups[0] != "VIP 用户组" {
|
||
t.Fatalf("expected allow permission groups, got %+v", item.Permission.AllowGroups)
|
||
}
|
||
if len(item.Permission.DenyGroups) != 1 || item.Permission.DenyGroups[0] != "Blocked 用户组" {
|
||
t.Fatalf("expected deny permission groups, got %+v", item.Permission.DenyGroups)
|
||
}
|
||
if item.Discount.Label != "80% - 无折扣" {
|
||
t.Fatalf("expected friendly discount label, got %q", item.Discount.Label)
|
||
}
|
||
if len(item.ProviderKeys) != 2 {
|
||
t.Fatalf("expected both providers on merged item, got %+v", item.ProviderKeys)
|
||
}
|
||
if !hasFilterCount(response.Filters.Providers, "volces", 1) || !hasFilterCount(response.Filters.Providers, "gemini", 1) {
|
||
t.Fatalf("expected provider filters to count merged model for each provider: %+v", response.Filters.Providers)
|
||
}
|
||
if !hasFilterCount(response.Filters.Capabilities, "image", 1) {
|
||
t.Fatalf("expected image capability filter: %+v", response.Filters.Capabilities)
|
||
}
|
||
if got := item.Pricing.Lines[0]; got != "图像:1K 10 / 2K 20" {
|
||
t.Fatalf("unexpected pricing line %q", got)
|
||
}
|
||
}
|
||
|
||
func TestBuildModelCatalogUsesBaseModelProviderForProviderFilters(t *testing.T) {
|
||
models := []store.PlatformModel{
|
||
{
|
||
ID: "glm-volces",
|
||
PlatformID: "platform-volces",
|
||
BaseModelID: "base-glm",
|
||
ModelName: "glm-4.7",
|
||
ModelAlias: "GLM-4.7",
|
||
ModelType: store.StringList{"text_generate"},
|
||
DisplayName: "GLM-4.7",
|
||
Enabled: true,
|
||
},
|
||
{
|
||
ID: "glm-zhipu",
|
||
PlatformID: "platform-zhipu",
|
||
BaseModelID: "base-glm",
|
||
ModelName: "glm-4.7",
|
||
ModelAlias: "GLM-4.7",
|
||
ModelType: store.StringList{"text_generate"},
|
||
DisplayName: "GLM-4.7",
|
||
Enabled: true,
|
||
},
|
||
}
|
||
platforms := []store.Platform{
|
||
{ID: "platform-volces", Provider: "volces-openai", Name: "火山引擎(OpenAI兼容)", Status: "enabled"},
|
||
{ID: "platform-zhipu", Provider: "zhipu-openai", Name: "智谱官方", Status: "enabled"},
|
||
}
|
||
providers := []store.CatalogProvider{
|
||
{ProviderKey: "volces-openai", DisplayName: "火山引擎(OpenAI兼容)", IconPath: "volces.png"},
|
||
{ProviderKey: "zhipu-openai", DisplayName: "智谱AI", IconPath: "zhipu.png"},
|
||
}
|
||
baseModels := []store.BaseModel{
|
||
{ID: "base-glm", ProviderKey: "zhipu-openai", ProviderModelName: "glm-4.7", ModelAlias: "GLM-4.7"},
|
||
}
|
||
|
||
response := buildModelCatalog(models, platforms, providers, nil, nil, nil, baseModels)
|
||
if response.Summary.ModelCount != 1 || response.Summary.SourceCount != 2 {
|
||
t.Fatalf("unexpected summary: %+v", response.Summary)
|
||
}
|
||
item := response.Items[0]
|
||
if len(item.ProviderKeys) != 1 || item.ProviderKeys[0] != "zhipu-openai" {
|
||
t.Fatalf("expected model provider zhipu-openai only, got %+v", item.ProviderKeys)
|
||
}
|
||
if len(item.Providers) != 1 || item.Providers[0].Name != "智谱AI" || item.Providers[0].SourceCount != 2 {
|
||
t.Fatalf("expected provider summary to aggregate both sources under model provider, got %+v", item.Providers)
|
||
}
|
||
if !hasFilterCount(response.Filters.Providers, "zhipu-openai", 1) {
|
||
t.Fatalf("expected zhipu provider filter count 1, got %+v", response.Filters.Providers)
|
||
}
|
||
if hasFilterCount(response.Filters.Providers, "volces-openai", 1) {
|
||
t.Fatalf("did not expect platform provider in model provider filters: %+v", response.Filters.Providers)
|
||
}
|
||
}
|
||
|
||
func TestBillingConfigLinesShowsTextInputAndOutputPricing(t *testing.T) {
|
||
lines := billingConfigLines(map[string]any{
|
||
"text_total": map[string]any{
|
||
"basePrice": 0.01,
|
||
"formulaConfig": map[string]any{
|
||
"inputTokenPrice": 0.01,
|
||
"outputTokenPrice": 0.03,
|
||
},
|
||
},
|
||
})
|
||
|
||
if len(lines) != 2 {
|
||
t.Fatalf("expected input and output pricing lines, got %+v", lines)
|
||
}
|
||
if lines[0] != "输入 0.01/k tokens" {
|
||
t.Fatalf("unexpected input pricing line %q", lines[0])
|
||
}
|
||
if lines[1] != "输出 0.03/k tokens" {
|
||
t.Fatalf("unexpected output pricing line %q", lines[1])
|
||
}
|
||
}
|
||
|
||
func TestBillingConfigLinesShowsVideoFiveSecondBasis(t *testing.T) {
|
||
lines := billingConfigLines(map[string]any{
|
||
"video": map[string]any{
|
||
"basePrice": float64(75),
|
||
"dynamicWeight": map[string]any{"480p": float64(1), "720p": float64(2)},
|
||
},
|
||
})
|
||
|
||
if len(lines) != 1 {
|
||
t.Fatalf("expected one video pricing line, got %+v", lines)
|
||
}
|
||
if lines[0] != "视频:480p 75 / 720p 150(5秒基准)" {
|
||
t.Fatalf("unexpected video pricing line %q", lines[0])
|
||
}
|
||
|
||
flatLines := billingConfigLines(map[string]any{"videoBase": float64(100)})
|
||
if len(flatLines) != 1 || flatLines[0] != "视频:100 / 5秒基准" {
|
||
t.Fatalf("unexpected flat video pricing line %+v", flatLines)
|
||
}
|
||
}
|
||
|
||
func hasFilterCount(options []ModelCatalogFilterOption, value string, count int) bool {
|
||
for _, option := range options {
|
||
if option.Value == value && option.Count == count {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|