feat: filter candidates by request resolution
This commit is contained in:
parent
ba419cd90a
commit
73c6d43e4b
@ -805,7 +805,7 @@ func (s *Server) estimatePricing(w http.ResponseWriter, r *http.Request) {
|
||||
estimate, err := s.runner.Estimate(r.Context(), kind, model, body, user)
|
||||
if err != nil {
|
||||
if errors.Is(err, store.ErrNoModelCandidate) {
|
||||
writeError(w, statusFromRunError(err), err.Error(), store.ModelCandidateErrorCode(err))
|
||||
writeErrorWithDetails(w, statusFromRunError(err), runErrorMessage(err), runErrorDetails(err), store.ModelCandidateErrorCode(err))
|
||||
return
|
||||
}
|
||||
s.logger.Error("estimate pricing failed", "error", err)
|
||||
@ -1168,6 +1168,9 @@ func runErrorDetails(err error) map[string]any {
|
||||
if detail := rateLimitErrorDetail(err); len(detail) > 0 {
|
||||
return map[string]any{"rateLimit": detail}
|
||||
}
|
||||
if detail := store.ModelCandidateErrorDetails(err); len(detail) > 0 {
|
||||
return map[string]any{"modelCandidate": detail}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
355
apps/api/internal/runner/candidate_filter.go
Normal file
355
apps/api/internal/runner/candidate_filter.go
Normal file
@ -0,0 +1,355 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||
)
|
||||
|
||||
const unsupportedRequestResolutionCode = "unsupported_request_resolution"
|
||||
|
||||
type requestResolutionRequirement struct {
|
||||
Kind string
|
||||
RequestedModel string
|
||||
ModelType string
|
||||
Resolution string
|
||||
Source string
|
||||
Scopes []string
|
||||
}
|
||||
|
||||
type videoResolutionReferenceStats struct {
|
||||
HasFirstFrame bool
|
||||
HasLastFrame bool
|
||||
ReferenceImages int
|
||||
HasReferenceVideo bool
|
||||
HasReferenceAudio bool
|
||||
HasAnyMedia bool
|
||||
HasExplicitContent bool
|
||||
}
|
||||
|
||||
func filterRuntimeCandidatesByRequest(kind string, requestedModel string, modelType string, body map[string]any, candidates []store.RuntimeModelCandidate) ([]store.RuntimeModelCandidate, map[string]any, error) {
|
||||
requirement, ok := requestResolutionRequirementFor(kind, requestedModel, modelType, body)
|
||||
if !ok || len(candidates) == 0 {
|
||||
return candidates, nil, nil
|
||||
}
|
||||
|
||||
filtered := make([]store.RuntimeModelCandidate, 0, len(candidates))
|
||||
rejected := make([]map[string]any, 0)
|
||||
supportedResolutions := make([]string, 0)
|
||||
for _, candidate := range candidates {
|
||||
supported, detail := candidateSupportsRequestResolution(candidate, requirement)
|
||||
if supported {
|
||||
filtered = append(filtered, candidate)
|
||||
for _, value := range stringListFromAny(detail["allowedResolutions"]) {
|
||||
appendUniqueString(&supportedResolutions, value)
|
||||
}
|
||||
continue
|
||||
}
|
||||
rejected = append(rejected, detail)
|
||||
for _, value := range stringListFromAny(detail["allowedResolutions"]) {
|
||||
appendUniqueString(&supportedResolutions, value)
|
||||
}
|
||||
}
|
||||
|
||||
summary := requestResolutionFilterSummary(requirement, len(candidates), len(filtered), rejected, supportedResolutions)
|
||||
if len(filtered) == 0 {
|
||||
return nil, summary, &store.ModelCandidateUnavailableError{
|
||||
Code: unsupportedRequestResolutionCode,
|
||||
Message: unsupportedRequestResolutionMessage(requirement, rejected),
|
||||
Details: summary,
|
||||
}
|
||||
}
|
||||
return filtered, summary, nil
|
||||
}
|
||||
|
||||
func requestResolutionRequirementFor(kind string, requestedModel string, modelType string, body map[string]any) (requestResolutionRequirement, bool) {
|
||||
if !isResolutionFilteredModelType(modelType) {
|
||||
return requestResolutionRequirement{}, false
|
||||
}
|
||||
resolution, source := requestResolutionValue(body, modelType)
|
||||
if resolution == "" {
|
||||
return requestResolutionRequirement{}, false
|
||||
}
|
||||
return requestResolutionRequirement{
|
||||
Kind: kind,
|
||||
RequestedModel: requestedModel,
|
||||
ModelType: modelType,
|
||||
Resolution: resolution,
|
||||
Source: source,
|
||||
Scopes: requestResolutionScopes(body, modelType),
|
||||
}, true
|
||||
}
|
||||
|
||||
func requestResolutionValue(body map[string]any, modelType string) (string, string) {
|
||||
if value := normalizedRequestResolution(stringFromAny(body["resolution"])); value != "" {
|
||||
return value, "resolution"
|
||||
}
|
||||
size := normalizedRequestResolution(stringFromAny(body["size"]))
|
||||
if size == "" {
|
||||
return "", ""
|
||||
}
|
||||
if isImageResolution(modelType, size) || isVideoResolution(modelType, size) {
|
||||
return size, "size"
|
||||
}
|
||||
return "", ""
|
||||
}
|
||||
|
||||
func normalizedRequestResolution(value string) string {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" || isEmptyParamString(value) {
|
||||
return ""
|
||||
}
|
||||
switch strings.ToLower(value) {
|
||||
case "auto", "automatic", "adaptive", "default":
|
||||
return ""
|
||||
default:
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
func isResolutionFilteredModelType(modelType string) bool {
|
||||
return modelType == "image_generate" || modelType == "image_edit" || isVideoModelType(modelType)
|
||||
}
|
||||
|
||||
func candidateSupportsRequestResolution(candidate store.RuntimeModelCandidate, requirement requestResolutionRequirement) (bool, map[string]any) {
|
||||
modelType := firstNonEmptyString(candidate.ModelType, requirement.ModelType)
|
||||
capability := capabilityForType(effectiveModelCapability(candidate), modelType)
|
||||
detail := candidateResolutionDetail(candidate, requirement, modelType)
|
||||
if capability == nil {
|
||||
detail["reason"] = "capability_missing"
|
||||
detail["message"] = "候选平台模型未配置对应模型类型能力。"
|
||||
detail["capabilityPath"] = capabilityPath(modelType, "output_resolutions")
|
||||
return false, detail
|
||||
}
|
||||
|
||||
allowed, configured := outputResolutionAllowedValues(capability["output_resolutions"], requirement.Scopes)
|
||||
detail["allowedResolutions"] = allowed
|
||||
detail["capabilityPath"] = capabilityPath(modelType, "output_resolutions")
|
||||
detail["capabilityValue"] = cloneAny(capability["output_resolutions"])
|
||||
if !configured {
|
||||
detail["reason"] = "output_resolutions_missing"
|
||||
detail["message"] = "候选平台模型未声明 output_resolutions。"
|
||||
return false, detail
|
||||
}
|
||||
if containsResolution(allowed, requirement.Resolution) {
|
||||
detail["reason"] = "supported"
|
||||
return true, detail
|
||||
}
|
||||
detail["reason"] = "resolution_not_allowed"
|
||||
detail["message"] = "请求分辨率不在候选平台模型 output_resolutions 中。"
|
||||
return false, detail
|
||||
}
|
||||
|
||||
func outputResolutionAllowedValues(value any, scopes []string) ([]string, bool) {
|
||||
switch typed := value.(type) {
|
||||
case []any, []string, string:
|
||||
return uniqueStringList(stringListFromAny(typed)), true
|
||||
case map[string]any:
|
||||
for _, scope := range append(scopes, "default", "*", "all") {
|
||||
if scope == "" {
|
||||
continue
|
||||
}
|
||||
if raw, ok := typed[scope]; ok {
|
||||
return uniqueStringList(stringListFromAny(raw)), true
|
||||
}
|
||||
}
|
||||
if len(scopes) == 0 {
|
||||
values := make([]string, 0)
|
||||
for _, raw := range typed {
|
||||
values = append(values, stringListFromAny(raw)...)
|
||||
}
|
||||
return uniqueStringList(values), len(values) > 0
|
||||
}
|
||||
return nil, true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func containsResolution(values []string, target string) bool {
|
||||
for _, value := range values {
|
||||
if strings.EqualFold(strings.TrimSpace(value), strings.TrimSpace(target)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func candidateResolutionDetail(candidate store.RuntimeModelCandidate, requirement requestResolutionRequirement, modelType string) map[string]any {
|
||||
return map[string]any{
|
||||
"platformId": candidate.PlatformID,
|
||||
"platformKey": candidate.PlatformKey,
|
||||
"platformName": candidate.PlatformName,
|
||||
"provider": candidate.Provider,
|
||||
"platformModelId": candidate.PlatformModelID,
|
||||
"modelName": candidate.ModelName,
|
||||
"modelAlias": candidate.ModelAlias,
|
||||
"displayName": candidate.DisplayName,
|
||||
"providerModelName": candidate.ProviderModelName,
|
||||
"modelType": modelType,
|
||||
"requested": map[string]any{
|
||||
"resolution": requirement.Resolution,
|
||||
"source": requirement.Source,
|
||||
"scopes": requirement.Scopes,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func requestResolutionFilterSummary(requirement requestResolutionRequirement, candidateCount int, supportedCandidateCount int, rejected []map[string]any, supportedResolutions []string) map[string]any {
|
||||
return map[string]any{
|
||||
"code": unsupportedRequestResolutionCode,
|
||||
"filter": "request_resolution",
|
||||
"kind": requirement.Kind,
|
||||
"requestedModel": requirement.RequestedModel,
|
||||
"modelType": requirement.ModelType,
|
||||
"requestedResolution": requirement.Resolution,
|
||||
"resolutionSource": requirement.Source,
|
||||
"resolutionScopes": requirement.Scopes,
|
||||
"capabilityPath": capabilityPath(requirement.ModelType, "output_resolutions"),
|
||||
"candidateCount": candidateCount,
|
||||
"supportedCandidateCount": supportedCandidateCount,
|
||||
"filteredCandidateCount": len(rejected),
|
||||
"supportedResolutions": uniqueStringList(supportedResolutions),
|
||||
"rejectedCandidates": rejected,
|
||||
}
|
||||
}
|
||||
|
||||
func unsupportedRequestResolutionMessage(requirement requestResolutionRequirement, rejected []map[string]any) string {
|
||||
resource := "媒体"
|
||||
if requirement.ModelType == "image_generate" || requirement.ModelType == "image_edit" {
|
||||
resource = "图像"
|
||||
} else if isVideoModelType(requirement.ModelType) {
|
||||
resource = "视频"
|
||||
}
|
||||
message := fmt.Sprintf("请求的%s分辨率 %s 没有可用平台模型支持,已过滤 %d 个候选平台模型", resource, requirement.Resolution, len(rejected))
|
||||
if summaries := rejectedResolutionSummaries(rejected, 3); len(summaries) > 0 {
|
||||
message += ";候选支持:" + strings.Join(summaries, ";")
|
||||
}
|
||||
return message
|
||||
}
|
||||
|
||||
func rejectedResolutionSummaries(rejected []map[string]any, limit int) []string {
|
||||
summaries := make([]string, 0, limit)
|
||||
for _, item := range rejected {
|
||||
if len(summaries) >= limit {
|
||||
break
|
||||
}
|
||||
allowed := stringListFromAny(item["allowedResolutions"])
|
||||
if len(allowed) == 0 {
|
||||
continue
|
||||
}
|
||||
name := firstNonEmptyString(stringFromAny(item["platformName"]), stringFromAny(item["platformKey"]), stringFromAny(item["provider"]))
|
||||
model := firstNonEmptyString(stringFromAny(item["displayName"]), stringFromAny(item["modelAlias"]), stringFromAny(item["modelName"]))
|
||||
if model != "" {
|
||||
name = firstNonEmptyString(name, model)
|
||||
if name != model {
|
||||
name += "/" + model
|
||||
}
|
||||
}
|
||||
if name == "" {
|
||||
name = "候选"
|
||||
}
|
||||
summaries = append(summaries, fmt.Sprintf("%s=%s", name, strings.Join(allowed, "/")))
|
||||
}
|
||||
return summaries
|
||||
}
|
||||
|
||||
func requestResolutionScopes(body map[string]any, modelType string) []string {
|
||||
if !isVideoModelType(modelType) {
|
||||
return nil
|
||||
}
|
||||
scopes := make([]string, 0)
|
||||
for _, key := range []string{"videoMode", "video_mode", "mode", "generation_mode", "generate_mode", "supported_mode"} {
|
||||
appendUniqueString(&scopes, stringFromMap(body, key))
|
||||
}
|
||||
stats := videoResolutionReferenceStatsFromBody(body)
|
||||
if stats.HasFirstFrame && stats.HasLastFrame {
|
||||
appendUniqueString(&scopes, "input_first_last_frame")
|
||||
appendUniqueString(&scopes, "first_last_frame")
|
||||
} else if stats.HasFirstFrame {
|
||||
appendUniqueString(&scopes, "input_first_frame")
|
||||
} else if stats.HasLastFrame {
|
||||
appendUniqueString(&scopes, "input_last_frame")
|
||||
}
|
||||
if stats.ReferenceImages > 1 {
|
||||
appendUniqueString(&scopes, "input_reference_generate_multiple")
|
||||
appendUniqueString(&scopes, "image_reference")
|
||||
} else if stats.ReferenceImages == 1 {
|
||||
appendUniqueString(&scopes, "input_reference_generate_single")
|
||||
appendUniqueString(&scopes, "image_reference")
|
||||
}
|
||||
if stats.HasReferenceVideo {
|
||||
appendUniqueString(&scopes, "video_reference")
|
||||
}
|
||||
if stats.HasReferenceAudio {
|
||||
appendUniqueString(&scopes, "audio_reference")
|
||||
}
|
||||
if !stats.HasAnyMedia {
|
||||
appendUniqueString(&scopes, "text_to_video")
|
||||
}
|
||||
return scopes
|
||||
}
|
||||
|
||||
func videoResolutionReferenceStatsFromBody(body map[string]any) videoResolutionReferenceStats {
|
||||
stats := videoResolutionReferenceStats{}
|
||||
content := contentItems(body["content"])
|
||||
stats.HasExplicitContent = len(content) > 0
|
||||
for _, item := range content {
|
||||
if isImageContent(item) {
|
||||
stats.HasAnyMedia = true
|
||||
switch strings.TrimSpace(stringFromAny(item["role"])) {
|
||||
case "first_frame":
|
||||
stats.HasFirstFrame = true
|
||||
case "last_frame":
|
||||
stats.HasLastFrame = true
|
||||
default:
|
||||
stats.ReferenceImages++
|
||||
}
|
||||
}
|
||||
if isVideoContent(item) {
|
||||
stats.HasAnyMedia = true
|
||||
stats.HasReferenceVideo = true
|
||||
}
|
||||
if isAudioContent(item) {
|
||||
stats.HasAnyMedia = true
|
||||
stats.HasReferenceAudio = true
|
||||
}
|
||||
}
|
||||
if hasAnyString(body, "first_frame", "firstFrame") {
|
||||
stats.HasAnyMedia = true
|
||||
stats.HasFirstFrame = true
|
||||
}
|
||||
if hasAnyString(body, "last_frame", "lastFrame") {
|
||||
stats.HasAnyMedia = true
|
||||
stats.HasLastFrame = true
|
||||
}
|
||||
if hasAnyString(body, "reference_image", "referenceImage") {
|
||||
stats.HasAnyMedia = true
|
||||
stats.ReferenceImages++
|
||||
}
|
||||
if hasAnyString(body, "video", "video_url", "videoUrl", "reference_video", "referenceVideo") {
|
||||
stats.HasAnyMedia = true
|
||||
stats.HasReferenceVideo = true
|
||||
}
|
||||
if hasAnyString(body, "audio_url", "audioUrl", "reference_audio", "referenceAudio") {
|
||||
stats.HasAnyMedia = true
|
||||
stats.HasReferenceAudio = true
|
||||
}
|
||||
if hasAnyString(body, "image", "images", "image_url", "imageUrl", "image_urls", "imageUrls") {
|
||||
stats.HasAnyMedia = true
|
||||
if !stats.HasFirstFrame && !stats.HasExplicitContent {
|
||||
stats.HasFirstFrame = true
|
||||
} else {
|
||||
stats.ReferenceImages++
|
||||
}
|
||||
}
|
||||
return stats
|
||||
}
|
||||
|
||||
func candidateCapabilityFilterMetrics(summary map[string]any) map[string]any {
|
||||
if len(summary) == 0 {
|
||||
return nil
|
||||
}
|
||||
return map[string]any{"candidateCapabilityFilter": summary}
|
||||
}
|
||||
191
apps/api/internal/runner/candidate_filter_test.go
Normal file
191
apps/api/internal/runner/candidate_filter_test.go
Normal file
@ -0,0 +1,191 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||
)
|
||||
|
||||
func TestFilterRuntimeCandidatesByRequestResolutionKeepsSupportedCandidate(t *testing.T) {
|
||||
candidates := []store.RuntimeModelCandidate{
|
||||
candidateWithResolutions("low", "720p"),
|
||||
candidateWithResolutions("high", "1080p"),
|
||||
}
|
||||
|
||||
filtered, summary, err := filterRuntimeCandidatesByRequest("videos.generations", "demo-video", "video_generate", map[string]any{
|
||||
"resolution": "1080p",
|
||||
}, candidates)
|
||||
if err != nil {
|
||||
t.Fatalf("filter should keep a supported candidate: %v", err)
|
||||
}
|
||||
if len(filtered) != 1 || filtered[0].PlatformKey != "high" {
|
||||
t.Fatalf("expected only high resolution candidate, got %+v", filtered)
|
||||
}
|
||||
if summary["filteredCandidateCount"] != 1 || summary["supportedCandidateCount"] != 1 {
|
||||
t.Fatalf("unexpected filter summary: %+v", summary)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterRuntimeCandidatesByScopedVideoResolution(t *testing.T) {
|
||||
candidates := []store.RuntimeModelCandidate{
|
||||
{
|
||||
PlatformID: "platform-first",
|
||||
PlatformKey: "first",
|
||||
PlatformName: "First Frame Platform",
|
||||
PlatformModelID: "model-first",
|
||||
ModelName: "demo-video",
|
||||
ModelType: "image_to_video",
|
||||
Capabilities: map[string]any{
|
||||
"image_to_video": map[string]any{
|
||||
"output_resolutions": map[string]any{
|
||||
"input_first_frame": []any{"1080p"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PlatformID: "platform-first-last",
|
||||
PlatformKey: "first-last",
|
||||
PlatformName: "First Last Platform",
|
||||
PlatformModelID: "model-first-last",
|
||||
ModelName: "demo-video",
|
||||
ModelType: "image_to_video",
|
||||
Capabilities: map[string]any{
|
||||
"image_to_video": map[string]any{
|
||||
"output_resolutions": map[string]any{
|
||||
"input_first_last_frame": []any{"1080p"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
filtered, _, err := filterRuntimeCandidatesByRequest("videos.generations", "demo-video", "image_to_video", map[string]any{
|
||||
"resolution": "1080p",
|
||||
"content": []any{
|
||||
map[string]any{"type": "image_url", "role": "first_frame", "image_url": map[string]any{"url": "https://example.com/first.png"}},
|
||||
map[string]any{"type": "image_url", "role": "last_frame", "image_url": map[string]any{"url": "https://example.com/last.png"}},
|
||||
},
|
||||
}, candidates)
|
||||
if err != nil {
|
||||
t.Fatalf("filter should keep first-last scoped candidate: %v", err)
|
||||
}
|
||||
if len(filtered) != 1 || filtered[0].PlatformKey != "first-last" {
|
||||
t.Fatalf("expected first-last scoped candidate only, got %+v", filtered)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterRuntimeCandidatesByRequestResolutionFailsWithDetails(t *testing.T) {
|
||||
candidates := []store.RuntimeModelCandidate{
|
||||
candidateWithImageResolutions("jimeng-v3", "1K", "2K"),
|
||||
candidateWithImageResolutions("jimeng-v4", "1K"),
|
||||
}
|
||||
|
||||
filtered, summary, err := filterRuntimeCandidatesByRequest("images.generations", "demo-image", "image_generate", map[string]any{
|
||||
"resolution": "4K",
|
||||
}, candidates)
|
||||
if len(filtered) != 0 {
|
||||
t.Fatalf("expected no candidates, got %+v", filtered)
|
||||
}
|
||||
var candidateErr *store.ModelCandidateUnavailableError
|
||||
if !errors.As(err, &candidateErr) {
|
||||
t.Fatalf("expected model candidate error, got %T %v", err, err)
|
||||
}
|
||||
if candidateErr.Code != unsupportedRequestResolutionCode {
|
||||
t.Fatalf("unexpected error code: %s", candidateErr.Code)
|
||||
}
|
||||
if !strings.Contains(candidateErr.Message, "4K") {
|
||||
t.Fatalf("message should include requested resolution, got %q", candidateErr.Message)
|
||||
}
|
||||
if summary["filteredCandidateCount"] != 2 || candidateErr.Details["requestedResolution"] != "4K" {
|
||||
t.Fatalf("unexpected filter detail summary=%+v details=%+v", summary, candidateErr.Details)
|
||||
}
|
||||
if details := store.ModelCandidateErrorDetails(err); details["requestedResolution"] != "4K" {
|
||||
t.Fatalf("store detail helper should expose requested resolution, got %+v", details)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterRuntimeCandidatesSkipsPixelSizeCompatibility(t *testing.T) {
|
||||
candidates := []store.RuntimeModelCandidate{{
|
||||
PlatformID: "openai",
|
||||
PlatformKey: "openai",
|
||||
PlatformModelID: "gpt-image-1",
|
||||
ModelName: "gpt-image-1",
|
||||
ModelType: "image_generate",
|
||||
Capabilities: map[string]any{
|
||||
"image_generate": map[string]any{
|
||||
"aspect_ratio_allowed": []any{"1:1"},
|
||||
},
|
||||
},
|
||||
}}
|
||||
|
||||
filtered, summary, err := filterRuntimeCandidatesByRequest("images.generations", "gpt-image-1", "image_generate", map[string]any{
|
||||
"size": "1024x1024",
|
||||
}, candidates)
|
||||
if err != nil {
|
||||
t.Fatalf("pixel size compatibility should skip resolution filtering: %v", err)
|
||||
}
|
||||
if len(filtered) != 1 || summary != nil {
|
||||
t.Fatalf("expected unchanged candidates and no summary, got filtered=%+v summary=%+v", filtered, summary)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildFailureResultIncludesModelCandidateDetails(t *testing.T) {
|
||||
cause := &store.ModelCandidateUnavailableError{
|
||||
Code: unsupportedRequestResolutionCode,
|
||||
Message: "unsupported resolution",
|
||||
Details: map[string]any{
|
||||
"requestedResolution": "4K",
|
||||
"candidateCount": 2,
|
||||
},
|
||||
}
|
||||
|
||||
result := buildFailureResult(store.ModelCandidateErrorCode(cause), cause.Error(), "", cause)
|
||||
errorPayload, _ := result["error"].(map[string]any)
|
||||
modelCandidate, _ := errorPayload["modelCandidate"].(map[string]any)
|
||||
if errorPayload["code"] != unsupportedRequestResolutionCode || modelCandidate["requestedResolution"] != "4K" {
|
||||
t.Fatalf("failure result should persist candidate details, got %+v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func candidateWithResolutions(platformKey string, resolutions ...string) store.RuntimeModelCandidate {
|
||||
return store.RuntimeModelCandidate{
|
||||
PlatformID: "platform-" + platformKey,
|
||||
PlatformKey: platformKey,
|
||||
PlatformName: "Platform " + platformKey,
|
||||
PlatformModelID: "model-" + platformKey,
|
||||
ModelName: "demo-video",
|
||||
ModelType: "video_generate",
|
||||
Capabilities: map[string]any{
|
||||
"video_generate": map[string]any{
|
||||
"output_resolutions": stringsToAny(resolutions),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func candidateWithImageResolutions(platformKey string, resolutions ...string) store.RuntimeModelCandidate {
|
||||
return store.RuntimeModelCandidate{
|
||||
PlatformID: "platform-" + platformKey,
|
||||
PlatformKey: platformKey,
|
||||
PlatformName: "Platform " + platformKey,
|
||||
PlatformModelID: "model-" + platformKey,
|
||||
ModelName: "demo-image",
|
||||
ModelType: "image_generate",
|
||||
Capabilities: map[string]any{
|
||||
"image_generate": map[string]any{
|
||||
"output_resolutions": stringsToAny(resolutions),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func stringsToAny(values []string) []any {
|
||||
out := make([]any, 0, len(values))
|
||||
for _, value := range values {
|
||||
out = append(out, value)
|
||||
}
|
||||
return out
|
||||
}
|
||||
@ -470,7 +470,7 @@ func isEmptyParamString(value string) bool {
|
||||
}
|
||||
|
||||
func isImageResolution(modelType string, value string) bool {
|
||||
return (modelType == "image_generate" || modelType == "image_edit") && containsString([]string{"1K", "2K", "4K", "8K"}, value)
|
||||
return (modelType == "image_generate" || modelType == "image_edit") && containsString([]string{"1K", "2K", "3K", "4K", "8K"}, value)
|
||||
}
|
||||
|
||||
func isVideoResolution(modelType string, value string) bool {
|
||||
|
||||
@ -19,7 +19,12 @@ type EstimateResult struct {
|
||||
|
||||
func (s *Service) Estimate(ctx context.Context, kind string, model string, body map[string]any, user *auth.User) (EstimateResult, error) {
|
||||
body = normalizeRequest(kind, body)
|
||||
candidates, err := s.store.ListModelCandidates(ctx, model, modelTypeFromKind(kind, body), user)
|
||||
modelType := modelTypeFromKind(kind, body)
|
||||
candidates, err := s.store.ListModelCandidates(ctx, model, modelType, user)
|
||||
if err != nil {
|
||||
return EstimateResult{}, err
|
||||
}
|
||||
candidates, _, err = filterRuntimeCandidatesByRequest(kind, model, modelType, body, candidates)
|
||||
if err != nil {
|
||||
return EstimateResult{}, err
|
||||
}
|
||||
|
||||
@ -214,6 +214,9 @@ func failureMetrics(err error, simulated bool) (string, map[string]any, time.Tim
|
||||
if detail := rateLimitFailureDetail(err); len(detail) > 0 {
|
||||
metrics["rateLimit"] = detail
|
||||
}
|
||||
if detail := store.ModelCandidateErrorDetails(err); len(detail) > 0 {
|
||||
metrics["modelCandidate"] = detail
|
||||
}
|
||||
}
|
||||
if meta.StatusCode > 0 {
|
||||
metrics["statusCode"] = meta.StatusCode
|
||||
@ -230,6 +233,23 @@ func failureMetrics(err error, simulated bool) (string, map[string]any, time.Tim
|
||||
return meta.RequestID, metrics, meta.ResponseStartedAt, meta.ResponseFinishedAt, meta.ResponseDurationMS
|
||||
}
|
||||
|
||||
func buildFailureResult(code string, message string, requestID string, err error) map[string]any {
|
||||
errorPayload := map[string]any{
|
||||
"code": code,
|
||||
"message": message,
|
||||
}
|
||||
if requestID != "" {
|
||||
errorPayload["requestId"] = requestID
|
||||
}
|
||||
if detail := rateLimitFailureDetail(err); len(detail) > 0 {
|
||||
errorPayload["rateLimit"] = detail
|
||||
}
|
||||
if detail := store.ModelCandidateErrorDetails(err); len(detail) > 0 {
|
||||
errorPayload["modelCandidate"] = detail
|
||||
}
|
||||
return map[string]any{"error": errorPayload}
|
||||
}
|
||||
|
||||
func rateLimitFailureDetail(err error) map[string]any {
|
||||
var limitErr *store.RateLimitExceededError
|
||||
if !errors.As(err, &limitErr) {
|
||||
|
||||
@ -120,6 +120,28 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
|
||||
}
|
||||
return Result{Task: failed, Output: failed.Result}, err
|
||||
}
|
||||
var candidateFilterSummary map[string]any
|
||||
candidates, candidateFilterSummary, err = filterRuntimeCandidatesByRequest(task.Kind, task.Model, modelType, body, candidates)
|
||||
if err != nil {
|
||||
candidateFilterMetrics := candidateCapabilityFilterMetrics(candidateFilterSummary)
|
||||
s.recordFailedAttempt(ctx, failedAttemptRecord{
|
||||
Task: task,
|
||||
Body: body,
|
||||
AttemptNo: task.AttemptCount + 1,
|
||||
Code: store.ModelCandidateErrorCode(err),
|
||||
Cause: err,
|
||||
Simulated: task.RunMode == "simulation",
|
||||
Scope: "candidate_request_filter",
|
||||
Reason: store.ModelCandidateErrorCode(err),
|
||||
ExtraMetrics: []map[string]any{candidateFilterMetrics},
|
||||
ModelType: modelType,
|
||||
})
|
||||
failed, finishErr := s.failTask(ctx, task.ID, store.ModelCandidateErrorCode(err), err.Error(), task.RunMode == "simulation", err, candidateFilterMetrics)
|
||||
if finishErr != nil {
|
||||
return Result{}, finishErr
|
||||
}
|
||||
return Result{Task: failed, Output: failed.Result}, err
|
||||
}
|
||||
firstCandidateBody := body
|
||||
normalizedModelType := modelType
|
||||
attemptNo := task.AttemptCount
|
||||
@ -230,6 +252,7 @@ candidatesLoop:
|
||||
attemptNo = nextAttemptNo
|
||||
billings := s.billings(ctx, user, task.Kind, candidateBody, candidate, response, isSimulation(task, candidate))
|
||||
record := buildSuccessRecord(task, user, candidateBody, candidate, response, billings, isSimulation(task, candidate))
|
||||
record.Metrics = mergeMetrics(record.Metrics, candidateCapabilityFilterMetrics(candidateFilterSummary))
|
||||
record.Metrics = mergeMetrics(record.Metrics, parameterPreprocessingMetrics(preprocessing.Log))
|
||||
record.Metrics = s.withAttemptHistory(ctx, task.ID, record.Metrics)
|
||||
finished, finishErr := s.store.FinishTaskSuccess(ctx, store.FinishTaskSuccessInput{
|
||||
@ -674,6 +697,7 @@ func (s *Service) failTask(ctx context.Context, taskID string, code string, mess
|
||||
TaskID: taskID,
|
||||
Code: code,
|
||||
Message: message,
|
||||
Result: buildFailureResult(code, message, requestID, cause),
|
||||
RequestID: requestID,
|
||||
Metrics: metrics,
|
||||
ResponseStartedAt: responseStartedAt,
|
||||
|
||||
@ -14,6 +14,7 @@ var (
|
||||
type ModelCandidateUnavailableError struct {
|
||||
Code string
|
||||
Message string
|
||||
Details map[string]any
|
||||
}
|
||||
|
||||
func (e *ModelCandidateUnavailableError) Error() string {
|
||||
@ -32,6 +33,14 @@ func ModelCandidateErrorCode(err error) string {
|
||||
return "no_model_candidate"
|
||||
}
|
||||
|
||||
func ModelCandidateErrorDetails(err error) map[string]any {
|
||||
var candidateErr *ModelCandidateUnavailableError
|
||||
if errors.As(err, &candidateErr) && len(candidateErr.Details) > 0 {
|
||||
return candidateErr.Details
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type RateLimitExceededError struct {
|
||||
ScopeType string
|
||||
ScopeKey string
|
||||
@ -247,6 +256,7 @@ type FinishTaskFailureInput struct {
|
||||
TaskID string
|
||||
Code string
|
||||
Message string
|
||||
Result map[string]any
|
||||
RequestID string
|
||||
Metrics map[string]any
|
||||
ResponseStartedAt time.Time
|
||||
|
||||
@ -778,22 +778,24 @@ func taskBillingString(value any) string {
|
||||
|
||||
func (s *Store) FinishTaskFailure(ctx context.Context, input FinishTaskFailureInput) (GatewayTask, error) {
|
||||
metricsJSON, _ := json.Marshal(emptyObjectIfNil(input.Metrics))
|
||||
resultJSON, _ := json.Marshal(emptyObjectIfNil(input.Result))
|
||||
if _, err := s.pool.Exec(ctx, `
|
||||
UPDATE gateway_tasks
|
||||
SET status = 'failed',
|
||||
error = NULLIF($2::text, ''),
|
||||
error_code = NULLIF($3::text, ''),
|
||||
error_message = NULLIF($2::text, ''),
|
||||
request_id = NULLIF($4::text, ''),
|
||||
metrics = $5::jsonb,
|
||||
response_started_at = $6::timestamptz,
|
||||
response_finished_at = $7::timestamptz,
|
||||
response_duration_ms = $8,
|
||||
locked_by = NULL,
|
||||
locked_at = NULL,
|
||||
heartbeat_at = NULL,
|
||||
finished_at = now(),
|
||||
updated_at = now()
|
||||
UPDATE gateway_tasks
|
||||
SET status = 'failed',
|
||||
error = NULLIF($2::text, ''),
|
||||
error_code = NULLIF($3::text, ''),
|
||||
error_message = NULLIF($2::text, ''),
|
||||
request_id = NULLIF($4::text, ''),
|
||||
metrics = $5::jsonb,
|
||||
response_started_at = $6::timestamptz,
|
||||
response_finished_at = $7::timestamptz,
|
||||
response_duration_ms = $8,
|
||||
result = $9::jsonb,
|
||||
locked_by = NULL,
|
||||
locked_at = NULL,
|
||||
heartbeat_at = NULL,
|
||||
finished_at = now(),
|
||||
updated_at = now()
|
||||
WHERE id = $1::uuid`,
|
||||
input.TaskID,
|
||||
input.Message,
|
||||
@ -803,6 +805,7 @@ WHERE id = $1::uuid`,
|
||||
nullableTime(input.ResponseStartedAt),
|
||||
nullableTime(input.ResponseFinishedAt),
|
||||
input.ResponseDurationMS,
|
||||
string(resultJSON),
|
||||
); err != nil {
|
||||
return GatewayTask{}, err
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user