package httpapi import ( "testing" "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" ) func TestBuildModelCatalogAggregatesSources(t *testing.T) { models := []store.PlatformModel{ { ID: "model-a", PlatformID: "platform-a", ModelName: "seedance", ModelAlias: "Seedance-2.0", ModelType: store.StringList{"image_generate"}, DisplayName: "Seedance Source A", BillingConfig: map[string]any{ "image": map[string]any{"basePrice": float64(10), "dynamicWeight": map[string]any{"1K": float64(1), "2K": float64(2)}}, }, RateLimitPolicy: map[string]any{ "platformLimits": map[string]any{ "max_request_per_minute": 60, "max_token_per_minute": 1000, "max_concurrent_requests": 2, }, }, PricingMode: "inherit_discount", Enabled: true, }, { ID: "model-b", PlatformID: "platform-b", ModelName: "seedance", ModelAlias: "Seedance-2.0", ModelType: store.StringList{"image_generate"}, DisplayName: "Seedance Source B", BillingConfig: map[string]any{ "image": map[string]any{"basePrice": float64(10), "dynamicWeight": map[string]any{"1K": float64(1), "2K": float64(2)}}, }, RateLimitPolicy: map[string]any{ "rpm": 40, "tpm": 2000, "concurrent": 3, }, DiscountFactor: 0.8, PricingMode: "custom", Enabled: true, }, } platforms := []store.Platform{ {ID: "platform-a", Provider: "volces", Name: "火山引擎", Status: "enabled", Priority: 20, DefaultDiscountFactor: 1}, {ID: "platform-b", Provider: "gemini", Name: "Gemini", Status: "enabled", Priority: 10, DefaultDiscountFactor: 1}, } providers := []store.CatalogProvider{ {ProviderKey: "volces", DisplayName: "火山引擎", IconPath: "volces.png"}, {ProviderKey: "gemini", DisplayName: "Google Gemini", IconPath: "gemini.png"}, } accessRules := []store.AccessRule{ {SubjectType: "user_group", SubjectID: "group-vip", ResourceType: "platform", ResourceID: "platform-b", Effect: "allow", Status: "active"}, {SubjectType: "user_group", SubjectID: "group-blocked", ResourceType: "platform", ResourceID: "platform-a", Effect: "deny", Status: "active"}, } userGroups := []store.UserGroup{ {ID: "group-vip", GroupKey: "vip", Name: "VIP 用户组"}, {ID: "group-blocked", GroupKey: "blocked", Name: "Blocked 用户组"}, } baseModels := []store.BaseModel{ {ID: "", Metadata: map[string]any{"description": "高质量图像生成模型"}}, } response := buildModelCatalog(models, platforms, providers, nil, accessRules, userGroups, baseModels) if response.Summary.ModelCount != 1 || response.Summary.SourceCount != 2 { t.Fatalf("unexpected summary: %+v", response.Summary) } item := response.Items[0] if item.SourceCount != 2 { t.Fatalf("expected merged source count, got %d", item.SourceCount) } if item.Source.Label != "2 个源" { t.Fatalf("expected source label to only show count, got %q", item.Source.Label) } if item.RateLimits.RPM == nil || *item.RateLimits.RPM != 100 { t.Fatalf("expected summed rpm 100, got %+v", item.RateLimits.RPM) } if item.RateLimits.TPM == nil || *item.RateLimits.TPM != 3000 { t.Fatalf("expected summed tpm 3000, got %+v", item.RateLimits.TPM) } if item.RateLimits.Concurrent == nil || *item.RateLimits.Concurrent != 5 { t.Fatalf("expected summed concurrency 5, got %+v", item.RateLimits.Concurrent) } if item.Permission.Label != "用户组 VIP 用户组;拒绝 Blocked 用户组" { t.Fatalf("expected permission label from access rules, got %q", item.Permission.Label) } if len(item.Permission.AllowGroups) != 1 || item.Permission.AllowGroups[0] != "VIP 用户组" { t.Fatalf("expected allow permission groups, got %+v", item.Permission.AllowGroups) } if len(item.Permission.DenyGroups) != 1 || item.Permission.DenyGroups[0] != "Blocked 用户组" { t.Fatalf("expected deny permission groups, got %+v", item.Permission.DenyGroups) } if item.Discount.Label != "80% - 无折扣" { t.Fatalf("expected friendly discount label, got %q", item.Discount.Label) } if len(item.ProviderKeys) != 2 { t.Fatalf("expected both providers on merged item, got %+v", item.ProviderKeys) } if !hasFilterCount(response.Filters.Providers, "volces", 1) || !hasFilterCount(response.Filters.Providers, "gemini", 1) { t.Fatalf("expected provider filters to count merged model for each provider: %+v", response.Filters.Providers) } if !hasFilterCount(response.Filters.Capabilities, "image", 1) { t.Fatalf("expected image capability filter: %+v", response.Filters.Capabilities) } if got := item.Pricing.Lines[0]; got != "图像:1K 10 / 2K 20" { t.Fatalf("unexpected pricing line %q", got) } } func TestBuildModelCatalogUsesBaseModelProviderForProviderFilters(t *testing.T) { models := []store.PlatformModel{ { ID: "glm-volces", PlatformID: "platform-volces", BaseModelID: "base-glm", ModelName: "glm-4.7", ModelAlias: "GLM-4.7", ModelType: store.StringList{"text_generate"}, DisplayName: "GLM-4.7", Enabled: true, }, { ID: "glm-zhipu", PlatformID: "platform-zhipu", BaseModelID: "base-glm", ModelName: "glm-4.7", ModelAlias: "GLM-4.7", ModelType: store.StringList{"text_generate"}, DisplayName: "GLM-4.7", Enabled: true, }, } platforms := []store.Platform{ {ID: "platform-volces", Provider: "volces-openai", Name: "火山引擎(OpenAI兼容)", Status: "enabled"}, {ID: "platform-zhipu", Provider: "zhipu-openai", Name: "智谱官方", Status: "enabled"}, } providers := []store.CatalogProvider{ {ProviderKey: "volces-openai", DisplayName: "火山引擎(OpenAI兼容)", IconPath: "volces.png"}, {ProviderKey: "zhipu-openai", DisplayName: "智谱AI", IconPath: "zhipu.png"}, } baseModels := []store.BaseModel{ {ID: "base-glm", ProviderKey: "zhipu-openai", ProviderModelName: "glm-4.7", ModelAlias: "GLM-4.7"}, } response := buildModelCatalog(models, platforms, providers, nil, nil, nil, baseModels) if response.Summary.ModelCount != 1 || response.Summary.SourceCount != 2 { t.Fatalf("unexpected summary: %+v", response.Summary) } item := response.Items[0] if len(item.ProviderKeys) != 1 || item.ProviderKeys[0] != "zhipu-openai" { t.Fatalf("expected model provider zhipu-openai only, got %+v", item.ProviderKeys) } if len(item.Providers) != 1 || item.Providers[0].Name != "智谱AI" || item.Providers[0].SourceCount != 2 { t.Fatalf("expected provider summary to aggregate both sources under model provider, got %+v", item.Providers) } if !hasFilterCount(response.Filters.Providers, "zhipu-openai", 1) { t.Fatalf("expected zhipu provider filter count 1, got %+v", response.Filters.Providers) } if hasFilterCount(response.Filters.Providers, "volces-openai", 1) { t.Fatalf("did not expect platform provider in model provider filters: %+v", response.Filters.Providers) } } func TestBillingConfigLinesShowsTextInputAndOutputPricing(t *testing.T) { lines := billingConfigLines(map[string]any{ "text_total": map[string]any{ "basePrice": 0.01, "formulaConfig": map[string]any{ "inputTokenPrice": 0.01, "outputTokenPrice": 0.03, }, }, }) if len(lines) != 2 { t.Fatalf("expected input and output pricing lines, got %+v", lines) } if lines[0] != "输入 0.01/k tokens" { t.Fatalf("unexpected input pricing line %q", lines[0]) } if lines[1] != "输出 0.03/k tokens" { t.Fatalf("unexpected output pricing line %q", lines[1]) } } func TestBillingConfigLinesShowsVideoFiveSecondBasis(t *testing.T) { lines := billingConfigLines(map[string]any{ "video": map[string]any{ "basePrice": float64(75), "dynamicWeight": map[string]any{"480p": float64(1), "720p": float64(2)}, }, }) if len(lines) != 1 { t.Fatalf("expected one video pricing line, got %+v", lines) } if lines[0] != "视频:480p 75 / 720p 150(5秒基准)" { t.Fatalf("unexpected video pricing line %q", lines[0]) } flatLines := billingConfigLines(map[string]any{"videoBase": float64(100)}) if len(flatLines) != 1 || flatLines[0] != "视频:100 / 5秒基准" { t.Fatalf("unexpected flat video pricing line %+v", flatLines) } } func hasFilterCount(options []ModelCatalogFilterOption, value string, count int) bool { for _, option := range options { if option.Value == value && option.Count == count { return true } } return false }