Merge origin/main into chore/devenv-setup

This commit is contained in:
chensipeng 2026-05-15 09:40:56 +08:00
commit be283daaa3
50 changed files with 10909 additions and 2190 deletions

1
.gitignore vendored
View File

@ -8,6 +8,7 @@ node_modules/
apps/api/bin/
apps/api/tmp/
apps/api/data/
coverage/

View File

@ -329,6 +329,12 @@ func TestVolcesClientVideoSubmitsAndPollsTask(t *testing.T) {
var gotModel string
var gotText string
var gotFirstFrameRole string
var gotDuration float64
var gotRatio string
var gotResolution string
var gotSeed float64
var gotCameraFixed bool
var gotWatermark bool
var submittedRemoteTaskID string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotAuth = r.Header.Get("Authorization")
@ -343,6 +349,17 @@ func TestVolcesClientVideoSubmitsAndPollsTask(t *testing.T) {
if body["prompt"] != nil || body["first_frame"] != nil {
t.Fatalf("video convenience fields leaked upstream: %+v", body)
}
for _, key := range []string{"duration_seconds", "aspect_ratio", "audio", "cameraFixed"} {
if _, ok := body[key]; ok {
t.Fatalf("volces video task body should not include top-level %s: %+v", key, body)
}
}
gotDuration, _ = body["duration"].(float64)
gotRatio, _ = body["ratio"].(string)
gotResolution, _ = body["resolution"].(string)
gotSeed, _ = body["seed"].(float64)
gotCameraFixed, _ = body["camera_fixed"].(bool)
gotWatermark, _ = body["watermark"].(bool)
content, _ := body["content"].([]any)
textItem, _ := content[0].(map[string]any)
gotText, _ = textItem["text"].(string)
@ -375,6 +392,10 @@ func TestVolcesClientVideoSubmitsAndPollsTask(t *testing.T) {
"first_frame": "https://example.com/first.png",
"duration": 6,
"aspect_ratio": "16:9",
"resolution": "720p",
"seed": 11,
"cameraFixed": false,
"watermark": true,
},
Candidate: store.RuntimeModelCandidate{
BaseURL: server.URL,
@ -406,10 +427,11 @@ func TestVolcesClientVideoSubmitsAndPollsTask(t *testing.T) {
if gotModel != "doubao-seedance-2-0-260128" || gotFirstFrameRole != "first_frame" {
t.Fatalf("unexpected submitted model=%s role=%s", gotModel, gotFirstFrameRole)
}
for _, fragment := range []string{"A clean product reveal", "--dur 6", "--ratio 16:9", "--watermark false", "--seed -1"} {
if !strings.Contains(gotText, fragment) {
t.Fatalf("expected text to contain %q, got %q", fragment, gotText)
}
if gotText != "A clean product reveal" {
t.Fatalf("video params should not be appended to prompt text, got %q", gotText)
}
if gotDuration != 6 || gotRatio != "16:9" || gotResolution != "720p" || gotSeed != 11 || gotCameraFixed != false || gotWatermark != true {
t.Fatalf("unexpected submitted video params duration=%v ratio=%s resolution=%s seed=%v camera_fixed=%v watermark=%v", gotDuration, gotRatio, gotResolution, gotSeed, gotCameraFixed, gotWatermark)
}
data, _ := response.Result["data"].([]any)
item, _ := data[0].(map[string]any)
@ -418,6 +440,181 @@ func TestVolcesClientVideoSubmitsAndPollsTask(t *testing.T) {
}
}
func TestVolcesClientVideoRejectsDuplicateFirstFrameBeforeSubmit(t *testing.T) {
var submitted bool
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
submitted = true
t.Fatalf("duplicate first_frame request should not be submitted upstream")
}))
defer server.Close()
_, err := (VolcesClient{HTTPClient: server.Client()}).Run(context.Background(), Request{
Kind: "videos.generations",
ModelType: "image_to_video",
Model: "豆包Seedance",
Body: map[string]any{
"model": "豆包Seedance",
"content": []any{
map[string]any{"type": "text", "text": "animate it"},
map[string]any{"type": "image_url", "role": "first_frame", "image_url": map[string]any{"url": "https://example.com/first.png"}},
map[string]any{"type": "image_url", "role": "first_frame", "image_url": map[string]any{"url": "https://example.com/second.png"}},
},
},
Candidate: store.RuntimeModelCandidate{
BaseURL: server.URL,
ProviderModelName: "doubao-seedance-1-5-pro-251215",
Credentials: map[string]any{"apiKey": "volces-key"},
},
})
if err == nil || ErrorCode(err) != "invalid_parameter" {
t.Fatalf("expected local invalid_parameter error, got %v", err)
}
if submitted {
t.Fatal("request was submitted upstream")
}
}
func TestVolcesVideoBodyAllowsOnlyTaskPayloadFields(t *testing.T) {
body := volcesVideoBody(Request{
Kind: "videos.generations",
ModelType: "omni_video",
Model: "豆包Seedance",
Body: map[string]any{
"model": "豆包Seedance",
"duration": 8,
"duration_seconds": 8,
"aspect_ratio": "9:16",
"resolution": "720p",
"audio": true,
"callback_url": "https://example.com/callback",
"returnLastFrame": true,
"executionExpiresAfter": 3600,
"draft": false,
"cameraFixed": false,
"watermark": true,
"seed": -1,
"task_id": "local-task-id",
"runMode": "simulation",
"fps": 24,
"content": []any{
map[string]any{"type": "text", "text": "Use <<<element_1>>> in a product reveal"},
map[string]any{
"type": "element",
"element": map[string]any{
"inline_element": map[string]any{
"name": "subject",
"frontal_image_url": "https://example.com/subject.png",
"refer_images": []any{map[string]any{"url": "https://example.com/side.png", "slot_key": "side"}},
},
},
},
map[string]any{
"type": "image_url",
"role": "unexpected_role",
"name": "drop-me",
"image_url": map[string]any{"url": "https://example.com/ref.png", "extra": "drop-me"},
},
map[string]any{
"type": "video_url",
"duration": 3,
"video_url": map[string]any{
"url": "https://example.com/ref.mp4",
"refer_type": "feature",
"keep_original_sound": "yes",
"extra": "drop-me",
},
},
map[string]any{
"type": "audio_url",
"audio_url": map[string]any{"url": "https://example.com/ref.mp3", "extra": "drop-me"},
},
},
},
Candidate: store.RuntimeModelCandidate{
ModelName: "豆包Seedance",
ProviderModelName: "doubao-seedance-2-0-260128",
Credentials: map[string]any{"apiKey": "volces-key"},
},
})
allowedTopLevel := map[string]bool{
"model": true, "content": true, "callback_url": true, "return_last_frame": true, "execution_expires_after": true,
"generate_audio": true, "draft": true, "resolution": true, "ratio": true, "duration": true,
"seed": true, "camera_fixed": true, "watermark": true,
}
for key := range body {
if !allowedTopLevel[key] {
t.Fatalf("unexpected top-level volces field %q in %+v", key, body)
}
}
if body["model"] != "doubao-seedance-2-0-260128" ||
body["generate_audio"] != true ||
body["callback_url"] != "https://example.com/callback" ||
body["return_last_frame"] != true ||
body["execution_expires_after"] != 3600 ||
body["draft"] != false ||
body["resolution"] != "720p" ||
body["ratio"] != "9:16" ||
body["duration"] != 8 ||
body["seed"] != -1 ||
body["camera_fixed"] != false ||
body["watermark"] != true {
t.Fatalf("unexpected direct video fields: %+v", body)
}
content, ok := body["content"].([]map[string]any)
if !ok || len(content) != 5 {
t.Fatalf("unexpected sanitized content: %#v", body["content"])
}
text := content[0]
if text["type"] != "text" || strings.Contains(text["text"].(string), "--dur") || strings.Contains(text["text"].(string), "--ratio") {
t.Fatalf("video params should not be appended to the text item: %+v", text)
}
elementImage := content[1]
if elementImage["type"] != "image_url" || elementImage["role"] != "reference_image" {
t.Fatalf("referenced element should be converted to reference image: %+v", elementImage)
}
imageURL, _ := elementImage["image_url"].(map[string]any)
if imageURL["url"] != "https://example.com/subject.png" || len(imageURL) != 1 {
t.Fatalf("element image payload should only include url: %+v", imageURL)
}
referenceImage := content[2]
if referenceImage["role"] != "reference_image" || referenceImage["name"] != nil {
t.Fatalf("image references should be role-normalized and scrubbed: %+v", referenceImage)
}
videoItem := content[3]
videoURL, _ := videoItem["video_url"].(map[string]any)
if videoItem["role"] != "reference_video" || videoURL["url"] != "https://example.com/ref.mp4" || videoURL["refer_type"] != "feature" || videoURL["extra"] != nil {
t.Fatalf("video references should keep only allowed nested fields: %+v", videoItem)
}
audioItem := content[4]
audioURL, _ := audioItem["audio_url"].(map[string]any)
if audioItem["role"] != "reference_audio" || audioURL["url"] != "https://example.com/ref.mp3" || len(audioURL) != 1 {
t.Fatalf("audio references should keep only url: %+v", audioItem)
}
}
func TestVolcesVideoBodyPrefersFramesOverDuration(t *testing.T) {
body := volcesVideoBody(Request{
Kind: "videos.generations",
ModelType: "video_generate",
Body: map[string]any{
"prompt": "A quick camera move",
"duration": 8,
"frames": 57,
},
Candidate: store.RuntimeModelCandidate{
ProviderModelName: "doubao-seedance-1-0-pro-250528",
},
})
if body["frames"] != 57 {
t.Fatalf("frames should be passed through as the official duration control: %+v", body)
}
if _, ok := body["duration"]; ok {
t.Fatalf("duration should not be sent when frames is present: %+v", body)
}
}
func TestVolcesClientVideoResumePollsExistingTaskID(t *testing.T) {
var submitCalled bool
var pollPath string

View File

@ -339,5 +339,12 @@ func firstNonEmptyPrompt(body map[string]any, fallback string) string {
return value
}
}
for _, item := range contentItems(body["content"]) {
if stringValue(item, "type") == "text" {
if value := strings.TrimSpace(stringValue(item, "text")); value != "" {
return value
}
}
}
return fallback
}

View File

@ -7,10 +7,14 @@ import (
"fmt"
"math"
"net/http"
"regexp"
"strconv"
"strings"
"time"
)
var volcesElementReferencePattern = regexp.MustCompile(`(?i)<<<[[:space:]]*element[_-]?([0-9]+)[[:space:]]*>>>|@element([0-9]+)`)
type VolcesClient struct {
HTTPClient *http.Client
}
@ -72,6 +76,9 @@ func (c VolcesClient) runVideo(ctx context.Context, request Request, apiKey stri
upstreamTaskID := strings.TrimSpace(request.RemoteTaskID)
if upstreamTaskID == "" {
body := volcesVideoBody(request)
if err := validateVolcesVideoTaskBody(body); err != nil {
return Response{}, err
}
submitResult, requestID, err := c.postJSON(ctx, request, request.Candidate.BaseURL, "/contents/generations/tasks", apiKey, body)
submitRequestID = requestID
if err != nil {
@ -215,11 +222,9 @@ func volcesVideoBody(request Request) map[string]any {
content = buildVolcesContentFromBody(body)
}
appendMultiShotTimeline(&content)
convertVolcesElementsToImageReferences(&content)
normalizeVolcesContentRoles(content)
appendVolcesVideoParams(&content, body)
body["content"] = content
stripVolcesVideoConvenienceFields(body)
return body
return volcesVideoTaskBody(body, content)
}
func cleanProviderBody(body map[string]any) map[string]any {
@ -286,56 +291,267 @@ func buildVolcesContentFromBody(body map[string]any) []map[string]any {
return content
}
func stripVolcesVideoConvenienceFields(body map[string]any) {
for _, key := range []string{
"prompt",
"input",
"image",
"images",
"image_url",
"imageUrl",
"image_urls",
"imageUrls",
"reference_image",
"referenceImage",
"first_frame",
"firstFrame",
"last_frame",
"lastFrame",
"video",
"video_url",
"videoUrl",
"reference_video",
"referenceVideo",
"audio_url",
"audioUrl",
"reference_audio",
"referenceAudio",
} {
delete(body, key)
func volcesVideoTaskBody(body map[string]any, content []map[string]any) map[string]any {
out := map[string]any{
"model": body["model"],
"content": sanitizeVolcesVideoContent(content),
}
addVolcesVideoTaskParams(out, body)
return out
}
func validateVolcesVideoTaskBody(body map[string]any) error {
firstFrameCount := 0
lastFrameCount := 0
for _, item := range contentItems(body["content"]) {
if stringFromAny(item["type"]) != "image_url" {
continue
}
switch stringFromAny(item["role"]) {
case "first_frame":
firstFrameCount++
case "last_frame":
lastFrameCount++
}
}
if firstFrameCount > 1 {
return &ClientError{
Code: "invalid_parameter",
Message: fmt.Sprintf("content contains %d first_frame image items; expected at most one first frame image content", firstFrameCount),
StatusCode: 400,
Retryable: false,
}
}
if lastFrameCount > 1 {
return &ClientError{
Code: "invalid_parameter",
Message: fmt.Sprintf("content contains %d last_frame image items; expected at most one last frame image content", lastFrameCount),
StatusCode: 400,
Retryable: false,
}
}
return nil
}
func addVolcesVideoTaskParams(out map[string]any, body map[string]any) {
copyVolcesStringParam(out, "callback_url", body, "callback_url", "callbackUrl")
copyVolcesBoolParam(out, "return_last_frame", body, "return_last_frame", "returnLastFrame")
copyVolcesIntParam(out, "execution_expires_after", body, "execution_expires_after", "executionExpiresAfter")
copyVolcesBoolParam(out, "generate_audio", body, "generate_audio", "generateAudio", "audio")
copyVolcesBoolParam(out, "draft", body, "draft")
copyVolcesStringParam(out, "resolution", body, "resolution", "size")
copyVolcesStringParam(out, "ratio", body, "ratio", "aspect_ratio", "aspectRatio")
if copyVolcesIntParam(out, "frames", body, "frames") {
delete(out, "duration")
} else {
copyVolcesIntParam(out, "duration", body, "duration", "duration_seconds", "durationSeconds", "dur")
}
copyVolcesIntParam(out, "seed", body, "seed")
copyVolcesBoolParam(out, "camera_fixed", body, "camera_fixed", "cameraFixed", "camerafixed", "cf")
copyVolcesBoolParam(out, "watermark", body, "watermark")
}
func copyVolcesStringParam(out map[string]any, target string, body map[string]any, keys ...string) bool {
for _, key := range keys {
if value := strings.TrimSpace(stringFromAny(body[key])); value != "" {
out[target] = value
return true
}
}
return false
}
func copyVolcesIntParam(out map[string]any, target string, body map[string]any, keys ...string) bool {
for _, key := range keys {
if value, ok := volcesIntFromAny(body[key]); ok {
out[target] = value
return true
}
}
return false
}
func copyVolcesBoolParam(out map[string]any, target string, body map[string]any, keys ...string) bool {
for _, key := range keys {
if value, ok := volcesBoolFromAny(body[key]); ok {
out[target] = value
return true
}
}
return false
}
func volcesIntFromAny(value any) (int, bool) {
switch typed := value.(type) {
case nil:
return 0, false
case int:
return typed, true
case int64:
return int(typed), true
case float64:
return int(math.Round(typed)), true
case string:
text := strings.TrimSpace(typed)
if text == "" {
return 0, false
}
if parsed, err := strconv.ParseFloat(text, 64); err == nil {
return int(math.Round(parsed)), true
}
return 0, false
default:
return 0, false
}
}
func contentItems(value any) []map[string]any {
rawItems, ok := value.([]any)
if !ok {
return nil
func volcesBoolFromAny(value any) (bool, bool) {
switch typed := value.(type) {
case nil:
return false, false
case bool:
return typed, true
case int:
if typed == 1 {
return true, true
}
if typed == 0 {
return false, true
}
case int64:
if typed == 1 {
return true, true
}
if typed == 0 {
return false, true
}
case float64:
if typed == 1 {
return true, true
}
if typed == 0 {
return false, true
}
case string:
normalized := strings.ToLower(strings.TrimSpace(typed))
if normalized == "true" || normalized == "1" {
return true, true
}
if normalized == "false" || normalized == "0" {
return false, true
}
}
out := make([]map[string]any, 0, len(rawItems))
for _, raw := range rawItems {
item, ok := raw.(map[string]any)
if !ok {
continue
return false, false
}
func sanitizeVolcesVideoContent(content []map[string]any) []map[string]any {
out := make([]map[string]any, 0, len(content))
for _, item := range content {
switch stringFromAny(item["type"]) {
case "text":
out = append(out, map[string]any{
"type": "text",
"text": strings.TrimSpace(stringFromAny(item["text"])),
})
case "image_url":
url := volcesNestedURL(item, "image_url")
if url == "" {
continue
}
out = append(out, map[string]any{
"type": "image_url",
"role": volcesImageRole(item),
"image_url": map[string]any{"url": url},
})
case "video_url":
url := volcesNestedURL(item, "video_url")
if url == "" {
continue
}
videoURL := map[string]any{"url": url}
if value := strings.TrimSpace(stringFromAny(mapFromAny(item["video_url"])["refer_type"])); value != "" {
videoURL["refer_type"] = value
}
if value := strings.TrimSpace(stringFromAny(mapFromAny(item["video_url"])["keep_original_sound"])); value != "" {
videoURL["keep_original_sound"] = value
}
out = append(out, map[string]any{
"type": "video_url",
"role": "reference_video",
"video_url": videoURL,
})
case "audio_url":
url := volcesNestedURL(item, "audio_url")
if url == "" {
continue
}
out = append(out, map[string]any{
"type": "audio_url",
"role": "reference_audio",
"audio_url": map[string]any{"url": url},
})
}
copied := map[string]any{}
for key, value := range item {
copied[key] = value
}
out = append(out, copied)
}
if len(out) == 0 {
return []map[string]any{{"type": "text", "text": ""}}
}
return out
}
func volcesImageRole(item map[string]any) string {
switch strings.TrimSpace(stringFromAny(item["role"])) {
case "first_frame":
return "first_frame"
case "last_frame":
return "last_frame"
default:
return "reference_image"
}
}
func volcesNestedURL(item map[string]any, key string) string {
nested := mapFromAny(item[key])
return strings.TrimSpace(stringFromAny(nested["url"]))
}
func mapFromAny(value any) map[string]any {
if object, ok := value.(map[string]any); ok {
return object
}
return nil
}
func contentItems(value any) []map[string]any {
switch typed := value.(type) {
case []any:
out := make([]map[string]any, 0, len(typed))
for _, raw := range typed {
item, ok := raw.(map[string]any)
if !ok {
continue
}
copied := map[string]any{}
for key, value := range item {
copied[key] = value
}
out = append(out, copied)
}
return out
case []map[string]any:
out := make([]map[string]any, 0, len(typed))
for _, item := range typed {
copied := map[string]any{}
for key, value := range item {
copied[key] = value
}
out = append(out, copied)
}
return out
default:
return nil
}
}
func normalizeVolcesContentRoles(content []map[string]any) {
for _, item := range content {
itemType := strings.TrimSpace(stringFromAny(item["type"]))
@ -353,32 +569,115 @@ func normalizeVolcesContentRoles(content []map[string]any) {
}
}
func appendVolcesVideoParams(content *[]map[string]any, body map[string]any) {
textItem := ensureTextContent(content)
current := strings.TrimSpace(stringFromAny(textItem["text"]))
values := []struct {
key string
value any
}{
{"dur", firstPresent(body["duration"], body["dur"])},
{"ratio", firstPresent(body["aspect_ratio"], body["aspectRatio"], body["ratio"])},
{"fps", firstPresent(body["framespersecond"], body["framesPerSecond"], body["fps"])},
{"watermark", firstPresent(body["watermark"], false)},
{"seed", firstPresent(body["seed"], -1)},
{"cf", firstPresent(body["camerafixed"], body["cameraFixed"])},
{"rs", firstPresent(body["resolution"], body["size"])},
}
for _, item := range values {
valueText := volcesParamString(item.value)
if valueText == "" || strings.Contains(current, "--"+item.key) {
func convertVolcesElementsToImageReferences(content *[]map[string]any) {
referenced := referencedVolcesElementIndexes(*content)
out := make([]map[string]any, 0, len(*content))
elementIndex := 0
for _, item := range *content {
if stringFromAny(item["type"]) != "element" {
out = append(out, item)
continue
}
if current != "" {
current += " "
elementIndex++
if !referenced[elementIndex] {
continue
}
current += "--" + item.key + " " + valueText
url := volcesElementFrontalImageURL(item)
if url == "" {
continue
}
role := stringFromAny(item["role"])
if role != "first_frame" && role != "last_frame" {
role = "reference_image"
}
out = append(out, map[string]any{
"type": "image_url",
"role": role,
"image_url": map[string]any{"url": url},
})
}
*content = out
}
func referencedVolcesElementIndexes(content []map[string]any) map[int]bool {
out := map[int]bool{}
for _, item := range content {
if stringFromAny(item["type"]) != "text" {
continue
}
text := stringFromAny(item["text"])
if strings.TrimSpace(text) == "" {
continue
}
for _, match := range volcesElementReferencePattern.FindAllStringSubmatch(text, -1) {
raw := ""
if len(match) > 1 && match[1] != "" {
raw = match[1]
} else if len(match) > 2 {
raw = match[2]
}
index, err := strconv.Atoi(raw)
if err == nil && index > 0 {
out[index] = true
}
}
}
return out
}
func volcesElementFrontalImageURL(item map[string]any) string {
element := mapFromAny(item["element"])
if element == nil {
return ""
}
inline := mapFromAny(element["inline_element"])
for _, value := range []any{
inline["frontal_image_url"],
element["frontal_image_url"],
element["front_image_url"],
element["image_url"],
} {
if url := strings.TrimSpace(stringFromAny(value)); url != "" {
return url
}
}
return volcesReferImageURL(firstPresent(inline["refer_images"], element["refer_images"]))
}
func volcesReferImageURL(value any) string {
images := mapListFromAny(value)
firstURL := ""
for _, image := range images {
url := strings.TrimSpace(stringFromAny(image["url"]))
if url == "" {
continue
}
if firstURL == "" {
firstURL = url
}
slot := strings.ToLower(strings.TrimSpace(stringFromAny(image["slot_key"])))
if slot == "frontal" || slot == "front" {
return url
}
}
return firstURL
}
func mapListFromAny(value any) []map[string]any {
switch typed := value.(type) {
case []any:
out := make([]map[string]any, 0, len(typed))
for _, item := range typed {
if object := mapFromAny(item); object != nil {
out = append(out, object)
}
}
return out
case []map[string]any:
return typed
default:
return nil
}
textItem["text"] = current
}
func appendMultiShotTimeline(content *[]map[string]any) {
@ -625,31 +924,6 @@ func firstNonEmptyStringListFromAny(values ...any) []string {
return nil
}
func volcesParamString(value any) string {
switch typed := value.(type) {
case nil:
return ""
case string:
return strings.TrimSpace(typed)
case bool:
if typed {
return "true"
}
return "false"
case int:
return fmt.Sprintf("%d", typed)
case int64:
return fmt.Sprintf("%d", typed)
case float64:
if math.Mod(typed, 1) == 0 {
return fmt.Sprintf("%d", int64(typed))
}
return fmt.Sprintf("%g", typed)
default:
return fmt.Sprintf("%v", typed)
}
}
func numericValue(value any, fallback float64) float64 {
switch typed := value.(type) {
case int:

View File

@ -7,6 +7,11 @@ import (
"strings"
)
const (
DefaultLocalGeneratedStorageDir = "data/static/generated"
DefaultLocalUploadedStorageDir = "data/static/uploaded"
)
type Config struct {
AppEnv string
HTTPAddr string
@ -15,6 +20,9 @@ type Config struct {
JWTSecret string
ServerMainBaseURL string
ServerMainInternalToken string
PublicBaseURL string
LocalGeneratedStorageDir string
LocalUploadedStorageDir string
TaskProgressCallbackEnabled bool
TaskProgressCallbackURL string
TaskProgressCallbackTimeoutMS string
@ -38,6 +46,9 @@ func Load() Config {
"/",
),
ServerMainInternalToken: env("SERVER_MAIN_INTERNAL_TOKEN", ""),
PublicBaseURL: strings.TrimRight(env("AI_GATEWAY_PUBLIC_BASE_URL", env("PUBLIC_BASE_URL", "")), "/"),
LocalGeneratedStorageDir: env("AI_GATEWAY_GENERATED_STORAGE_DIR", env("LOCAL_GENERATED_STORAGE_DIR", env("AI_GATEWAY_STATIC_STORAGE_DIR", DefaultLocalGeneratedStorageDir))),
LocalUploadedStorageDir: env("AI_GATEWAY_UPLOADED_STORAGE_DIR", env("LOCAL_UPLOADED_STORAGE_DIR", DefaultLocalUploadedStorageDir)),
TaskProgressCallbackEnabled: env("TASK_PROGRESS_CALLBACK_ENABLED", "true") == "true",
TaskProgressCallbackURL: env("TASK_PROGRESS_CALLBACK_URL",
strings.TrimRight(env("SERVER_MAIN_BASE_URL", "http://localhost:3000"), "/")+"/internal/platform/task-progress-callbacks",

View File

@ -0,0 +1,58 @@
package httpapi
import (
"io"
"net/http"
"strings"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/runner"
)
const maxGatewayUploadBytes = 256 << 20
func (s *Server) uploadFile(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxGatewayUploadBytes)
if err := r.ParseMultipartForm(32 << 20); err != nil {
writeError(w, http.StatusBadRequest, "invalid multipart upload")
return
}
file, header, err := r.FormFile("file")
if err != nil {
writeError(w, http.StatusBadRequest, "file is required")
return
}
defer file.Close()
payload, err := io.ReadAll(file)
if err != nil {
writeError(w, http.StatusBadRequest, "read upload file failed")
return
}
contentType := strings.TrimSpace(header.Header.Get("Content-Type"))
if contentType == "" && len(payload) > 0 {
contentType = http.DetectContentType(payload)
}
upload, err := s.runner.UploadFile(r.Context(), runner.FileUploadPayload{
Bytes: payload,
ContentType: contentType,
FileName: header.Filename,
Source: firstNonEmptyFormValue(r, "source", "ai-gateway-openapi"),
})
if err != nil {
s.logger.Error("upload file failed", "error", err)
status := http.StatusBadGateway
if clients.ErrorCode(err) == "upload_no_channel" {
status = http.StatusServiceUnavailable
}
writeError(w, status, err.Error())
return
}
writeJSON(w, http.StatusOK, upload)
}
func firstNonEmptyFormValue(r *http.Request, key string, fallback string) string {
if value := strings.TrimSpace(r.FormValue(key)); value != "" {
return value
}
return fallback
}

View File

@ -597,7 +597,7 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler {
status := statusFromRunError(runErr)
errorPayload := map[string]any{
"code": runErrorCode(runErr),
"message": runErr.Error(),
"message": runErrorMessage(runErr),
"status": status,
}
if result.Task.ID != "" {
@ -606,6 +606,9 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler {
if result.Task.RequestID != "" {
errorPayload["requestId"] = result.Task.RequestID
}
for key, value := range runErrorDetails(runErr) {
errorPayload[key] = value
}
sendSSE(w, "error", map[string]any{"error": errorPayload})
if flusher != nil {
flusher.Flush()
@ -626,7 +629,7 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler {
if !requestStillConnected(r) {
return
}
writeError(w, statusFromRunError(runErr), runErr.Error(), runErrorCode(runErr))
writeErrorWithDetails(w, statusFromRunError(runErr), runErrorMessage(runErr), runErrorDetails(runErr), runErrorCode(runErr))
return
}
if !requestStillConnected(r) {
@ -742,6 +745,138 @@ func runErrorCode(err error) string {
return clients.ErrorCode(err)
}
func runErrorMessage(err error) string {
if err == nil {
return ""
}
if summary := rateLimitErrorSummary(err); summary != "" {
return err.Error() + "" + summary
}
return err.Error()
}
func runErrorDetails(err error) map[string]any {
if detail := rateLimitErrorDetail(err); len(detail) > 0 {
return map[string]any{"rateLimit": detail}
}
return nil
}
func rateLimitErrorSummary(err error) string {
var limitErr *store.RateLimitExceededError
if !errors.As(err, &limitErr) {
return ""
}
scopeLabel := "限流对象"
switch limitErr.ScopeType {
case "user_group":
scopeLabel = "用户组"
case "platform_model":
scopeLabel = "平台模型"
}
scopeName := strings.TrimSpace(limitErr.ScopeName)
if scopeName == "" {
scopeName = strings.TrimSpace(limitErr.ScopeKey)
}
if groupKey := stringValue(limitErr.ScopeMetadata["groupKey"]); limitErr.ScopeType == "user_group" && groupKey != "" && groupKey != scopeName {
scopeName = fmt.Sprintf("%s(%s)", scopeName, groupKey)
}
projected := limitErr.Projected
if projected <= 0 {
projected = limitErr.Current + limitErr.Amount
}
parts := []string{
fmt.Sprintf("限流摘要:%s %s 的 %s 超限", scopeLabel, scopeName, limitErr.Metric),
fmt.Sprintf("当前 %s本次 %s预计 %s限制 %s", formatRateLimitValue(limitErr.Current), formatRateLimitValue(limitErr.Amount), formatRateLimitValue(projected), formatRateLimitValue(limitErr.Limit)),
}
if limitErr.WindowSeconds > 0 {
parts = append(parts, fmt.Sprintf("窗口 %d 秒", limitErr.WindowSeconds))
}
if limitErr.RetryAfter > 0 {
parts = append(parts, fmt.Sprintf("约%s后可重试", formatRateLimitDuration(limitErr.RetryAfter)))
} else if !limitErr.Retryable {
parts = append(parts, "该请求超过单次限额,不能排队重试")
}
return strings.Join(parts, "")
}
func rateLimitErrorDetail(err error) map[string]any {
var limitErr *store.RateLimitExceededError
if !errors.As(err, &limitErr) {
return nil
}
detail := map[string]any{
"scopeType": limitErr.ScopeType,
"scopeKey": limitErr.ScopeKey,
"scopeName": limitErr.ScopeName,
"metric": limitErr.Metric,
"limit": limitErr.Limit,
"amount": limitErr.Amount,
"current": limitErr.Current,
"used": limitErr.Used,
"reserved": limitErr.Reserved,
"projected": limitErr.Projected,
"windowSeconds": limitErr.WindowSeconds,
"retryable": limitErr.Retryable,
"exceeded": map[string]any{
"metric": limitErr.Metric,
"current": limitErr.Current,
"amount": limitErr.Amount,
"projected": limitErr.Projected,
"limit": limitErr.Limit,
},
}
if limitErr.RetryAfter > 0 {
detail["retryAfterMs"] = limitErr.RetryAfter.Milliseconds()
}
if !limitErr.ResetAt.IsZero() {
detail["resetAt"] = limitErr.ResetAt.UTC().Format(time.RFC3339Nano)
}
if len(limitErr.Policy) > 0 {
detail["rateLimitPolicy"] = limitErr.Policy
if matchedRule := matchedRateLimitRule(limitErr.Policy, limitErr.Metric); len(matchedRule) > 0 {
detail["matchedRule"] = matchedRule
}
}
if len(limitErr.ScopeMetadata) > 0 {
detail["scopeMetadata"] = limitErr.ScopeMetadata
}
if limitErr.ScopeType == "user_group" {
userGroup := map[string]any{
"id": limitErr.ScopeKey,
"name": limitErr.ScopeName,
}
if groupKey := stringValue(limitErr.ScopeMetadata["groupKey"]); groupKey != "" {
userGroup["groupKey"] = groupKey
}
detail["userGroup"] = userGroup
}
return detail
}
func formatRateLimitValue(value float64) string {
return strconv.FormatFloat(value, 'f', -1, 64)
}
func formatRateLimitDuration(duration time.Duration) string {
if duration < time.Second {
return strconv.FormatInt(duration.Milliseconds(), 10) + "毫秒"
}
seconds := duration.Seconds()
return strconv.FormatFloat(seconds, 'f', -1, 64) + "秒"
}
func matchedRateLimitRule(policy map[string]any, metric string) map[string]any {
rules, _ := policy["rules"].([]any)
for _, rawRule := range rules {
rule, _ := rawRule.(map[string]any)
if stringValue(rule["metric"]) == metric {
return rule
}
}
return nil
}
func (s *Server) listTasks(w http.ResponseWriter, r *http.Request) {
user, ok := auth.UserFromContext(r.Context())
if !ok {

View File

@ -0,0 +1,72 @@
package httpapi
import (
"strings"
"testing"
"time"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
func TestRateLimitErrorDetailIncludesUserGroupAndExceededMetric(t *testing.T) {
resetAt := time.Date(2026, 5, 15, 10, 30, 0, 0, time.UTC)
detail := rateLimitErrorDetail(&store.RateLimitExceededError{
ScopeType: "user_group",
ScopeKey: "group-1",
ScopeName: "VIP 用户组",
ScopeMetadata: map[string]any{"groupKey": "vip"},
Metric: "rpm",
Limit: 2,
Amount: 1,
Current: 2,
Used: 1,
Reserved: 1,
Projected: 3,
WindowSeconds: 60,
ResetAt: resetAt,
RetryAfter: 5 * time.Second,
Retryable: true,
Policy: map[string]any{
"rules": []any{
map[string]any{"metric": "rpm", "limit": float64(2), "windowSeconds": float64(60)},
},
},
})
if detail["metric"] != "rpm" || detail["projected"] != float64(3) || detail["limit"] != float64(2) {
t.Fatalf("unexpected exceeded detail: %+v", detail)
}
userGroup, _ := detail["userGroup"].(map[string]any)
if userGroup["id"] != "group-1" || userGroup["groupKey"] != "vip" || userGroup["name"] != "VIP 用户组" {
t.Fatalf("missing user group detail: %+v", detail)
}
matchedRule, _ := detail["matchedRule"].(map[string]any)
if matchedRule["metric"] != "rpm" {
t.Fatalf("missing matched rule: %+v", detail)
}
if detail["retryAfterMs"] != int64(5000) || detail["resetAt"] != resetAt.Format(time.RFC3339Nano) {
t.Fatalf("missing retry/reset detail: %+v", detail)
}
}
func TestRunErrorMessageIncludesRateLimitSummary(t *testing.T) {
message := runErrorMessage(&store.RateLimitExceededError{
ScopeType: "user_group",
ScopeKey: "group-1",
ScopeName: "VIP 用户组",
ScopeMetadata: map[string]any{"groupKey": "vip"},
Metric: "rpm",
Limit: 2,
Amount: 1,
Current: 2,
Projected: 3,
WindowSeconds: 60,
RetryAfter: 5 * time.Second,
Retryable: true,
Message: "rate limit exceeded: rpm window has no remaining capacity",
})
for _, expected := range []string{"限流摘要", "用户组 VIP 用户组(vip)", "rpm 超限", "当前 2", "本次 1", "预计 3", "限制 2", "窗口 60 秒", "约5秒后可重试"} {
if !strings.Contains(message, expected) {
t.Fatalf("message %q should contain %q", message, expected)
}
}
}

View File

@ -14,6 +14,10 @@ func writeJSON(w http.ResponseWriter, status int, value any) {
}
func writeError(w http.ResponseWriter, status int, message string, codes ...string) {
writeErrorWithDetails(w, status, message, nil, codes...)
}
func writeErrorWithDetails(w http.ResponseWriter, status int, message string, details map[string]any, codes ...string) {
errorPayload := map[string]any{
"message": message,
"status": status,
@ -23,6 +27,9 @@ func writeError(w http.ResponseWriter, status int, message string, codes ...stri
errorPayload["code"] = code
}
}
for key, value := range details {
errorPayload[key] = value
}
writeJSON(w, status, map[string]any{"error": errorPayload})
}

View File

@ -41,6 +41,8 @@ func NewServerWithContext(ctx context.Context, cfg config.Config, db *store.Stor
mux.HandleFunc("GET /healthz", server.health)
mux.HandleFunc("GET /readyz", server.ready)
mux.HandleFunc("GET /static/simulation/{asset}", serveSimulationAsset)
mux.HandleFunc("GET /static/generated/{asset}", server.serveGeneratedStaticAsset)
mux.HandleFunc("GET /static/uploaded/{asset}", server.serveUploadedStaticAsset)
mux.Handle("POST /api/v1/auth/register", server.auth.Require(auth.PermissionPublic, http.HandlerFunc(server.register)))
mux.Handle("POST /api/v1/auth/login", server.auth.Require(auth.PermissionPublic, http.HandlerFunc(server.login)))
@ -102,6 +104,12 @@ func NewServerWithContext(ctx context.Context, cfg config.Config, db *store.Stor
mux.Handle("GET /api/admin/runtime/runner-policy", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.getRunnerPolicy)))
mux.Handle("PATCH /api/admin/runtime/runner-policy", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.updateRunnerPolicy)))
mux.Handle("GET /api/admin/config/network-proxy", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.getNetworkProxyConfig)))
mux.Handle("GET /api/admin/system/file-storage/settings", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.getFileStorageSettings)))
mux.Handle("PATCH /api/admin/system/file-storage/settings", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.updateFileStorageSettings)))
mux.Handle("GET /api/admin/system/file-storage/channels", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.listFileStorageChannels)))
mux.Handle("POST /api/admin/system/file-storage/channels", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.createFileStorageChannel)))
mux.Handle("PATCH /api/admin/system/file-storage/channels/{channelID}", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.updateFileStorageChannel)))
mux.Handle("DELETE /api/admin/system/file-storage/channels/{channelID}", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.deleteFileStorageChannel)))
mux.Handle("GET /api/admin/platforms", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.listPlatforms)))
mux.Handle("POST /api/admin/platforms", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.createPlatform)))
mux.Handle("PATCH /api/admin/platforms/{platformID}", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.updatePlatform)))
@ -123,6 +131,7 @@ func NewServerWithContext(ctx context.Context, cfg config.Config, db *store.Stor
mux.Handle("POST /api/v1/images/generations", server.auth.Require(auth.PermissionBasic, server.createTask("images.generations", false)))
mux.Handle("POST /api/v1/images/edits", server.auth.Require(auth.PermissionBasic, server.createTask("images.edits", false)))
mux.Handle("POST /api/v1/videos/generations", server.auth.Require(auth.PermissionBasic, server.createTask("videos.generations", false)))
mux.Handle("POST /api/v1/files/upload", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.uploadFile)))
mux.Handle("GET /api/v1/tasks", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listTasks)))
mux.Handle("GET /api/v1/tasks/{taskID}", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.getTask)))
mux.Handle("GET /api/v1/tasks/{taskID}/param-preprocessing", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskParamPreprocessing)))
@ -135,6 +144,7 @@ func NewServerWithContext(ctx context.Context, cfg config.Config, db *store.Stor
mux.Handle("POST /v1/images/generations", server.auth.Require(auth.PermissionBasic, server.createTask("images.generations", true)))
mux.Handle("POST /images/edits", server.auth.Require(auth.PermissionBasic, server.createTask("images.edits", true)))
mux.Handle("POST /v1/images/edits", server.auth.Require(auth.PermissionBasic, server.createTask("images.edits", true)))
mux.Handle("POST /v1/files/upload", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.uploadFile)))
return server.recover(server.cors(mux))
}

View File

@ -0,0 +1,37 @@
package httpapi
import (
"net/http"
"os"
"path/filepath"
"strings"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
)
func (s *Server) serveGeneratedStaticAsset(w http.ResponseWriter, r *http.Request) {
s.serveLocalStaticAsset(w, r, s.cfg.LocalGeneratedStorageDir, config.DefaultLocalGeneratedStorageDir)
}
func (s *Server) serveUploadedStaticAsset(w http.ResponseWriter, r *http.Request) {
s.serveLocalStaticAsset(w, r, s.cfg.LocalUploadedStorageDir, config.DefaultLocalUploadedStorageDir)
}
func (s *Server) serveLocalStaticAsset(w http.ResponseWriter, r *http.Request, storageDir string, fallbackStorageDir string) {
fileName := filepath.Base(strings.TrimSpace(r.PathValue("asset")))
if fileName == "" || fileName == "." || fileName == ".." || fileName == string(filepath.Separator) {
http.NotFound(w, r)
return
}
storageDir = strings.TrimSpace(storageDir)
if storageDir == "" {
storageDir = fallbackStorageDir
}
filePath := filepath.Join(storageDir, fileName)
info, err := os.Stat(filePath)
if err != nil || info.IsDir() {
http.NotFound(w, r)
return
}
http.ServeFile(w, r, filePath)
}

View File

@ -0,0 +1,65 @@
package httpapi
import (
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
)
func TestServeGeneratedStaticAsset(t *testing.T) {
storageDir := t.TempDir()
if err := os.WriteFile(filepath.Join(storageDir, "result.png"), []byte("png"), 0o644); err != nil {
t.Fatalf("failed to write generated asset fixture: %v", err)
}
server := &Server{cfg: config.Config{LocalGeneratedStorageDir: storageDir}}
request := httptest.NewRequest(http.MethodGet, "/static/generated/result.png", nil)
request.SetPathValue("asset", "result.png")
response := httptest.NewRecorder()
server.serveGeneratedStaticAsset(response, request)
if response.Code != http.StatusOK {
t.Fatalf("expected generated asset to be served, got status %d", response.Code)
}
if response.Body.String() != "png" {
t.Fatalf("unexpected generated asset payload: %q", response.Body.String())
}
}
func TestServeUploadedStaticAsset(t *testing.T) {
storageDir := t.TempDir()
if err := os.WriteFile(filepath.Join(storageDir, "upload.pdf"), []byte("pdf"), 0o644); err != nil {
t.Fatalf("failed to write uploaded asset fixture: %v", err)
}
server := &Server{cfg: config.Config{LocalUploadedStorageDir: storageDir}}
request := httptest.NewRequest(http.MethodGet, "/static/uploaded/upload.pdf", nil)
request.SetPathValue("asset", "upload.pdf")
response := httptest.NewRecorder()
server.serveUploadedStaticAsset(response, request)
if response.Code != http.StatusOK {
t.Fatalf("expected uploaded asset to be served, got status %d", response.Code)
}
if response.Body.String() != "pdf" {
t.Fatalf("unexpected uploaded asset payload: %q", response.Body.String())
}
}
func TestServeLocalStaticAssetRejectsTraversal(t *testing.T) {
storageDir := t.TempDir()
server := &Server{cfg: config.Config{LocalGeneratedStorageDir: storageDir}}
request := httptest.NewRequest(http.MethodGet, "/static/generated/..", nil)
request.SetPathValue("asset", "..")
response := httptest.NewRecorder()
server.serveGeneratedStaticAsset(response, request)
if response.Code != http.StatusNotFound {
t.Fatalf("expected traversal-like generated asset name to 404, got status %d", response.Code)
}
}

View File

@ -0,0 +1,150 @@
package httpapi
import (
"encoding/json"
"net/http"
"strings"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
func (s *Server) listFileStorageChannels(w http.ResponseWriter, r *http.Request) {
items, err := s.store.ListFileStorageChannels(r.Context())
if err != nil {
s.logger.Error("list file storage channels failed", "error", err)
writeError(w, http.StatusInternalServerError, "list file storage channels failed")
return
}
writeJSON(w, http.StatusOK, map[string]any{"items": items})
}
func (s *Server) getFileStorageSettings(w http.ResponseWriter, r *http.Request) {
settings, err := s.store.GetFileStorageSettings(r.Context())
if err != nil {
if store.IsUndefinedDatabaseObject(err) {
writeJSON(w, http.StatusOK, store.DefaultFileStorageSettings())
return
}
s.logger.Error("get file storage settings failed", "error", err)
writeError(w, http.StatusInternalServerError, "get file storage settings failed")
return
}
writeJSON(w, http.StatusOK, settings)
}
func (s *Server) updateFileStorageSettings(w http.ResponseWriter, r *http.Request) {
var input store.FileStorageSettingsInput
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
writeError(w, http.StatusBadRequest, "invalid json body")
return
}
settings, err := s.store.UpdateFileStorageSettings(r.Context(), input)
if err != nil {
s.logger.Error("update file storage settings failed", "error", err)
writeError(w, http.StatusInternalServerError, "update file storage settings failed")
return
}
writeJSON(w, http.StatusOK, settings)
}
func (s *Server) createFileStorageChannel(w http.ResponseWriter, r *http.Request) {
var input store.FileStorageChannelInput
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
writeError(w, http.StatusBadRequest, "invalid json body")
return
}
if message := validateFileStorageChannelInput(input, nil); message != "" {
writeError(w, http.StatusBadRequest, message)
return
}
item, err := s.store.CreateFileStorageChannel(r.Context(), input)
if err != nil {
if store.IsUniqueViolation(err) {
writeError(w, http.StatusConflict, "file storage channel key already exists")
return
}
s.logger.Error("create file storage channel failed", "error", err)
writeError(w, http.StatusInternalServerError, "create file storage channel failed")
return
}
writeJSON(w, http.StatusCreated, item)
}
func (s *Server) updateFileStorageChannel(w http.ResponseWriter, r *http.Request) {
var input store.FileStorageChannelInput
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
writeError(w, http.StatusBadRequest, "invalid json body")
return
}
existing, err := s.store.GetFileStorageChannel(r.Context(), r.PathValue("channelID"))
if err != nil {
if store.IsNotFound(err) {
writeError(w, http.StatusNotFound, "file storage channel not found")
return
}
s.logger.Error("get file storage channel failed", "error", err)
writeError(w, http.StatusInternalServerError, "get file storage channel failed")
return
}
if message := validateFileStorageChannelInput(input, &existing); message != "" {
writeError(w, http.StatusBadRequest, message)
return
}
item, err := s.store.UpdateFileStorageChannel(r.Context(), r.PathValue("channelID"), input)
if err != nil {
if store.IsNotFound(err) {
writeError(w, http.StatusNotFound, "file storage channel not found")
return
}
if store.IsUniqueViolation(err) {
writeError(w, http.StatusConflict, "file storage channel key already exists")
return
}
s.logger.Error("update file storage channel failed", "error", err)
writeError(w, http.StatusInternalServerError, "update file storage channel failed")
return
}
writeJSON(w, http.StatusOK, item)
}
func (s *Server) deleteFileStorageChannel(w http.ResponseWriter, r *http.Request) {
if err := s.store.DeleteFileStorageChannel(r.Context(), r.PathValue("channelID")); err != nil {
if store.IsNotFound(err) {
writeError(w, http.StatusNotFound, "file storage channel not found")
return
}
s.logger.Error("delete file storage channel failed", "error", err)
writeError(w, http.StatusInternalServerError, "delete file storage channel failed")
return
}
w.WriteHeader(http.StatusNoContent)
}
func validateFileStorageChannelInput(input store.FileStorageChannelInput, existing *store.FileStorageChannel) string {
provider := strings.ToLower(strings.TrimSpace(input.Provider))
if provider == "" {
provider = "server_main_openapi"
}
status := strings.ToLower(strings.TrimSpace(input.Status))
if status == "" {
status = "disabled"
}
if strings.TrimSpace(input.ChannelKey) == "" || strings.TrimSpace(input.Name) == "" {
return "channelKey and name are required"
}
if status != "enabled" && status != "disabled" {
return "status must be enabled or disabled"
}
if provider == "server_main_openapi" {
hasAPIKey := false
if input.APIKey != nil {
hasAPIKey = strings.TrimSpace(*input.APIKey) != ""
} else if existing != nil {
hasAPIKey = strings.TrimSpace(existing.APIKey) != ""
}
if status == "enabled" && !hasAPIKey {
return "server-main OpenAPI channel requires API key before enabling"
}
}
return ""
}

View File

@ -52,9 +52,31 @@ func isLocalRateLimitError(err error) bool {
func (s *Service) rateLimitReservations(ctx context.Context, user *auth.User, candidate store.RuntimeModelCandidate, body map[string]any) []store.RateLimitReservation {
out := make([]store.RateLimitReservation, 0)
out = append(out, reservationsFromPolicy("platform_model", candidate.PlatformModelID, effectiveRateLimitPolicy(candidate), body)...)
out = append(out, reservationsFromPolicy(
"platform_model",
candidate.PlatformModelID,
firstNonEmptyString(candidate.DisplayName, candidate.ModelAlias, candidate.ModelName),
map[string]any{
"platformId": candidate.PlatformID,
"platformName": candidate.PlatformName,
"modelAlias": candidate.ModelAlias,
"modelName": candidate.ModelName,
},
effectiveRateLimitPolicy(candidate),
body,
)...)
if group, err := s.store.ResolveUserGroupPolicy(ctx, user); err == nil && group.ID != "" {
out = append(out, reservationsFromPolicy("user_group", group.ID, group.RateLimitPolicy, body)...)
out = append(out, reservationsFromPolicy(
"user_group",
group.ID,
firstNonEmptyString(group.Name, group.GroupKey),
map[string]any{
"groupKey": group.GroupKey,
"name": group.Name,
},
group.RateLimitPolicy,
body,
)...)
}
return out
}
@ -90,7 +112,7 @@ func effectiveRetryPolicy(candidate store.RuntimeModelCandidate) map[string]any
return policy
}
func reservationsFromPolicy(scopeType string, scopeKey string, policy map[string]any, body map[string]any) []store.RateLimitReservation {
func reservationsFromPolicy(scopeType string, scopeKey string, scopeName string, scopeMetadata map[string]any, policy map[string]any, body map[string]any) []store.RateLimitReservation {
if scopeKey == "" || !hasRules(policy) {
return nil
}
@ -108,11 +130,14 @@ func reservationsFromPolicy(scopeType string, scopeKey string, policy map[string
out = append(out, store.RateLimitReservation{
ScopeType: scopeType,
ScopeKey: scopeKey,
ScopeName: scopeName,
ScopeMetadata: scopeMetadata,
Metric: metric,
Limit: limit,
Amount: amount,
WindowSeconds: int(floatFromAny(rule["windowSeconds"])),
LeaseTTLSeconds: int(floatFromAny(rule["leaseTtlSeconds"])),
Policy: policy,
})
}
return out
@ -131,6 +156,11 @@ func estimateRequestTokens(body map[string]any) int {
if input := stringFromMap(body, "input"); input != "" {
text += input
}
for _, item := range contentItems(body["content"]) {
if stringFromAny(item["type"]) == "text" {
text += stringFromAny(item["text"])
}
}
if messages, ok := body["messages"].([]any); ok {
for _, raw := range messages {
message, _ := raw.(map[string]any)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,380 @@
package runner
import (
"fmt"
"math"
"strings"
)
type resolutionNormalizeProcessor struct{}
func (resolutionNormalizeProcessor) Name() string { return "ResolutionNormalizeProcessor" }
func (resolutionNormalizeProcessor) ShouldProcess(params map[string]any, modelType string, context *paramProcessContext) bool {
if stringFromAny(params["resolution"]) != "" {
return false
}
size := stringFromAny(params["size"])
if size == "" {
return false
}
return isImageResolution(modelType, size) || isVideoResolution(modelType, size)
}
func (resolutionNormalizeProcessor) Process(params map[string]any, modelType string, context *paramProcessContext) bool {
size := stringFromAny(params["size"])
if stringFromAny(params["resolution"]) == "" && (isImageResolution(modelType, size) || isVideoResolution(modelType, size)) {
_, capabilityValue := capabilityEvidence(context.modelCapability, modelType, "output_resolutions")
params["resolution"] = size
context.resolution = size
context.recordChange(
"ResolutionNormalizeProcessor",
"set",
"resolution",
nil,
size,
"size 使用分辨率格式,归一到 resolution 供后续能力校验和计费使用。",
capabilityPath(modelType, "output_resolutions"),
capabilityValue,
)
}
return true
}
type aspectRatioProcessor struct{}
func (aspectRatioProcessor) Name() string { return "AspectRatioProcessor" }
func (aspectRatioProcessor) ShouldProcess(params map[string]any, modelType string, context *paramProcessContext) bool {
return modelType != "text_generate" && (stringFromAny(params["aspect_ratio"]) != "" || stringFromAny(params["size"]) != "")
}
func (aspectRatioProcessor) Process(params map[string]any, modelType string, context *paramProcessContext) bool {
capability := capabilityForType(context.modelCapability, modelType)
if capability == nil {
return true
}
aspectRatio := stringFromAny(params["aspect_ratio"])
if isEmptyParamString(aspectRatio) {
before := params["aspect_ratio"]
delete(params, "aspect_ratio")
context.aspectRatio = ""
context.recordChange(
"AspectRatioProcessor",
"remove",
"aspect_ratio",
before,
nil,
"aspect_ratio 是空值字符串,不能作为有效比例传给上游。",
"",
nil,
)
return true
}
resolution := firstNonEmptyString(stringFromAny(params["resolution"]), context.resolution)
if resolution == "" {
if values := stringListFromAny(capability["output_resolutions"]); len(values) > 0 {
resolution = values[0]
} else if size := stringFromAny(params["size"]); strings.HasSuffix(size, "K") || strings.HasSuffix(size, "p") {
resolution = size
}
}
allowed := aspectRatioAllowed(capability["aspect_ratio_allowed"], resolution)
if allowed != nil && len(allowed) == 1 && allowed[0] == "adaptive" {
before := params["aspect_ratio"]
params["aspect_ratio"] = "adaptive"
context.aspectRatio = "adaptive"
if before != "adaptive" {
context.recordChange(
"AspectRatioProcessor",
"adjust",
"aspect_ratio",
before,
"adaptive",
"模型当前分辨率只允许 adaptive 宽高比。",
capabilityPath(modelType, "aspect_ratio_allowed"),
capability["aspect_ratio_allowed"],
)
}
return true
}
if allowed != nil && len(allowed) == 0 {
before := params["aspect_ratio"]
delete(params, "aspect_ratio")
context.aspectRatio = ""
context.recordChange(
"AspectRatioProcessor",
"remove",
"aspect_ratio",
before,
nil,
"模型能力配置不允许传入任何 aspect_ratio。",
capabilityPath(modelType, "aspect_ratio_allowed"),
capability["aspect_ratio_allowed"],
)
return true
}
if aspectRatio == "" {
return true
}
if allowed == nil && validAspectRatio(aspectRatio) {
params["aspect_ratio"] = aspectRatio
context.aspectRatio = aspectRatio
return true
}
processed, ok := validateAndAdjustAspectRatio(aspectRatio, capability, allowed)
if !ok {
before := params["aspect_ratio"]
delete(params, "aspect_ratio")
context.aspectRatio = ""
context.recordChange(
"AspectRatioProcessor",
"remove",
"aspect_ratio",
before,
nil,
"传入的 aspect_ratio 不在模型允许范围内,且没有可用替代值。",
capabilityPath(modelType, "aspect_ratio_allowed"),
capability["aspect_ratio_allowed"],
)
return true
}
if processed != "" {
before := params["aspect_ratio"]
params["aspect_ratio"] = processed
context.aspectRatio = processed
if before != processed {
path := capabilityPath(modelType, "aspect_ratio_allowed")
value := capability["aspect_ratio_allowed"]
if ratioRange, ok := numberPair(capability["aspect_ratio_range"]); ok {
ratio, valid := aspectRatioNumber(aspectRatio)
if !valid || ratio < ratioRange[0] || ratio > ratioRange[1] {
path = capabilityPath(modelType, "aspect_ratio_range")
value = capability["aspect_ratio_range"]
}
}
context.recordChange(
"AspectRatioProcessor",
"adjust",
"aspect_ratio",
before,
processed,
"传入的 aspect_ratio 不符合模型能力配置,已调整为允许值。",
path,
value,
)
}
}
return true
}
type inputAudioProcessor struct{}
func (inputAudioProcessor) Name() string { return "InputAudioProcessor" }
func (inputAudioProcessor) ShouldProcess(params map[string]any, modelType string, context *paramProcessContext) bool {
if !isVideoModelType(modelType) {
return false
}
content := contentItems(params["content"])
for _, item := range content {
if isAudioContent(item) {
return true
}
}
return false
}
func (inputAudioProcessor) Process(params map[string]any, modelType string, context *paramProcessContext) bool {
content := contentItems(params["content"])
if len(content) == 0 {
return true
}
supportsInputAudio := false
if len(context.modelCapability) > 0 {
if isOmniVideoLike(context) {
supportsInputAudio = supportsOmniAudioReference(context)
} else if capability := capabilityForType(context.modelCapability, modelType); capability != nil {
supportsInputAudio = boolFromAny(capability["input_audio"])
}
}
if supportsInputAudio {
return true
}
next := make([]map[string]any, 0, len(content))
for index, item := range content {
if isAudioContent(item) {
path, value := audioInputCapabilityEvidence(context, modelType)
context.recordChange(
"InputAudioProcessor",
"remove",
fmt.Sprintf("content[%d]", index),
item,
nil,
"模型能力未开启输入音频,已移除 audio_url。",
path,
value,
)
continue
}
next = append(next, item)
}
params["content"] = mapsToAnySlice(next)
path, value := audioInputCapabilityEvidence(context, modelType)
deleteFieldsWithLog(params, context, "InputAudioProcessor", []string{"audio_url", "audioUrl", "reference_audio", "referenceAudio"}, "模型能力未开启输入音频,已移除音频参考快捷字段。", path, value)
return true
}
type durationProcessor struct{}
func (durationProcessor) Name() string { return "DurationProcessor" }
func (durationProcessor) ShouldProcess(params map[string]any, modelType string, context *paramProcessContext) bool {
return isVideoModelType(modelType) && params["duration"] != nil
}
func (durationProcessor) Process(params map[string]any, modelType string, context *paramProcessContext) bool {
capability := capabilityForType(context.modelCapability, modelType)
if capability == nil {
return true
}
duration := floatFromAny(params["duration"])
if duration <= 0 {
return true
}
resolution := firstNonEmptyString(stringFromAny(params["resolution"]), context.resolution)
modeKey := videoModeKey(params)
if options := scopedNumberList(capability["duration_options"], resolution, modeKey); len(options) > 0 {
normalized := nextAllowedNumber(duration, options)
params["duration"] = normalized
syncDurationSeconds(params)
if normalized != duration {
context.recordChange(
"DurationProcessor",
"adjust",
"duration",
duration,
normalized,
"duration 不在模型固定时长选项内,已向上调整为允许值。",
capabilityPath(modelType, "duration_options"),
capability["duration_options"],
)
}
return true
}
if minValue, maxValue, ok := scopedRange(capability["duration_range"], resolution, modeKey); ok {
step := durationStep(capability["duration_step"], resolution, modeKey)
normalized := normalizeDurationByRange(duration, minValue, maxValue, step)
params["duration"] = normalized
syncDurationSeconds(params)
if normalized != duration {
context.recordChange(
"DurationProcessor",
"adjust",
"duration",
duration,
normalized,
"duration 超出模型时长范围或步进配置,已按能力配置归一。",
capabilityPath(modelType, "duration_range"),
map[string]any{
"duration_range": capability["duration_range"],
"duration_step": capability["duration_step"],
},
)
}
return true
}
step := durationStep(capability["duration_step"], resolution, modeKey)
normalized := normalizeDurationByStep(duration, step)
params["duration"] = normalized
syncDurationSeconds(params)
if normalized != duration {
context.recordChange(
"DurationProcessor",
"adjust",
"duration",
duration,
normalized,
"duration 不符合模型时长步进,已按步进向上归一。",
capabilityPath(modelType, "duration_step"),
capability["duration_step"],
)
}
return true
}
type audioProcessor struct{}
func (audioProcessor) Name() string { return "AudioProcessor" }
func (audioProcessor) ShouldProcess(params map[string]any, modelType string, context *paramProcessContext) bool {
return isVideoModelType(modelType) && (params["audio"] != nil || params["output_audio"] != nil)
}
func (audioProcessor) Process(params map[string]any, modelType string, context *paramProcessContext) bool {
capability := capabilityForType(context.modelCapability, modelType)
if capability == nil || !boolFromAny(capability["output_audio"]) {
for _, key := range []string{"audio", "output_audio"} {
if before, ok := params[key]; ok {
delete(params, key)
context.recordChange(
"AudioProcessor",
"remove",
key,
before,
nil,
"模型能力未开启输出音频,已移除音频输出参数。",
capabilityPath(modelType, "output_audio"),
capabilityValue(context.modelCapability, modelType, "output_audio"),
)
}
}
}
return true
}
type imageCountProcessor struct{}
func (imageCountProcessor) Name() string { return "ImageCountProcessor" }
func (imageCountProcessor) ShouldProcess(params map[string]any, modelType string, context *paramProcessContext) bool {
return modelType == "image_generate" || modelType == "image_edit"
}
func (imageCountProcessor) Process(params map[string]any, modelType string, context *paramProcessContext) bool {
capability := capabilityForType(context.modelCapability, modelType)
if capability == nil || !boolFromAny(capability["output_multiple_images"]) {
return true
}
maxCount := int(math.Round(floatFromAny(capability["output_max_images_count"])))
if maxCount <= 0 {
return true
}
count := int(math.Round(floatFromAny(params["n"])))
if count <= 0 {
count = int(math.Round(floatFromAny(params["batch_size"])))
}
if count <= 0 {
count = 1
}
if count > maxCount {
before := count
count = maxCount
context.recordChange(
"ImageCountProcessor",
"adjust",
"n",
before,
count,
"请求图片数量超过模型输出上限,已按 output_max_images_count 截断。",
capabilityPath(modelType, "output_max_images_count"),
capability["output_max_images_count"],
)
}
params["n"] = count
return true
}

View File

@ -0,0 +1,190 @@
package runner
import "fmt"
type messageContentProcessor struct{}
func (messageContentProcessor) Name() string { return "MessageContentProcessor" }
func (messageContentProcessor) ShouldProcess(params map[string]any, modelType string, context *paramProcessContext) bool {
return isTextGenerationKind(context.kind) && params["messages"] != nil
}
func (messageContentProcessor) Process(params map[string]any, modelType string, context *paramProcessContext) bool {
messages, changed := processMessageListContent(params["messages"], context)
if changed {
params["messages"] = messages
}
return true
}
func processMessageListContent(value any, context *paramProcessContext) ([]any, bool) {
rawMessages, ok := value.([]any)
if !ok {
return nil, false
}
out := make([]any, 0, len(rawMessages))
changed := false
for messageIndex, rawMessage := range rawMessages {
message, ok := rawMessage.(map[string]any)
if !ok {
out = append(out, rawMessage)
continue
}
nextMessage := cloneMap(message)
if contentParts, ok := message["content"].([]any); ok {
nextContent, contentChanged := processMessageContentParts(
contentParts,
fmt.Sprintf("messages[%d].content", messageIndex),
context,
)
if contentChanged {
nextMessage["content"] = nextContent
changed = true
}
}
out = append(out, nextMessage)
}
return out, changed
}
func processMessageContentParts(parts []any, basePath string, context *paramProcessContext) ([]any, bool) {
out := make([]any, 0, len(parts))
changed := false
for partIndex, rawPart := range parts {
part, ok := rawPart.(map[string]any)
if !ok {
out = append(out, rawPart)
continue
}
if replacement, replacementChanged := messageContentPartReplacement(part, context); replacementChanged {
out = append(out, replacement)
context.recordChange(
"MessageContentProcessor",
"convert",
fmt.Sprintf("%s[%d]", basePath, partIndex),
part,
replacement,
messageContentConversionReason(part),
messageContentCapabilityPath(part),
messageContentCapabilityValue(part, context),
)
changed = true
continue
}
out = append(out, cloneMap(part))
}
return out, changed
}
func messageContentPartReplacement(part map[string]any, context *paramProcessContext) (map[string]any, bool) {
switch {
case isImageContent(part):
if modelSupportsMessageModality(context, "image_analysis") {
return nil, false
}
if url := imageURLFromContentPart(part); url != "" {
return map[string]any{"type": "text", "text": "Image link: " + url}, true
}
case isVideoContent(part):
if modelSupportsMessageModality(context, "video_understanding") {
return nil, false
}
if url := videoURLFromContentPart(part); url != "" {
return map[string]any{"type": "text", "text": "video URL: " + url}, true
}
case isAudioContent(part) || stringFromAny(part["type"]) == "input_audio":
if modelSupportsMessageModality(context, "audio_understanding") {
return nil, false
}
if url := audioURLFromContentPart(part); url != "" {
return map[string]any{"type": "text", "text": "audio URL: " + url}, true
}
}
return nil, false
}
func messageContentConversionReason(part map[string]any) string {
switch {
case isImageContent(part):
return "模型不支持图像理解,已将 image_url 转为文本链接。"
case isVideoContent(part):
return "模型不支持视频理解,已将 video_url 转为文本链接。"
default:
return "模型不支持音频理解,已将音频输入转为文本链接。"
}
}
func messageContentCapabilityPath(part map[string]any) string {
switch {
case isImageContent(part):
return "capabilities.image_analysis"
case isVideoContent(part):
return "capabilities.video_understanding"
default:
return "capabilities.audio_understanding"
}
}
func messageContentCapabilityValue(part map[string]any, context *paramProcessContext) any {
if context == nil {
return nil
}
switch {
case isImageContent(part):
return capabilityValue(context.modelCapability, "image_analysis", "")
case isVideoContent(part):
return capabilityValue(context.modelCapability, "video_understanding", "")
default:
return capabilityValue(context.modelCapability, "audio_understanding", "")
}
}
func modelSupportsMessageModality(context *paramProcessContext, capabilityName string) bool {
if context == nil {
return false
}
capabilities := context.modelCapability
if capabilityForType(capabilities, capabilityName) != nil {
return true
}
if capabilityForType(capabilities, "omni") != nil {
return true
}
originalTypes := stringListFromAny(capabilities["originalTypes"])
return containsString(originalTypes, capabilityName) || containsString(originalTypes, "omni")
}
func imageURLFromContentPart(part map[string]any) string {
return urlFromNestedContentPart(part, "image_url", "url", "imageUrl")
}
func videoURLFromContentPart(part map[string]any) string {
return urlFromNestedContentPart(part, "video_url", "url", "videoUrl")
}
func audioURLFromContentPart(part map[string]any) string {
if stringFromAny(part["type"]) == "input_audio" {
if audio, ok := part["input_audio"].(map[string]any); ok {
if url := firstNonEmptyString(stringFromAny(audio["data"]), stringFromAny(audio["url"])); url != "" {
return url
}
}
}
return urlFromNestedContentPart(part, "audio_url", "url", "audioUrl")
}
func urlFromNestedContentPart(part map[string]any, keys ...string) string {
for _, key := range keys {
value := part[key]
if url := stringFromAny(value); url != "" {
return url
}
if nested, ok := value.(map[string]any); ok {
if url := stringFromAny(nested["url"]); url != "" {
return url
}
}
}
return ""
}

View File

@ -6,6 +6,50 @@ import (
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
func TestVideoModelTypeInferenceReadsContentArray(t *testing.T) {
imageToVideo := modelTypeFromKind("videos.generations", map[string]any{
"model": "demo-video",
"content": []any{
map[string]any{"type": "text", "text": "animate it"},
map[string]any{"type": "image_url", "role": "first_frame", "image_url": map[string]any{"url": "https://example.com/frame.png"}},
},
})
if imageToVideo != "image_to_video" {
t.Fatalf("image content should infer image_to_video, got %s", imageToVideo)
}
omniVideo := modelTypeFromKind("videos.generations", map[string]any{
"model": "demo-video",
"content": []any{
map[string]any{"type": "text", "text": "edit it"},
map[string]any{"type": "video_url", "role": "reference_video", "video_url": map[string]any{"url": "https://example.com/ref.mp4"}},
},
})
if omniVideo != "omni_video" {
t.Fatalf("video content should infer omni_video, got %s", omniVideo)
}
textToVideo := modelTypeFromKind("videos.generations", map[string]any{
"model": "demo-video",
"content": []any{map[string]any{"type": "text", "text": "make a clip"}},
})
if textToVideo != "video_generate" {
t.Fatalf("text-only content should infer video_generate, got %s", textToVideo)
}
}
func TestVideoContentTextContributesToTokenEstimate(t *testing.T) {
tokens := estimateRequestTokens(map[string]any{
"model": "demo-video",
"content": []any{
map[string]any{"type": "text", "text": "a cinematic product reveal"},
},
})
if tokens <= 1 {
t.Fatalf("content text should contribute to token estimate, got %d", tokens)
}
}
func TestParamProcessorOmniFiltersUnsupportedVideoAndAudioContent(t *testing.T) {
body := map[string]any{
"model": "可灵O1",
@ -123,6 +167,163 @@ func TestParamProcessorOmniCapabilityLogUsesActualCapabilityKey(t *testing.T) {
t.Fatalf("expected log to reference capabilities.omni.input_audio, got %+v", result.Log.Changes)
}
func TestParamProcessorChatConvertsUnsupportedMediaMessageContentToText(t *testing.T) {
body := map[string]any{
"model": "text-only",
"messages": []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{"type": "text", "text": "describe these"},
map[string]any{"type": "image_url", "image_url": map[string]any{"url": "https://example.com/image.png"}},
map[string]any{"type": "video_url", "video_url": map[string]any{"url": "https://example.com/video.mp4"}},
map[string]any{"type": "audio_url", "audio_url": map[string]any{"url": "https://example.com/audio.mp3"}},
map[string]any{"type": "input_audio", "input_audio": map[string]any{"data": "https://example.com/input.wav"}},
},
},
},
}
candidate := store.RuntimeModelCandidate{
ModelType: "text_generate",
Capabilities: map[string]any{
"text_generate": map[string]any{},
"originalTypes": []any{"text_generate"},
},
}
result := preprocessRequestWithLog("chat.completions", body, candidate)
messages, _ := result.Body["messages"].([]any)
if len(messages) != 1 {
t.Fatalf("expected one message, got %+v", result.Body["messages"])
}
message, _ := messages[0].(map[string]any)
content, _ := message["content"].([]any)
if len(content) != 5 {
t.Fatalf("expected five content parts, got %+v", message["content"])
}
expectedText := []string{
"describe these",
"Image link: https://example.com/image.png",
"video URL: https://example.com/video.mp4",
"audio URL: https://example.com/audio.mp3",
"audio URL: https://example.com/input.wav",
}
for index, expected := range expectedText {
part, _ := content[index].(map[string]any)
if stringFromAny(part["text"]) != expected {
t.Fatalf("content[%d] text = %q, want %q; all=%+v", index, stringFromAny(part["text"]), expected, content)
}
}
if len(result.Log.Changes) != 4 {
t.Fatalf("expected four media conversion changes, got %+v", result.Log.Changes)
}
expectedCapabilityPaths := map[string]bool{
"capabilities.image_analysis": false,
"capabilities.video_understanding": false,
"capabilities.audio_understanding": false,
}
for _, change := range result.Log.Changes {
if _, ok := expectedCapabilityPaths[change.CapabilityPath]; ok {
expectedCapabilityPaths[change.CapabilityPath] = true
}
}
for path, found := range expectedCapabilityPaths {
if !found {
t.Fatalf("expected conversion log for %s, got %+v", path, result.Log.Changes)
}
}
}
func TestParamProcessorChatKeepsOmniMessageContent(t *testing.T) {
body := map[string]any{
"model": "omni",
"messages": []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{"type": "image_url", "image_url": map[string]any{"url": "https://example.com/image.png"}},
map[string]any{"type": "video_url", "video_url": map[string]any{"url": "https://example.com/video.mp4"}},
map[string]any{"type": "audio_url", "audio_url": map[string]any{"url": "https://example.com/audio.mp3"}},
},
},
},
}
candidate := store.RuntimeModelCandidate{
ModelType: "text_generate",
Capabilities: map[string]any{
"text_generate": map[string]any{},
"omni": map[string]any{},
"originalTypes": []any{"text_generate", "omni"},
},
}
result := preprocessRequestWithLog("chat.completions", body, candidate)
if result.Log.Changed {
t.Fatalf("omni model should keep message media content unchanged, got %+v", result.Log.Changes)
}
messages, _ := result.Body["messages"].([]any)
message, _ := messages[0].(map[string]any)
content, _ := message["content"].([]any)
for _, item := range content {
part, _ := item.(map[string]any)
if stringFromAny(part["type"]) == "text" {
t.Fatalf("media content should not be converted for omni model: %+v", content)
}
}
}
func TestParamProcessorChatConvertsOnlyUnsupportedModalities(t *testing.T) {
body := map[string]any{
"model": "vision-only",
"messages": []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{"type": "image_url", "image_url": map[string]any{"url": "https://example.com/image.png"}},
map[string]any{"type": "video_url", "video_url": map[string]any{"url": "https://example.com/video.mp4"}},
},
},
},
}
candidate := store.RuntimeModelCandidate{
ModelType: "text_generate",
Capabilities: map[string]any{
"text_generate": map[string]any{},
"image_analysis": map[string]any{},
"originalTypes": []any{"text_generate", "image_analysis"},
},
}
result := preprocessRequestWithLog("chat.completions", body, candidate)
messages, _ := result.Body["messages"].([]any)
message, _ := messages[0].(map[string]any)
content, _ := message["content"].([]any)
first, _ := content[0].(map[string]any)
second, _ := content[1].(map[string]any)
if stringFromAny(first["type"]) != "image_url" {
t.Fatalf("image content should be kept when image_analysis is supported: %+v", content)
}
if stringFromAny(second["text"]) != "video URL: https://example.com/video.mp4" {
t.Fatalf("video content should be converted, got %+v", second)
}
if len(result.Log.Changes) != 1 || result.Log.Changes[0].CapabilityPath != "capabilities.video_understanding" {
t.Fatalf("expected only video conversion to be logged, got %+v", result.Log.Changes)
}
}
func TestSkipTaskParameterPreprocessingLogForTextModelTypes(t *testing.T) {
for _, modelType := range []string{"text_generate", "chat", "responses", "text"} {
if !skipTaskParameterPreprocessingLog(modelType) {
t.Fatalf("%s should skip task parameter preprocessing log", modelType)
}
}
for _, modelType := range []string{"image_generate", "image_edit", "video_generate", "omni_video"} {
if skipTaskParameterPreprocessingLog(modelType) {
t.Fatalf("%s should keep task parameter preprocessing log", modelType)
}
}
}
func TestParamProcessorVideoCapabilitiesNormalizeAndFilter(t *testing.T) {
body := map[string]any{
"model": "Seedance",
@ -180,6 +381,222 @@ func TestParamProcessorVideoCapabilitiesNormalizeAndFilter(t *testing.T) {
}
}
func TestParamProcessorDowngradesReferenceImagesToFrames(t *testing.T) {
candidate := store.RuntimeModelCandidate{
ModelType: "image_to_video",
Capabilities: map[string]any{
"image_to_video": map[string]any{
"input_first_frame": true,
"input_first_last_frame": true,
"input_reference_generate_single": false,
"input_reference_generate_multiple": false,
},
},
}
body := map[string]any{
"model": "Seedance",
"prompt": "animate it",
"content": []any{
map[string]any{"type": "text", "text": "animate it"},
map[string]any{"type": "image_url", "role": "reference_image", "image_url": map[string]any{"url": "https://example.com/first.png"}},
map[string]any{"type": "image_url", "role": "reference_image", "image_url": map[string]any{"url": "https://example.com/last.png"}},
},
}
result := preprocessRequestWithLog("videos.generations", body, candidate)
if result.Err != nil {
t.Fatalf("two image references should downgrade to first/last frames: %v", result.Err)
}
content := contentItems(result.Body["content"])
if stringFromAny(content[1]["role"]) != "first_frame" || stringFromAny(content[2]["role"]) != "last_frame" {
t.Fatalf("expected first/last frame downgrade, got %+v", content)
}
}
func TestParamProcessorDowngradesSingleReferenceImageToFirstFrame(t *testing.T) {
candidate := store.RuntimeModelCandidate{
ModelType: "image_to_video",
Capabilities: map[string]any{
"image_to_video": map[string]any{
"input_first_frame": true,
"input_first_last_frame": true,
"input_reference_generate_single": false,
"input_reference_generate_multiple": false,
},
},
}
body := map[string]any{
"model": "Seedance",
"prompt": "animate it",
"content": []any{
map[string]any{"type": "text", "text": "animate it"},
map[string]any{"type": "image_url", "role": "reference_image", "image_url": map[string]any{"url": "https://example.com/first.png"}},
},
}
result := preprocessRequestWithLog("videos.generations", body, candidate)
if result.Err != nil {
t.Fatalf("single image reference should downgrade to first frame: %v", result.Err)
}
content := contentItems(result.Body["content"])
if stringFromAny(content[1]["role"]) != "first_frame" {
t.Fatalf("expected first frame downgrade, got %+v", content)
}
}
func TestParamProcessorRejectsUnsafeReferenceImageDowngrade(t *testing.T) {
candidate := store.RuntimeModelCandidate{
ModelType: "image_to_video",
Capabilities: map[string]any{
"image_to_video": map[string]any{
"input_first_frame": true,
"input_first_last_frame": false,
"input_reference_generate_single": false,
"input_reference_generate_multiple": false,
},
},
}
body := map[string]any{
"model": "Seedance",
"prompt": "animate it",
"content": []any{
map[string]any{"type": "text", "text": "animate it"},
map[string]any{"type": "image_url", "role": "reference_image", "image_url": map[string]any{"url": "https://example.com/first.png"}},
map[string]any{"type": "image_url", "role": "reference_image", "image_url": map[string]any{"url": "https://example.com/last.png"}},
},
}
result := preprocessRequestWithLog("videos.generations", body, candidate)
if result.Err == nil {
t.Fatalf("two image references should be rejected when first/last frame is unsupported")
}
if len(result.Log.Changes) == 0 || result.Log.Changes[len(result.Log.Changes)-1].Action != "reject" {
t.Fatalf("expected reject preprocessing log, got %+v", result.Log.Changes)
}
}
func TestParamProcessorRejectsVideoOrAudioReferenceDowngrade(t *testing.T) {
candidate := store.RuntimeModelCandidate{
ModelType: "image_to_video",
Capabilities: map[string]any{
"image_to_video": map[string]any{
"input_first_frame": true,
"input_first_last_frame": true,
"input_reference_generate_single": false,
"input_reference_generate_multiple": false,
},
},
}
body := map[string]any{
"model": "Seedance",
"prompt": "animate it",
"content": []any{
map[string]any{"type": "text", "text": "animate it"},
map[string]any{"type": "image_url", "role": "reference_image", "image_url": map[string]any{"url": "https://example.com/first.png"}},
map[string]any{"type": "video_url", "role": "reference_video", "video_url": map[string]any{"url": "https://example.com/ref.mp4"}},
},
}
result := preprocessRequestWithLog("videos.generations", body, candidate)
if result.Err == nil {
t.Fatalf("video reference should be rejected instead of downgraded")
}
}
func TestParamProcessorDurationRangeRoundsFractionalSecondsUp(t *testing.T) {
body := map[string]any{
"model": "Seedance",
"prompt": "animate it",
"duration": 5.5,
}
candidate := store.RuntimeModelCandidate{
ModelType: "video_generate",
Capabilities: map[string]any{
"video_generate": map[string]any{
"duration_range": []any{3, 12},
},
},
}
result := preprocessRequestWithLog("videos.generations", body, candidate)
if result.Body["duration"] != float64(6) && result.Body["duration"] != 6 {
t.Fatalf("fractional duration should be rounded up to default 1s step, got %+v", result.Body["duration"])
}
}
func TestParamProcessorDurationWithoutRangeStillRoundsUp(t *testing.T) {
body := map[string]any{
"model": "Seedance",
"prompt": "animate it",
"duration": 5.2,
}
candidate := store.RuntimeModelCandidate{
ModelType: "video_generate",
Capabilities: map[string]any{
"video_generate": map[string]any{},
},
}
result := preprocessRequestWithLog("videos.generations", body, candidate)
if result.Body["duration"] != float64(6) && result.Body["duration"] != 6 {
t.Fatalf("duration should default to a 1s upward step without range, got %+v", result.Body["duration"])
}
}
func TestParamProcessorDurationRangeUsesStepCeilingAndRange(t *testing.T) {
body := map[string]any{
"model": "Seedance",
"prompt": "animate it",
"duration": 6.1,
"duration_seconds": 6.1,
}
candidate := store.RuntimeModelCandidate{
ModelType: "image_to_video",
Capabilities: map[string]any{
"image_to_video": map[string]any{
"duration_range": []any{5, 10},
"duration_step": 2,
},
},
}
result := preprocessRequestWithLog("videos.generations", body, candidate)
if result.Body["duration"] != float64(7) && result.Body["duration"] != 7 {
t.Fatalf("duration should be rounded up by configured step, got %+v", result.Body["duration"])
}
if result.Body["duration_seconds"] != result.Body["duration"] {
t.Fatalf("duration_seconds should sync with normalized duration, got %+v", result.Body)
}
body["duration"] = 10.1
body["duration_seconds"] = 10.1
result = preprocessRequestWithLog("videos.generations", body, candidate)
if result.Body["duration"] != float64(10) && result.Body["duration"] != 10 {
t.Fatalf("duration should be capped by range max, got %+v", result.Body["duration"])
}
}
func TestParamProcessorDurationOptionsChooseNextAllowedValue(t *testing.T) {
body := map[string]any{
"model": "Seedance",
"prompt": "animate it",
"duration": 8.1,
}
candidate := store.RuntimeModelCandidate{
ModelType: "image_to_video",
Capabilities: map[string]any{
"image_to_video": map[string]any{
"duration_options": []any{4, 8, 12},
},
},
}
result := preprocessRequestWithLog("videos.generations", body, candidate)
if result.Body["duration"] != float64(12) && result.Body["duration"] != 12 {
t.Fatalf("duration should use next allowed option, got %+v", result.Body["duration"])
}
}
func TestParamProcessorVideoGenerateLogsFirstFrameRemoval(t *testing.T) {
body := map[string]any{
"model": "Seedance T2V",

View File

@ -0,0 +1,511 @@
package runner
import (
"math"
"sort"
"strconv"
"strings"
)
func validateAndAdjustAspectRatio(aspectRatio string, capability map[string]any, allowed []string) (string, bool) {
if !isMediaModelTypeWithAspectRatio(capability) {
return "", false
}
if ratioRange, ok := numberPair(capability["aspect_ratio_range"]); ok {
ratio, valid := aspectRatioNumber(aspectRatio)
if !valid || ratio < ratioRange[0] || ratio > ratioRange[1] {
return adjustAspectRatioToRange(aspectRatio, ratioRange[0], ratioRange[1], allowed), true
}
}
if allowed == nil {
return aspectRatio, true
}
if len(allowed) == 0 {
return "", false
}
if (aspectRatio == "adaptive" || aspectRatio == "keep_ratio") && !containsString(allowed, aspectRatio) {
return "", false
}
if containsString(allowed, aspectRatio) {
return aspectRatio, true
}
return allowed[0], true
}
func isMediaModelTypeWithAspectRatio(capability map[string]any) bool {
return capability != nil
}
func aspectRatioAllowed(value any, resolution string) []string {
switch typed := value.(type) {
case []any:
return stringListFromAny(typed)
case []string:
return typed
case map[string]any:
if resolution != "" {
if values := stringListFromAny(typed[resolution]); len(values) > 0 {
return values
}
}
return nil
default:
return nil
}
}
func scopedNumberList(value any, scopes ...string) []float64 {
switch typed := value.(type) {
case []any:
out := make([]float64, 0, len(typed))
for _, item := range typed {
if number := floatFromAny(item); number > 0 {
out = append(out, number)
}
}
return out
case []float64:
return typed
case []int:
out := make([]float64, 0, len(typed))
for _, item := range typed {
out = append(out, float64(item))
}
return out
case map[string]any:
for _, scope := range scopes {
if scope == "" {
continue
}
if values := scopedNumberList(typed[scope]); len(values) > 0 {
return values
}
}
for _, item := range typed {
if values := scopedNumberList(item); len(values) > 0 {
return values
}
}
}
return nil
}
func scopedRange(value any, scopes ...string) (float64, float64, bool) {
if pair, ok := numberPair(value); ok {
return pair[0], pair[1], true
}
if typed, ok := value.(map[string]any); ok {
for _, scope := range scopes {
if scope == "" {
continue
}
if minValue, maxValue, ok := scopedRange(typed[scope]); ok {
return minValue, maxValue, true
}
}
for _, item := range typed {
if minValue, maxValue, ok := scopedRange(item); ok {
return minValue, maxValue, true
}
}
}
return 0, 0, false
}
func durationStep(value any, scopes ...string) float64 {
if step := floatFromAny(value); step > 0 {
return step
}
if typed, ok := value.(map[string]any); ok {
for _, scope := range scopes {
if scope == "" {
continue
}
if step := durationStep(typed[scope]); step > 0 {
return step
}
}
for _, item := range typed {
if step := durationStep(item); step > 0 {
return step
}
}
}
return 0
}
func normalizeDurationByRange(target float64, minValue float64, maxValue float64, step float64) float64 {
if minValue > maxValue {
minValue, maxValue = maxValue, minValue
}
if step <= 0 {
step = 1
}
clamped := math.Min(math.Max(target, minValue), maxValue)
snapped := math.Ceil(((clamped-minValue)/step)-1e-9)*step + minValue
snapped = math.Min(math.Max(snapped, minValue), maxValue)
return math.Round(snapped*1_000_000) / 1_000_000
}
func normalizeDurationByStep(target float64, step float64) float64 {
if step <= 0 {
step = 1
}
snapped := math.Ceil((target/step)-1e-9) * step
return math.Round(snapped*1_000_000) / 1_000_000
}
func nextAllowedNumber(target float64, values []float64) float64 {
if len(values) == 0 {
return target
}
sorted := append([]float64(nil), values...)
sort.Float64s(sorted)
for _, value := range sorted {
if value >= target || math.Abs(value-target) < 1e-9 {
return value
}
}
return sorted[len(sorted)-1]
}
func contentItems(value any) []map[string]any {
switch typed := value.(type) {
case []any:
out := make([]map[string]any, 0, len(typed))
for _, item := range typed {
if object, ok := item.(map[string]any); ok {
out = append(out, cloneMap(object))
}
}
return out
case []map[string]any:
out := make([]map[string]any, 0, len(typed))
for _, item := range typed {
out = append(out, cloneMap(item))
}
return out
default:
return nil
}
}
func mapsToAnySlice(values []map[string]any) []any {
out := make([]any, 0, len(values))
for _, value := range values {
out = append(out, value)
}
return out
}
func isImageContent(item map[string]any) bool {
return stringFromAny(item["type"]) == "image_url" || item["image_url"] != nil
}
func isVideoContent(item map[string]any) bool {
return stringFromAny(item["type"]) == "video_url" || item["video_url"] != nil
}
func isAudioContent(item map[string]any) bool {
return stringFromAny(item["type"]) == "audio_url" || item["audio_url"] != nil
}
func capabilityForType(capabilities map[string]any, modelType string) map[string]any {
if capabilities == nil {
return nil
}
if typed, ok := capabilities[modelType].(map[string]any); ok {
return typed
}
return nil
}
func capabilityPath(modelType string, key string) string {
modelType = strings.TrimSpace(modelType)
if modelType == "" {
return ""
}
if strings.TrimSpace(key) == "" {
return "capabilities." + modelType
}
return "capabilities." + modelType + "." + key
}
func capabilityValue(capabilities map[string]any, modelType string, key string) any {
capability := capabilityForType(capabilities, modelType)
if capability == nil {
return nil
}
if strings.TrimSpace(key) == "" {
return cloneMap(capability)
}
return cloneAny(capability[key])
}
func capabilityEvidence(capabilities map[string]any, modelType string, key string) (string, any) {
return capabilityPath(modelType, key), capabilityValue(capabilities, modelType, key)
}
func audioInputCapabilityEvidence(context *paramProcessContext, modelType string) (string, any) {
if isOmniVideoLike(context) {
path, value := omniCapabilityEvidence(context, "input_audio")
return path, mergeMetrics(map[string]any{"input_audio": value}, omniCapabilityBundle(context, "max_audios"))
}
return capabilityEvidence(context.modelCapability, modelType, "input_audio")
}
func omniCapabilityType(context *paramProcessContext) string {
if context != nil && capabilityForType(context.modelCapability, "omni_video") != nil {
return "omni_video"
}
if context != nil && capabilityForType(context.modelCapability, "omni") != nil {
return "omni"
}
return "omni_video"
}
func omniCapabilityEvidence(context *paramProcessContext, key string) (string, any) {
modelType := omniCapabilityType(context)
var capabilities map[string]any
if context != nil {
capabilities = context.modelCapability
}
return capabilityPath(modelType, key), capabilityValue(capabilities, modelType, key)
}
func omniCapabilityBundle(context *paramProcessContext, keys ...string) map[string]any {
modelType := omniCapabilityType(context)
var capabilities map[string]any
if context != nil {
capabilities = context.modelCapability
}
out := map[string]any{}
for _, key := range keys {
out[key] = capabilityValue(capabilities, modelType, key)
}
return out
}
func numericField(values map[string]any, key string) (float64, bool) {
if values == nil {
return 0, false
}
if _, ok := values[key]; !ok {
return 0, false
}
return floatFromAny(values[key]), true
}
func boolFromAny(value any) bool {
typed, _ := value.(bool)
return typed
}
func firstNonEmptyStringValue(values map[string]any, keys ...string) string {
for _, key := range keys {
if value := stringFromAny(values[key]); value != "" {
return value
}
}
return ""
}
func firstNonEmptyStringListFromAny(values ...any) []string {
for _, value := range values {
items := stringListFromAny(value)
if len(items) > 0 {
return items
}
}
return nil
}
func stringListFromAny(value any) []string {
switch typed := value.(type) {
case []string:
out := make([]string, 0, len(typed))
for _, item := range typed {
if text := strings.TrimSpace(item); text != "" {
out = append(out, text)
}
}
return out
case []any:
out := make([]string, 0, len(typed))
for _, item := range typed {
if text := stringFromAny(item); text != "" {
out = append(out, text)
}
}
return out
case string:
if strings.TrimSpace(typed) == "" {
return nil
}
return []string{strings.TrimSpace(typed)}
default:
return nil
}
}
func containsString(values []string, target string) bool {
for _, value := range values {
if value == target {
return true
}
}
return false
}
func appendUniqueString(values *[]string, value string) {
value = strings.TrimSpace(value)
if value == "" {
return
}
for _, existing := range *values {
if existing == value {
return
}
}
*values = append(*values, value)
}
func numberPair(value any) ([2]float64, bool) {
switch typed := value.(type) {
case []any:
if len(typed) < 2 {
return [2]float64{}, false
}
return [2]float64{floatFromAny(typed[0]), floatFromAny(typed[1])}, true
case []float64:
if len(typed) < 2 {
return [2]float64{}, false
}
return [2]float64{typed[0], typed[1]}, true
case []int:
if len(typed) < 2 {
return [2]float64{}, false
}
return [2]float64{float64(typed[0]), float64(typed[1])}, true
default:
return [2]float64{}, false
}
}
func validAspectRatio(value string) bool {
if value == "adaptive" || value == "keep_ratio" {
return true
}
_, ok := aspectRatioNumber(value)
return ok
}
func aspectRatioNumber(value string) (float64, bool) {
parts := strings.Split(value, ":")
if len(parts) != 2 {
return 0, false
}
width := parsePositiveFloat(parts[0])
height := parsePositiveFloat(parts[1])
if width <= 0 || height <= 0 {
return 0, false
}
return width / height, true
}
func adjustAspectRatioToRange(value string, minValue float64, maxValue float64, allowed []string) string {
current, ok := aspectRatioNumber(value)
if !ok {
if len(allowed) > 0 {
return allowed[0]
}
return "1:1"
}
if len(allowed) > 0 {
closest := ""
minDiff := math.Inf(1)
for _, candidate := range allowed {
ratio, ok := aspectRatioNumber(candidate)
if !ok || ratio < minValue || ratio > maxValue {
continue
}
diff := math.Abs(ratio - current)
if diff < minDiff {
minDiff = diff
closest = candidate
}
}
if closest != "" {
return closest
}
}
if current < minValue {
return ratioString(minValue)
}
return ratioString(maxValue)
}
func ratioString(value float64) string {
if value <= 0 {
return "1:1"
}
return strings.TrimRight(strings.TrimRight(strconv.FormatFloat(value, 'f', 6, 64), "0"), ".") + ":1"
}
func parsePositiveFloat(value string) float64 {
for _, r := range strings.TrimSpace(value) {
if r < '0' || r > '9' {
if r != '.' {
return 0
}
}
}
out, _ := strconv.ParseFloat(strings.TrimSpace(value), 64)
return out
}
func isEmptyParamString(value string) bool {
normalized := strings.ToLower(strings.TrimSpace(value))
return normalized == "null" || normalized == "undefined"
}
func isImageResolution(modelType string, value string) bool {
return (modelType == "image_generate" || modelType == "image_edit") && containsString([]string{"1K", "2K", "4K", "8K"}, value)
}
func isVideoResolution(modelType string, value string) bool {
return isVideoModelType(modelType) && containsString([]string{"480p", "720p", "1080p", "1440p", "2160p"}, value)
}
func isVideoModelType(modelType string) bool {
return modelType == "video_generate" || modelType == "text_to_video" || modelType == "image_to_video" || modelType == "video_edit" || modelType == "video_reference" || modelType == "video_first_last_frame" || modelType == "omni_video" || modelType == "omni"
}
func cloneMap(values map[string]any) map[string]any {
out := map[string]any{}
for key, value := range values {
out[key] = cloneAny(value)
}
return out
}
func cloneAny(value any) any {
switch typed := value.(type) {
case map[string]any:
return cloneMap(typed)
case []any:
out := make([]any, 0, len(typed))
for _, item := range typed {
out = append(out, cloneAny(item))
}
return out
case []map[string]any:
out := make([]any, 0, len(typed))
for _, item := range typed {
out = append(out, cloneMap(item))
}
return out
default:
return value
}
}

View File

@ -0,0 +1,663 @@
package runner
import (
"fmt"
"math"
"strings"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
type contentFilterProcessor struct{}
func (contentFilterProcessor) Name() string { return "ContentFilterProcessor" }
func (contentFilterProcessor) ShouldProcess(params map[string]any, modelType string, context *paramProcessContext) bool {
_, ok := params["content"]
return ok
}
func (contentFilterProcessor) Process(params map[string]any, modelType string, context *paramProcessContext) bool {
content := contentItems(params["content"])
if len(content) == 0 {
return true
}
if isOmniVideoLike(context) {
filtered := filterUnsupportedOmniVideoContent(content, context)
params["content"] = mapsToAnySlice(filtered)
syncVideoConvenienceFields(params, filtered, context)
return true
}
if err := downgradeReferenceImageIfNeeded(params, content, modelType, context); err != nil {
return false
}
if modelType == "video_generate" || modelType == "text_to_video" {
next := make([]map[string]any, 0, len(content))
for index, item := range content {
if isImageContent(item) {
reason, path, value := imageContentRemovalEvidence(item, modelType, context)
context.recordChange(
"ContentFilterProcessor",
"remove",
fmt.Sprintf("content[%d]", index),
item,
nil,
reason,
path,
value,
)
continue
}
next = append(next, item)
}
content = next
}
if modelType == "image_to_video" || modelType == "omni_video" || modelType == "omni" {
if !supportsFirstAndLastFrame(context.modelCapability, modelType) {
next := make([]map[string]any, 0, len(content))
for index, item := range content {
if stringFromAny(item["role"]) == "last_frame" {
context.recordChange(
"ContentFilterProcessor",
"remove",
fmt.Sprintf("content[%d]", index),
item,
nil,
"模型不支持首尾帧输入,已移除 last_frame。",
capabilityPath(modelType, "input_first_last_frame"),
map[string]any{
"input_first_last_frame": capabilityValue(context.modelCapability, modelType, "input_first_last_frame"),
"max_images_for_last_frame": capabilityValue(context.modelCapability, modelType, "max_images_for_last_frame"),
},
)
continue
}
next = append(next, item)
}
content = next
deleteFieldsWithLog(params, context, "ContentFilterProcessor", []string{"last_frame", "lastFrame"}, "模型不支持首尾帧输入,已移除快捷字段。", capabilityPath(modelType, "input_first_last_frame"), map[string]any{
"input_first_last_frame": capabilityValue(context.modelCapability, modelType, "input_first_last_frame"),
"max_images_for_last_frame": capabilityValue(context.modelCapability, modelType, "max_images_for_last_frame"),
})
}
}
params["content"] = mapsToAnySlice(content)
return true
}
func imageContentRemovalEvidence(item map[string]any, modelType string, context *paramProcessContext) (string, string, any) {
role := stringFromAny(item["role"])
switch role {
case "first_frame":
return "模型能力未开启首帧输入,已移除 first_frame。", capabilityPath(modelType, "input_first_frame"), map[string]any{
"input_first_frame": capabilityValue(context.modelCapability, modelType, "input_first_frame"),
"input_first_last_frame": capabilityValue(context.modelCapability, modelType, "input_first_last_frame"),
}
case "last_frame":
return "模型能力未开启尾帧或首尾帧输入,已移除 last_frame。", capabilityPath(modelType, "input_first_last_frame"), map[string]any{
"input_last_frame": capabilityValue(context.modelCapability, modelType, "input_last_frame"),
"input_first_last_frame": capabilityValue(context.modelCapability, modelType, "input_first_last_frame"),
"max_images_for_last_frame": capabilityValue(context.modelCapability, modelType, "max_images_for_last_frame"),
"max_images_for_first_frame": capabilityValue(context.modelCapability, modelType, "max_images_for_first_frame"),
"max_images_for_middle_frame": capabilityValue(context.modelCapability, modelType, "max_images_for_middle_frame"),
}
case "reference_image":
return "模型能力未开启参考图输入,已移除 reference_image。", capabilityPath(modelType, "input_reference_generate_single"), map[string]any{
"input_reference_generate_single": capabilityValue(context.modelCapability, modelType, "input_reference_generate_single"),
"input_reference_generate_multiple": capabilityValue(context.modelCapability, modelType, "input_reference_generate_multiple"),
"max_images": capabilityValue(context.modelCapability, modelType, "max_images"),
}
default:
return "当前模型能力未开启图像输入,已移除 image_url。", capabilityPath(modelType, "input_first_frame"), map[string]any{
"input_first_frame": capabilityValue(context.modelCapability, modelType, "input_first_frame"),
"input_first_last_frame": capabilityValue(context.modelCapability, modelType, "input_first_last_frame"),
"input_reference_generate_single": capabilityValue(context.modelCapability, modelType, "input_reference_generate_single"),
"input_reference_generate_multiple": capabilityValue(context.modelCapability, modelType, "input_reference_generate_multiple"),
}
}
}
func ensureVideoContent(params map[string]any, context *paramProcessContext) {
if len(contentItems(params["content"])) > 0 {
return
}
content := make([]map[string]any, 0)
if prompt := firstNonEmptyString(stringFromAny(params["prompt"]), stringFromAny(params["input"])); prompt != "" {
content = append(content, map[string]any{"type": "text", "text": prompt})
}
appendURL := func(kind string, role string, url string) {
url = strings.TrimSpace(url)
if url == "" {
return
}
item := map[string]any{"type": kind, "role": role}
switch kind {
case "image_url":
item["image_url"] = map[string]any{"url": url}
case "video_url":
item["video_url"] = map[string]any{"url": url}
case "audio_url":
item["audio_url"] = map[string]any{"url": url}
}
content = append(content, item)
}
firstFrame := firstNonEmptyStringValue(params, "first_frame", "firstFrame")
appendURL("image_url", "first_frame", firstFrame)
appendURL("image_url", "last_frame", firstNonEmptyStringValue(params, "last_frame", "lastFrame"))
imageURLs := firstNonEmptyStringListFromAny(params["image"], params["images"], params["image_url"], params["imageUrl"], params["image_urls"], params["imageUrls"])
if firstFrame == "" && len(imageURLs) > 0 {
appendURL("image_url", "first_frame", imageURLs[0])
imageURLs = imageURLs[1:]
}
for _, url := range imageURLs {
appendURL("image_url", "reference_image", url)
}
for _, url := range firstNonEmptyStringListFromAny(params["reference_image"], params["referenceImage"]) {
appendURL("image_url", "reference_image", url)
}
for _, url := range firstNonEmptyStringListFromAny(params["video"], params["video_url"], params["videoUrl"], params["reference_video"], params["referenceVideo"]) {
appendURL("video_url", "reference_video", url)
}
for _, url := range firstNonEmptyStringListFromAny(params["audio_url"], params["audioUrl"], params["reference_audio"], params["referenceAudio"]) {
appendURL("audio_url", "reference_audio", url)
}
if len(content) > 0 {
params["content"] = mapsToAnySlice(content)
context.recordChange(
"ContentBuildProcessor",
"set",
"content",
nil,
params["content"],
"将 prompt/first_frame/reference_* 等快捷字段转换为 content 数组,后续处理器可按模型能力逐项过滤。",
"",
nil,
)
}
}
func effectiveModelCapability(candidate store.RuntimeModelCandidate) map[string]any {
base := cloneMap(candidate.Capabilities)
for key, value := range candidate.CapabilityOverride {
if baseChild, ok := base[key].(map[string]any); ok {
if overrideChild, ok := value.(map[string]any); ok {
base[key] = mergeMap(baseChild, overrideChild)
continue
}
}
base[key] = cloneAny(value)
}
return base
}
func filterUnsupportedOmniVideoContent(content []map[string]any, context *paramProcessContext) []map[string]any {
capability := omniVideoCapability(context)
maxVideos := math.Inf(1)
if capability != nil {
if value, ok := numericField(capability, "max_videos"); ok {
maxVideos = value
}
}
maxAudios := 0.0
if capability != nil {
if value, ok := numericField(capability, "max_audios"); ok {
maxAudios = value
} else if supportsOmniAudioReference(context) {
maxAudios = math.Inf(1)
}
}
videoCount := 0.0
audioCount := 0.0
out := make([]map[string]any, 0, len(content))
for index, item := range content {
if isVideoContent(item) {
if !supportsOmniVideoReference(item, capability) {
path, value := omniCapabilityEvidence(context, "supported_modes")
context.recordChange(
"ContentFilterProcessor",
"remove",
fmt.Sprintf("content[%d]", index),
item,
nil,
"视频参考类型不在 omni_video.supported_modes 允许范围内。",
path,
value,
)
continue
}
if videoCount >= maxVideos {
path, value := omniCapabilityEvidence(context, "max_videos")
context.recordChange(
"ContentFilterProcessor",
"remove",
fmt.Sprintf("content[%d]", index),
item,
nil,
"视频参考数量超过 omni_video.max_videos 限制。",
path,
value,
)
continue
}
videoCount++
out = append(out, item)
continue
}
if isAudioContent(item) {
if !supportsOmniAudioReference(context) {
path, value := omniCapabilityEvidence(context, "input_audio")
context.recordChange(
"ContentFilterProcessor",
"remove",
fmt.Sprintf("content[%d]", index),
item,
nil,
"模型能力不支持音频参考,已移除 audio_url。",
path,
mergeMetrics(map[string]any{"input_audio": value}, omniCapabilityBundle(context, "max_audios")),
)
continue
}
if audioCount >= maxAudios {
path, value := omniCapabilityEvidence(context, "max_audios")
context.recordChange(
"ContentFilterProcessor",
"remove",
fmt.Sprintf("content[%d]", index),
item,
nil,
"音频参考数量超过 omni_video.max_audios 限制。",
path,
value,
)
continue
}
audioCount++
out = append(out, item)
continue
}
out = append(out, item)
}
return out
}
func isOmniVideoLike(context *paramProcessContext) bool {
modelType := strings.TrimSpace(context.candidate.ModelType)
return modelType == "omni_video" ||
modelType == "omni" ||
context.modelCapability["omni_video"] != nil ||
context.modelCapability["omni"] != nil
}
func omniVideoCapability(context *paramProcessContext) map[string]any {
if capability := capabilityForType(context.modelCapability, "omni_video"); capability != nil {
return capability
}
return capabilityForType(context.modelCapability, "omni")
}
func supportsOmniAudioReference(context *paramProcessContext) bool {
capability := omniVideoCapability(context)
return capability != nil && (boolFromAny(capability["input_audio"]) || floatFromAny(capability["max_audios"]) > 0)
}
func supportsOmniVideoReference(item map[string]any, capability map[string]any) bool {
if capability == nil {
return true
}
if value, ok := numericField(capability, "max_videos"); ok && value == 0 {
return false
}
supportedModes := stringListFromAny(capability["supported_modes"])
supportsReference := containsString(supportedModes, "video_reference")
supportsEdit := containsString(supportedModes, "video_edit")
video, _ := item["video_url"].(map[string]any)
referType := stringFromAny(video["refer_type"])
isEditVideo := stringFromAny(item["role"]) == "video_base" || referType == "base"
isReferenceVideo := stringFromAny(item["role"]) == "video_feature" ||
stringFromAny(item["role"]) == "reference_video" ||
referType == "feature"
if isEditVideo {
return supportsEdit
}
if isReferenceVideo {
return supportsReference
}
return supportsReference || supportsEdit
}
func downgradeReferenceImageIfNeeded(params map[string]any, content []map[string]any, modelType string, context *paramProcessContext) error {
if !isVideoModelType(modelType) {
return nil
}
if supportsReferenceImage(context.modelCapability, modelType) {
return nil
}
imageIndexes := make([]int, 0)
referenceIndexes := make([]int, 0)
hasVideoOrAudioReference := false
for index, item := range content {
if isVideoContent(item) || isAudioContent(item) {
hasVideoOrAudioReference = true
continue
}
if !isImageContent(item) {
continue
}
imageIndexes = append(imageIndexes, index)
role := stringFromAny(item["role"])
if role == "" || role == "reference_image" {
referenceIndexes = append(referenceIndexes, index)
}
}
if len(referenceIndexes) == 0 {
return nil
}
evidence := referenceImageDowngradeCapabilityEvidence(context.modelCapability, modelType)
if hasVideoOrAudioReference {
context.reject(
"ContentFilterProcessor",
"content",
content,
"当前模型不支持多模态参考,不能将视频或音频参考降级为首尾帧,请移除视频/音频参考或选择支持多模态参考的模型。",
evidence.path,
evidence.value,
)
return context.err
}
if len(imageIndexes) > 2 {
context.reject(
"ContentFilterProcessor",
"content",
content,
"当前模型不支持多参考图输入,最多只允许 2 张图片降级为首尾帧。",
evidence.path,
evidence.value,
)
return context.err
}
if len(imageIndexes) == 2 && !supportsFirstAndLastFrame(context.modelCapability, modelType) {
context.reject(
"ContentFilterProcessor",
"content",
content,
"当前模型不支持首尾帧输入,不能将 2 张参考图降级为首尾帧。",
evidence.path,
evidence.value,
)
return context.err
}
if len(imageIndexes) == 1 && !supportsFirstFrame(context.modelCapability, modelType) {
context.reject(
"ContentFilterProcessor",
"content",
content,
"当前模型不支持首帧输入,不能将参考图降级为首帧。",
evidence.path,
evidence.value,
)
return context.err
}
if len(imageIndexes) == 1 {
adjustImageContentRole(content, imageIndexes[0], "first_frame", context, modelType, "模型不支持 reference_image且只有 1 张图片,已降级为 first_frame。")
appendParamWarning(params, "reference_image is unsupported by the selected model and was downgraded to first_frame")
return nil
}
firstIndex, lastIndex := firstLastFrameIndexes(content, imageIndexes)
adjustImageContentRole(content, firstIndex, "first_frame", context, modelType, "模型不支持 reference_image2 张图片已降级为首尾帧的 first_frame。")
adjustImageContentRole(content, lastIndex, "last_frame", context, modelType, "模型不支持 reference_image2 张图片已降级为首尾帧的 last_frame。")
appendParamWarning(params, "reference_image is unsupported by the selected model and was downgraded to first/last frame")
return nil
}
type capabilityEvidenceValue struct {
path string
value any
}
func referenceImageDowngradeCapabilityEvidence(modelCapability map[string]any, modelType string) capabilityEvidenceValue {
actualType, capability := firstVideoInputCapability(modelCapability, modelType)
if actualType == "" {
actualType = modelType
}
value := map[string]any{}
if capability != nil {
for _, key := range []string{
"input_reference_generate_single",
"input_reference_generate_multiple",
"max_images",
"input_first_frame",
"input_first_last_frame",
"max_images_for_last_frame",
} {
value[key] = cloneAny(capability[key])
}
}
return capabilityEvidenceValue{path: capabilityPath(actualType, ""), value: value}
}
func adjustImageContentRole(content []map[string]any, index int, role string, context *paramProcessContext, modelType string, reason string) {
if index < 0 || index >= len(content) {
return
}
item := content[index]
if stringFromAny(item["role"]) == role {
return
}
before := cloneMap(item)
item["role"] = role
context.recordChange(
"ContentFilterProcessor",
"adjust",
fmt.Sprintf("content[%d].role", index),
before,
item,
reason,
capabilityPath(modelType, "input_reference_generate_single"),
referenceImageDowngradeCapabilityEvidence(context.modelCapability, modelType).value,
)
}
func firstLastFrameIndexes(content []map[string]any, imageIndexes []int) (int, int) {
firstIndex := -1
lastIndex := -1
for _, index := range imageIndexes {
switch stringFromAny(content[index]["role"]) {
case "first_frame":
if firstIndex == -1 {
firstIndex = index
}
case "last_frame":
if lastIndex == -1 {
lastIndex = index
}
}
}
if firstIndex == -1 && lastIndex == -1 {
return imageIndexes[0], imageIndexes[1]
}
if firstIndex == -1 {
for _, index := range imageIndexes {
if index != lastIndex {
firstIndex = index
break
}
}
}
if lastIndex == -1 {
for _, index := range imageIndexes {
if index != firstIndex {
lastIndex = index
break
}
}
}
if firstIndex == lastIndex {
return imageIndexes[0], imageIndexes[1]
}
return firstIndex, lastIndex
}
type videoInputCapabilityValue struct {
modelType string
capability map[string]any
}
func firstVideoInputCapability(modelCapability map[string]any, modelType string) (string, map[string]any) {
for _, candidate := range videoInputCapabilityCandidates(modelCapability, modelType) {
return candidate.modelType, candidate.capability
}
return "", nil
}
func videoInputCapabilityCandidates(modelCapability map[string]any, modelType string) []videoInputCapabilityValue {
keys := []string{modelType, "image_to_video", "video_first_last_frame"}
if modelType == "omni_video" || modelType == "omni" {
keys = append(keys, "omni_video", "omni")
}
seen := map[string]bool{}
out := make([]videoInputCapabilityValue, 0, len(keys))
for _, key := range keys {
key = strings.TrimSpace(key)
if key == "" || seen[key] {
continue
}
seen[key] = true
if capability := capabilityForType(modelCapability, key); capability != nil {
out = append(out, videoInputCapabilityValue{modelType: key, capability: capability})
}
}
return out
}
func supportsReferenceImage(modelCapability map[string]any, modelType string) bool {
candidates := videoInputCapabilityCandidates(modelCapability, modelType)
if len(candidates) == 0 {
return true
}
for _, candidate := range candidates {
capability := candidate.capability
_, hasSingle := capability["input_reference_generate_single"]
_, hasMultiple := capability["input_reference_generate_multiple"]
if hasSingle || hasMultiple {
if boolFromAny(capability["input_reference_generate_single"]) || boolFromAny(capability["input_reference_generate_multiple"]) {
return true
}
continue
}
if value, ok := numericField(capability, "max_images"); ok {
if value > 1 {
return true
}
continue
}
}
return false
}
func supportsFirstFrame(modelCapability map[string]any, modelType string) bool {
for _, candidate := range videoInputCapabilityCandidates(modelCapability, modelType) {
capability := candidate.capability
if boolFromAny(capability["input_first_frame"]) ||
boolFromAny(capability["input_first_last_frame"]) ||
floatFromAny(capability["max_images_for_first_frame"]) > 0 ||
floatFromAny(capability["max_images_for_last_frame"]) > 0 {
return true
}
}
return false
}
func supportsFirstAndLastFrame(modelCapability map[string]any, modelType string) bool {
for _, candidate := range videoInputCapabilityCandidates(modelCapability, modelType) {
capability := candidate.capability
if boolFromAny(capability["input_first_last_frame"]) || floatFromAny(capability["max_images_for_last_frame"]) > 0 {
return true
}
}
return false
}
func videoModeKey(params map[string]any) string {
content := contentItems(params["content"])
hasFirstFrame := false
hasLastFrame := false
for _, item := range content {
switch stringFromAny(item["role"]) {
case "first_frame":
hasFirstFrame = true
case "last_frame":
hasLastFrame = true
}
}
switch {
case hasFirstFrame && hasLastFrame:
return "input_first_last_frame"
case hasFirstFrame:
return "input_first_frame"
case hasLastFrame:
return "input_last_frame"
default:
return ""
}
}
func syncDurationSeconds(params map[string]any) {
if params["duration_seconds"] != nil {
params["duration_seconds"] = params["duration"]
}
}
func syncVideoConvenienceFields(params map[string]any, content []map[string]any, context *paramProcessContext) {
hasVideo := false
hasAudio := false
for _, item := range content {
hasVideo = hasVideo || isVideoContent(item)
hasAudio = hasAudio || isAudioContent(item)
}
if !hasVideo {
path, value := omniCapabilityEvidence(context, "supported_modes")
deleteFieldsWithLog(params, context, "ContentFilterProcessor", []string{"video", "video_url", "videoUrl", "reference_video", "referenceVideo"}, "对应视频 content 已被模型能力过滤,移除视频参考快捷字段。", path, value)
}
if !hasAudio {
path, value := omniCapabilityEvidence(context, "input_audio")
deleteFieldsWithLog(params, context, "ContentFilterProcessor", []string{"audio_url", "audioUrl", "reference_audio", "referenceAudio"}, "对应音频 content 已被模型能力过滤,移除音频参考快捷字段。", path, mergeMetrics(map[string]any{"input_audio": value}, omniCapabilityBundle(context, "max_audios")))
}
}
func deleteFieldsWithLog(params map[string]any, context *paramProcessContext, processor string, keys []string, reason string, capabilityPath string, capabilityValue any) {
for _, key := range keys {
if before, ok := params[key]; ok {
delete(params, key)
context.recordChange(processor, "remove", key, before, nil, reason, capabilityPath, capabilityValue)
}
}
}
func appendParamWarning(params map[string]any, warning string) {
warnings, _ := params["_param_warnings"].([]any)
for _, item := range warnings {
if stringFromAny(item) == warning {
return
}
}
params["_param_warnings"] = append(warnings, warning)
}
func filterContent(content []map[string]any, keep func(map[string]any) bool) []map[string]any {
out := make([]map[string]any, 0, len(content))
for _, item := range content {
if keep(item) {
out = append(out, item)
}
}
return out
}

View File

@ -11,8 +11,10 @@ import (
)
type EstimateResult struct {
Items []any `json:"items"`
Resolver string `json:"resolver"`
Items []any `json:"items"`
Resolver string `json:"resolver"`
TotalAmount float64 `json:"totalAmount"`
Currency string `json:"currency"`
}
func (s *Service) Estimate(ctx context.Context, kind string, model string, body map[string]any, user *auth.User) (EstimateResult, error) {
@ -23,9 +25,12 @@ func (s *Service) Estimate(ctx context.Context, kind string, model string, body
}
candidate := candidates[0]
body = preprocessRequest(kind, body, candidate)
items := s.estimatedBillings(ctx, user, kind, body, candidate)
return EstimateResult{
Items: s.estimatedBillings(ctx, user, kind, body, candidate),
Resolver: "effective-pricing-v1",
Items: items,
Resolver: "effective-pricing-v1",
TotalAmount: totalBillingAmount(items),
Currency: billingCurrency(items),
}, nil
}
@ -60,10 +65,7 @@ func (s *Service) billings(ctx context.Context, user *auth.User, kind string, bo
billingLine(candidate, "text_output", "1k_tokens", outputTokens, outputAmount, discount, simulated),
}
}
count := int(floatFromAny(body["n"]))
if count <= 0 {
count = 1
}
count := requestOutputCount(body)
resource := "image"
unit := "image"
baseKey := "imageBase"
@ -73,8 +75,24 @@ func (s *Service) billings(ctx context.Context, user *auth.User, kind string, bo
}
if kind == "videos.generations" {
resource = "video"
unit = "video"
unit = "5s_video"
baseKey = "videoBase"
duration := requestDurationSeconds(body)
durationUnits := math.Max(1, math.Ceil(duration/5))
amount := float64(count) *
durationUnits *
resourcePrice(config, resource, baseKey, "basePrice") *
resourceWeight(config, resource, "resolutionWeights", firstNonEmptyString(stringFromMap(body, "resolution"), stringFromMap(body, "size"))) *
resourceWeight(config, resource, "audioWeights", boolWeightKey(boolishValue(body["audio"]))) *
resourceWeight(config, resource, "referenceVideoWeights", boolWeightKey(requestHasReferenceVideo(body))) *
resourceWeight(config, resource, "voiceSpecifiedWeights", boolWeightKey(requestHasVoiceID(body))) *
discount
return []any{billingLineWithDetails(candidate, resource, unit, count*int(durationUnits), roundPrice(amount), discount, simulated, map[string]any{
"count": count,
"durationSeconds": duration,
"durationUnit": "5s",
"durationUnitCount": durationUnits,
})}
}
amount := float64(count) * resourcePrice(config, resource, baseKey, "basePrice") * resourceWeight(config, resource, "qualityWeights", stringFromMap(body, "quality")) * resourceWeight(config, resource, "sizeWeights", stringFromMap(body, "size")) * resourceWeight(config, resource, "resolutionWeights", firstNonEmptyString(stringFromMap(body, "resolution"), stringFromMap(body, "size"))) * discount
return []any{billingLine(candidate, resource, unit, count, roundPrice(amount), discount, simulated)}
@ -109,17 +127,23 @@ func effectiveDiscount(ctx context.Context, db *store.Store, user *auth.User, ca
if discount <= 0 {
discount = 1
}
if group, err := db.ResolveUserGroupPolicy(ctx, user); err == nil {
groupDiscount := floatFromAny(group.BillingDiscountPolicy["discountFactor"])
if groupDiscount > 0 {
discount *= groupDiscount
if db != nil {
if group, err := db.ResolveUserGroupPolicy(ctx, user); err == nil {
groupDiscount := floatFromAny(group.BillingDiscountPolicy["discountFactor"])
if groupDiscount > 0 {
discount *= groupDiscount
}
}
}
return discount
}
func billingLine(candidate store.RuntimeModelCandidate, resourceType string, unit string, quantity any, amount float64, discount float64, simulated bool) map[string]any {
return map[string]any{
return billingLineWithDetails(candidate, resourceType, unit, quantity, amount, discount, simulated, nil)
}
func billingLineWithDetails(candidate store.RuntimeModelCandidate, resourceType string, unit string, quantity any, amount float64, discount float64, simulated bool, details map[string]any) map[string]any {
line := map[string]any{
"model": candidate.ModelName,
"modelAlias": candidate.ModelAlias,
"provider": candidate.Provider,
@ -133,6 +157,10 @@ func billingLine(candidate store.RuntimeModelCandidate, resourceType string, uni
"discountFactor": discount,
"simulated": simulated,
}
for key, value := range details {
line[key] = value
}
return line
}
func price(config map[string]any, key string) float64 {
@ -177,7 +205,16 @@ func weighted(config map[string]any, key string, name string) float64 {
}
func resourceWeight(config map[string]any, resource string, key string, name string) float64 {
if value := weighted(config, key, name); value != 1 {
keys := weightKeyAliases(key)
names := weightValueAliases(key, name)
for _, candidateKey := range keys {
for _, candidateName := range names {
if value := weighted(config, candidateKey, candidateName); value != 1 {
return value
}
}
}
if value := dynamicWeight(config["dynamicWeight"], keys, names); value != 1 {
return value
}
if strings.TrimSpace(name) == "" {
@ -187,19 +224,201 @@ func resourceWeight(config map[string]any, resource string, key string, name str
if len(resourceConfig) == 0 && resource == "image_edit" {
resourceConfig, _ = config["image"].(map[string]any)
}
if weights, ok := resourceConfig["dynamicWeight"].(map[string]any); ok {
if value := floatFromAny(weights[name]); value > 0 {
return value
if value := dynamicWeight(resourceConfig["dynamicWeight"], keys, names); value != 1 {
return value
}
for _, candidateKey := range keys {
if weights, ok := resourceConfig[candidateKey].(map[string]any); ok {
for _, candidateName := range names {
if value := floatFromAny(weights[candidateName]); value > 0 {
return value
}
}
}
}
if weights, ok := resourceConfig[key].(map[string]any); ok {
if value := floatFromAny(weights[name]); value > 0 {
return 1
}
func dynamicWeight(value any, keys []string, names []string) float64 {
if len(names) == 0 {
return 1
}
weights, _ := value.(map[string]any)
if len(weights) == 0 {
return 1
}
for _, name := range names {
if direct := floatFromAny(weights[name]); direct > 0 {
return direct
}
}
for _, key := range keys {
if nested, ok := weights[key].(map[string]any); ok {
for _, name := range names {
if nestedValue := floatFromAny(nested[name]); nestedValue > 0 {
return nestedValue
}
}
}
}
return 1
}
func weightKeyAliases(key string) []string {
switch key {
case "qualityWeights":
return []string{"qualityWeights", "qualityFactors"}
case "resolutionWeights":
return []string{"resolutionWeights", "resolutionFactors"}
case "audioWeights":
return []string{"audioWeights", "audioFactors"}
case "referenceVideoWeights":
return []string{"referenceVideoWeights", "referenceVideoFactors"}
case "voiceSpecifiedWeights":
return []string{"voiceSpecifiedWeights", "voiceSpecifiedFactors"}
default:
return []string{key}
}
}
func weightValueAliases(key string, name string) []string {
name = strings.TrimSpace(name)
if name == "" {
return nil
}
switch key {
case "audioWeights":
return []string{name, "audio-" + name}
case "referenceVideoWeights":
return []string{name, "reference-video-" + name}
case "voiceSpecifiedWeights":
return []string{name, "voice-specified-" + name}
default:
return []string{name}
}
}
func requestOutputCount(body map[string]any) int {
for _, key := range []string{"n", "count", "batch_size", "batchSize"} {
if value := int(math.Ceil(floatFromAny(body[key]))); value > 0 {
return value
}
}
return 1
}
func requestDurationSeconds(body map[string]any) float64 {
for _, key := range []string{"duration", "durationSeconds", "duration_seconds"} {
if value := floatFromAny(body[key]); value > 0 {
return value
}
}
for _, value := range body {
items, ok := value.([]any)
if !ok || len(items) == 0 {
continue
}
total := 0.0
allDurationItems := true
for _, item := range items {
record, ok := item.(map[string]any)
if !ok {
allDurationItems = false
break
}
duration := floatFromAny(record["duration"])
if duration <= 0 {
allDurationItems = false
break
}
total += duration
}
if allDurationItems && total > 0 {
return total
}
}
return 5
}
func requestHasReferenceVideo(body map[string]any) bool {
if hasNonEmptyArray(body["video_list"]) || hasNonEmptyArray(body["videoList"]) {
return true
}
if firstNonEmptyStringValue(body, "video", "video_url", "videoUrl", "reference_video", "referenceVideo") != "" {
return true
}
content, _ := body["content"].([]any)
for _, item := range content {
record, _ := item.(map[string]any)
if len(record) == 0 {
continue
}
itemType := strings.TrimSpace(stringFromAny(record["type"]))
role := strings.TrimSpace(stringFromAny(record["role"]))
if itemType == "video_url" || role == "video_feature" || role == "video_base" || role == "reference_video" {
return true
}
}
return false
}
func requestHasVoiceID(body map[string]any) bool {
return boolishValue(body["audio"]) && firstNonEmptyStringValue(body, "voice_id", "voiceId") != ""
}
func boolWeightKey(value bool) string {
if value {
return "true"
}
return "false"
}
func boolishValue(value any) bool {
switch typed := value.(type) {
case bool:
return typed
case string:
switch strings.ToLower(strings.TrimSpace(typed)) {
case "true", "1", "yes", "on":
return true
default:
return false
}
case int:
return typed != 0
case int64:
return typed != 0
case float64:
return typed != 0
default:
return false
}
}
func hasNonEmptyArray(value any) bool {
items, ok := value.([]any)
return ok && len(items) > 0
}
func totalBillingAmount(items []any) float64 {
total := 0.0
for _, raw := range items {
line, _ := raw.(map[string]any)
total += floatFromAny(line["amount"])
}
return roundPrice(total)
}
func billingCurrency(items []any) string {
for _, raw := range items {
line, _ := raw.(map[string]any)
if currency := stringFromAny(line["currency"]); currency != "" {
return currency
}
}
return "resource"
}
func firstNonEmptyString(values ...string) string {
for _, value := range values {
if strings.TrimSpace(value) != "" {

View File

@ -0,0 +1,132 @@
package runner
import (
"context"
"testing"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
func TestImageBillingEstimateUsesCountResolutionAndQuality(t *testing.T) {
service := &Service{}
candidate := store.RuntimeModelCandidate{
ModelName: "image-model",
BaseBillingConfig: map[string]any{
"image": map[string]any{
"basePrice": 10,
"dynamicWeight": map[string]any{
"resolutionFactors": map[string]any{"2K": 1.5},
"qualityFactors": map[string]any{"high": 1.5},
},
},
},
}
items := service.billings(context.Background(), nil, "images.generations", map[string]any{
"count": 2,
"quality": "high",
"resolution": "2K",
}, candidate, clients.Response{}, true)
line := firstBillingLine(t, items)
if got, want := floatFromAny(line["amount"]), 45.0; got != want {
t.Fatalf("image estimated amount = %v, want %v", got, want)
}
if got, want := line["quantity"], 2; got != want {
t.Fatalf("image quantity = %v, want %v", got, want)
}
}
func TestVideoBillingEstimateUsesFiveSecondUnitsAndDynamicWeights(t *testing.T) {
service := &Service{}
candidate := store.RuntimeModelCandidate{
ModelName: "video-model",
BaseBillingConfig: map[string]any{
"video": map[string]any{
"basePrice": 100,
"dynamicWeight": map[string]any{
"resolutionWeights": map[string]any{"1080p": 1.5},
"audioWeights": map[string]any{"true": 2},
"referenceVideoWeights": map[string]any{"true": 1.5},
"voiceSpecifiedWeights": map[string]any{"true": 1.2},
"unusedCompatibilityField": map[string]any{"true": 99},
},
},
},
}
items := service.billings(context.Background(), nil, "videos.generations", map[string]any{
"audio": true,
"duration": 12,
"resolution": "1080p",
"voice_id": "voice-a",
"content": []any{
map[string]any{"type": "video_url", "video_url": map[string]any{"url": "https://example.com/reference.mp4"}},
},
}, candidate, clients.Response{}, true)
line := firstBillingLine(t, items)
if got, want := floatFromAny(line["amount"]), 1620.0; got != want {
t.Fatalf("video estimated amount = %v, want %v", got, want)
}
if got, want := floatFromAny(line["durationUnitCount"]), 3.0; got != want {
t.Fatalf("video duration units = %v, want %v", got, want)
}
if got, want := line["quantity"], 3; got != want {
t.Fatalf("video quantity = %v, want %v", got, want)
}
}
func TestVideoBillingEstimateSupportsServerMainStyleDynamicKeys(t *testing.T) {
service := &Service{}
candidate := store.RuntimeModelCandidate{
ModelName: "legacy-video-model",
BaseBillingConfig: map[string]any{
"videoBase": 100,
"video": map[string]any{
"dynamicWeight": map[string]any{
"720p": 1.25,
"audio-true": 2,
"reference-video-true": 1.5,
},
},
},
}
items := service.billings(context.Background(), nil, "videos.generations", map[string]any{
"audio": "true",
"duration": 5,
"resolution": "720p",
"video_list": []any{map[string]any{"url": "https://example.com/reference.mp4"}},
}, candidate, clients.Response{}, true)
line := firstBillingLine(t, items)
if got, want := floatFromAny(line["amount"]), 375.0; got != want {
t.Fatalf("legacy video estimated amount = %v, want %v", got, want)
}
}
func TestVideoDurationEstimateSumsMultiShotDurations(t *testing.T) {
duration := requestDurationSeconds(map[string]any{
"multi_prompt": []any{
map[string]any{"prompt": "shot 1", "duration": 3},
map[string]any{"prompt": "shot 2", "duration": 7},
},
})
if duration != 10 {
t.Fatalf("multi-shot duration = %v, want 10", duration)
}
}
func firstBillingLine(t *testing.T, items []any) map[string]any {
t.Helper()
if len(items) != 1 {
t.Fatalf("items length = %d, want 1: %+v", len(items), items)
}
line, ok := items[0].(map[string]any)
if !ok {
t.Fatalf("item type = %T, want map[string]any", items[0])
}
return line
}

View File

@ -86,7 +86,7 @@ func taskMetrics(task store.GatewayTask, user *auth.User, body map[string]any, c
copyIfPresent(metrics, body, "style")
case "videos.generations":
metrics["hasReferenceImage"] = imageInputCount(body) > 0
metrics["hasReferenceVideo"] = hasAnyString(body, "video", "video_url", "videoUrl", "reference_video", "referenceVideo")
metrics["hasReferenceVideo"] = hasAnyString(body, "video", "video_url", "videoUrl", "reference_video", "referenceVideo") || hasVideoContent(body)
copyIfPresent(metrics, body, "duration")
copyIfPresent(metrics, body, "resolution")
copyIfPresent(metrics, body, "size")
@ -303,9 +303,23 @@ func imageInputCount(body map[string]any) int {
count += len(values)
}
}
for _, item := range contentItems(body["content"]) {
if isImageContent(item) {
count++
}
}
return count
}
func hasVideoContent(body map[string]any) bool {
for _, item := range contentItems(body["content"]) {
if isVideoContent(item) {
return true
}
}
return false
}
func hasAnyString(body map[string]any, keys ...string) bool {
for _, key := range keys {
if stringFromMap(body, key) != "" {

View File

@ -104,6 +104,17 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
firstCandidateBody = preprocessing.Body
firstPreprocessing = preprocessing.Log
normalizedModelType = candidates[0].ModelType
if preprocessing.Err != nil {
clientErr := parameterPreprocessClientError(preprocessing.Err)
if logErr := s.recordTaskParameterPreprocessing(ctx, task.ID, "", 0, candidates[0], firstPreprocessing); logErr != nil {
return Result{}, logErr
}
failed, finishErr := s.failTask(ctx, task.ID, clients.ErrorCode(clientErr), clientErr.Error(), task.RunMode == "simulation", clientErr, parameterPreprocessingMetrics(firstPreprocessing))
if finishErr != nil {
return Result{}, finishErr
}
return Result{Task: failed, Output: failed.Result}, clientErr
}
if err := s.store.MarkTaskRunning(ctx, task.ID, candidates[0].ModelType, firstCandidateBody); err != nil {
return Result{}, err
}
@ -149,6 +160,10 @@ candidatesLoop:
preprocessing := preprocessRequestWithLog(task.Kind, body, candidate)
preprocessingLog := preprocessing.Log
lastPreprocessing = &preprocessingLog
if preprocessing.Err != nil {
lastErr = parameterPreprocessClientError(preprocessing.Err)
break candidatesLoop
}
candidateBody := preprocessing.Body
response, err := s.runCandidate(ctx, task, user, candidateBody, preprocessing.Log, candidate, nextAttemptNo, onDelta)
if err == nil {
@ -481,7 +496,7 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user
s.applyCandidateFailurePolicies(ctx, task.ID, candidate, err, simulated)
return clients.Response{}, err
}
uploadedResult, err := s.uploadGeneratedAssets(ctx, response.Result)
uploadedResult, err := s.uploadGeneratedAssets(ctx, task.ID, task.Kind, response.Result)
if err != nil {
metrics := mergeMetrics(taskMetrics(task, user, body, candidate, response, simulated), parameterPreprocessingMetrics(preprocessing), map[string]any{
"error": err.Error(),
@ -531,6 +546,9 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user
}
func (s *Service) recordTaskParameterPreprocessing(ctx context.Context, taskID string, attemptID string, attemptNo int, candidate store.RuntimeModelCandidate, log parameterPreprocessingLog) error {
if skipTaskParameterPreprocessingLog(log.ModelType) {
return nil
}
_, err := s.store.CreateTaskParamPreprocessingLog(ctx, store.CreateTaskParamPreprocessingLogInput{
TaskID: taskID,
AttemptID: attemptID,
@ -549,6 +567,15 @@ func (s *Service) recordTaskParameterPreprocessing(ctx context.Context, taskID s
return err
}
func skipTaskParameterPreprocessingLog(modelType string) bool {
switch strings.TrimSpace(modelType) {
case "text_generate", "chat", "responses", "text":
return true
default:
return false
}
}
func (s *Service) clientFor(candidate store.RuntimeModelCandidate, simulated bool) clients.Client {
if simulated {
return s.clients["simulation"]
@ -687,7 +714,7 @@ func requestedModelTypeFromBody(body map[string]any) string {
func isKnownModelType(value string) bool {
switch value {
case "text_generate", "image_generate", "image_edit", "video_generate", "image_to_video", "text_to_video", "video_edit", "omni_video", "omni":
case "text_generate", "image_generate", "image_edit", "video_generate", "image_to_video", "text_to_video", "video_edit", "video_reference", "video_first_last_frame", "omni_video", "omni":
return true
default:
return false
@ -706,6 +733,11 @@ func videoRequestHasReferenceImage(body map[string]any) bool {
return true
}
}
for _, item := range contentItems(body["content"]) {
if isImageContent(item) {
return true
}
}
return false
}
@ -851,3 +883,15 @@ func validateRequest(kind string, body map[string]any) error {
}
return nil
}
func parameterPreprocessClientError(err error) *clients.ClientError {
if err == nil {
return nil
}
return &clients.ClientError{
Code: "invalid_parameter",
Message: err.Error(),
StatusCode: 400,
Retryable: false,
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,279 @@
package runner
import (
"bytes"
"context"
"encoding/base64"
"os"
"path/filepath"
"strings"
"testing"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
func TestGeneratedAssetDecisionSkipsURLResultAndStripsInlinePayload(t *testing.T) {
item := map[string]any{
"b64_json": base64.StdEncoding.EncodeToString([]byte("inline image")),
"url": "https://cdn.example.com/generated.png",
}
decision, err := generatedAssetDecisionForItem("images.generations", item, defaultGeneratedAssetUploadPolicy())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if decision.Inline != nil {
t.Fatalf("URL media should not be uploaded by the default policy")
}
if !containsString(decision.StripKeys, "b64_json") {
t.Fatalf("inline payload should be stripped when URL is already available: %+v", decision.StripKeys)
}
}
func TestGeneratedAssetDecisionUploadsInlineImageBase64(t *testing.T) {
item := map[string]any{
"b64_json": base64.StdEncoding.EncodeToString([]byte("inline image")),
"mime_type": "image/jpeg",
}
decision, err := generatedAssetDecisionForItem("images.generations", item, defaultGeneratedAssetUploadPolicy())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if decision.Inline == nil {
t.Fatalf("expected inline image to be uploaded")
}
if decision.Inline.Kind != "image" || decision.Inline.ContentType != "image/jpeg" {
t.Fatalf("unexpected inline image metadata: %+v", decision.Inline)
}
if !containsString(decision.StripKeys, "b64_json") {
t.Fatalf("uploaded inline payload should be stripped: %+v", decision.StripKeys)
}
}
func TestGeneratedAssetDecisionUploadsInlineVideoBuffer(t *testing.T) {
item := map[string]any{
"type": "video",
"video_buffer": []any{float64(0), float64(1), float64(2), float64(3)},
}
decision, err := generatedAssetDecisionForItem("videos.generations", item, defaultGeneratedAssetUploadPolicy())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if decision.Inline == nil {
t.Fatalf("expected inline video buffer to be uploaded")
}
if decision.Inline.Kind != "video" || decision.Inline.ContentType != "video/mp4" {
t.Fatalf("unexpected inline video metadata: %+v", decision.Inline)
}
if !containsString(decision.StripKeys, "video_buffer") {
t.Fatalf("uploaded video buffer should be stripped: %+v", decision.StripKeys)
}
}
func TestGeneratedAssetDecisionUploadsDataURL(t *testing.T) {
item := map[string]any{
"url": "data:image/webp;base64," + base64.StdEncoding.EncodeToString([]byte("inline webp")),
}
decision, err := generatedAssetDecisionForItem("images.generations", item, defaultGeneratedAssetUploadPolicy())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if decision.Inline == nil {
t.Fatalf("expected data URL to be uploaded")
}
if decision.Inline.SourceKey != "url" || decision.Inline.ContentType != "image/webp" {
t.Fatalf("unexpected data URL metadata: %+v", decision.Inline)
}
if !containsString(decision.StripKeys, "url") {
t.Fatalf("uploaded data URL field should be stripped: %+v", decision.StripKeys)
}
}
func TestGeneratedAssetDecisionUploadsURLWhenPolicyUploadAll(t *testing.T) {
item := map[string]any{
"type": "video",
"video_url": "https://cdn.example.com/generated.mp4",
}
decision, err := generatedAssetDecisionForItem("videos.generations", item, generatedAssetUploadPolicy{UploadInlineMedia: true, UploadURLMedia: true})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if decision.URL == nil {
t.Fatalf("expected URL media to be uploaded")
}
if decision.URL.Kind != "video" || decision.URL.SourceKey != "video_url" {
t.Fatalf("unexpected URL media metadata: %+v", decision.URL)
}
if !containsString(decision.StripKeys, "video_url") {
t.Fatalf("uploaded URL field should be stripped: %+v", decision.StripKeys)
}
}
func TestGeneratedAssetDecisionStoresInlineLocallyWhenPolicyUploadNone(t *testing.T) {
item := map[string]any{
"b64_json": base64.StdEncoding.EncodeToString([]byte("inline image")),
}
decision, err := generatedAssetDecisionForItem("images.generations", item, generatedAssetUploadPolicyFromName(store.FileStorageResultUploadPolicyUploadNone))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if decision.Inline == nil || decision.URL != nil {
t.Fatalf("upload_none should still turn inline payloads into static URLs: %+v", decision)
}
if !containsString(decision.StripKeys, "b64_json") {
t.Fatalf("inline payload should be stripped before persistence: %+v", decision.StripKeys)
}
}
func TestGeneratedAssetUploadPolicyFromName(t *testing.T) {
tests := []struct {
name string
policyName string
want generatedAssetUploadPolicy
}{
{
name: "default",
policyName: store.FileStorageResultUploadPolicyDefault,
want: generatedAssetUploadPolicy{UploadInlineMedia: true, UploadURLMedia: false, StoreInlineMediaLocally: false},
},
{
name: "upload all",
policyName: store.FileStorageResultUploadPolicyUploadAll,
want: generatedAssetUploadPolicy{UploadInlineMedia: true, UploadURLMedia: true, StoreInlineMediaLocally: false},
},
{
name: "upload none",
policyName: store.FileStorageResultUploadPolicyUploadNone,
want: generatedAssetUploadPolicy{UploadInlineMedia: true, UploadURLMedia: false, StoreInlineMediaLocally: true},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := generatedAssetUploadPolicyFromName(tt.policyName)
if got != tt.want {
t.Fatalf("unexpected policy: got %+v, want %+v", got, tt.want)
}
})
}
}
func TestResolvedGeneratedAssetContentTypePrefersDetectedMedia(t *testing.T) {
pngPayload := []byte{0x89, 'P', 'N', 'G', 0x0d, 0x0a, 0x1a, 0x0a, 0, 0, 0, 0}
contentType := resolvedGeneratedAssetContentType("image/jpeg", "image", pngPayload)
if contentType != "image/png" {
t.Fatalf("expected detected PNG content type, got %s", contentType)
}
if extension := fileExtensionForContentType(contentType, "image"); extension != ".png" {
t.Fatalf("expected PNG extension, got %s", extension)
}
}
func TestResolvedGeneratedAssetContentTypeKeepsDeclaredMediaWhenDetectionIsGeneric(t *testing.T) {
contentType := resolvedGeneratedAssetContentType("image/webp", "image", []byte("not enough media bytes"))
if contentType != "image/webp" {
t.Fatalf("expected declared webp content type, got %s", contentType)
}
}
func TestGeneratedAssetFileNameIsUniqueAndTyped(t *testing.T) {
first := generatedAssetFileName("663e19cd4fa9d8078385c7c9", 0, "image/png", "image")
second := generatedAssetFileName("663e19cd4fa9d8078385c7c9", 0, "image/png", "image")
if first == second {
t.Fatalf("expected generated file names to be unique, both were %s", first)
}
if !strings.HasPrefix(first, "gateway-result-663e19cd4fa9d8078385c7c9-01-") || !strings.HasSuffix(first, ".png") {
t.Fatalf("unexpected generated file name: %s", first)
}
}
func TestUploadGeneratedAssetStoresLocalWhenNoChannels(t *testing.T) {
storageDir := t.TempDir()
service := &Service{cfg: config.Config{LocalGeneratedStorageDir: storageDir}}
payload := []byte{0x89, 'P', 'N', 'G', 0x0d, 0x0a, 0x1a, 0x0a, 0, 0, 0, 0}
asset := &generatedInlineAsset{
Bytes: payload,
ContentType: "image/jpeg",
Kind: "image",
SourceKey: "b64_json",
}
upload, contentType, kind, strategy, err := service.uploadGeneratedAsset(context.Background(), "task-123", asset, 0, nil, false)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if contentType != "image/png" || kind != "image" || strategy != "local_static_inline_media" {
t.Fatalf("unexpected local upload metadata: contentType=%s kind=%s strategy=%s", contentType, kind, strategy)
}
urlValue := stringFromAny(upload["url"])
if !strings.HasPrefix(urlValue, "/static/generated/gateway-result-task-123-01-") || !strings.HasSuffix(urlValue, ".png") {
t.Fatalf("unexpected local static URL: %s", urlValue)
}
entries, err := os.ReadDir(storageDir)
if err != nil {
t.Fatalf("failed to read local static dir: %v", err)
}
if len(entries) != 1 || !strings.HasSuffix(entries[0].Name(), ".png") {
t.Fatalf("expected one PNG file in local static dir, got %+v", entries)
}
stored, err := os.ReadFile(filepath.Join(storageDir, entries[0].Name()))
if err != nil {
t.Fatalf("failed to read local static file: %v", err)
}
if !bytes.Equal(stored, payload) {
t.Fatalf("stored payload does not match source payload")
}
}
func TestUploadFileStoresLocalWhenNoChannels(t *testing.T) {
storageDir := t.TempDir()
service := &Service{cfg: config.Config{
LocalUploadedStorageDir: storageDir,
ServerMainBaseURL: "http://127.0.0.1:1",
ServerMainInternalToken: "change-me",
}}
payload := []byte("%PDF-1.4")
upload, err := service.UploadFile(context.Background(), FileUploadPayload{
Bytes: payload,
ContentType: "application/pdf",
FileName: "用户文件.png",
Source: "playground",
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
urlValue := stringFromAny(upload["url"])
if !strings.HasPrefix(urlValue, "/static/uploaded/") || !strings.HasSuffix(urlValue, ".pdf") {
t.Fatalf("unexpected uploaded local static URL: %s", urlValue)
}
storageChannel, _ := upload["storageChannel"].(map[string]any)
if stringFromAny(storageChannel["provider"]) != "local_static" {
t.Fatalf("expected local static provider metadata, got %+v", upload["storageChannel"])
}
assetStorage, _ := upload["assetStorage"].(map[string]any)
if stringFromAny(assetStorage["strategy"]) != "local_static_upload" || stringFromAny(assetStorage["scene"]) != store.FileStorageSceneUpload {
t.Fatalf("unexpected upload asset storage metadata: %+v", assetStorage)
}
entries, err := os.ReadDir(storageDir)
if err != nil {
t.Fatalf("failed to read uploaded static dir: %v", err)
}
if len(entries) != 1 || !strings.HasSuffix(entries[0].Name(), ".pdf") {
t.Fatalf("expected one PDF file in uploaded static dir, got %+v", entries)
}
stored, err := os.ReadFile(filepath.Join(storageDir, entries[0].Name()))
if err != nil {
t.Fatalf("failed to read uploaded static file: %v", err)
}
if !bytes.Equal(stored, payload) {
t.Fatalf("stored uploaded payload does not match source payload")
}
}

View File

@ -0,0 +1,499 @@
package store
import (
"context"
"encoding/json"
"strings"
"time"
"github.com/jackc/pgx/v5"
)
const defaultServerMainUploadURL = "http://127.0.0.1:3001/v1/files/upload"
const (
FileStorageSceneUpload = "upload"
FileStorageSceneImageResult = "image_result"
)
const (
FileStorageResultUploadPolicyDefault = "default"
FileStorageResultUploadPolicyUploadAll = "upload_all"
FileStorageResultUploadPolicyUploadNone = "upload_none"
)
const SystemSettingFileStorage = "file_storage"
const fileStorageChannelColumns = `
id::text, channel_key, name, provider, COALESCE(upload_url, ''), credentials,
config, retry_policy, priority, status, COALESCE(last_error, ''),
COALESCE(last_failed_at::text, ''), COALESCE(last_succeeded_at::text, ''),
created_at, updated_at`
type FileStorageChannel struct {
ID string `json:"id"`
ChannelKey string `json:"channelKey"`
Name string `json:"name"`
Provider string `json:"provider"`
UploadURL string `json:"uploadUrl,omitempty"`
APIKey string `json:"-"`
CredentialsPreview map[string]any `json:"credentialsPreview,omitempty"`
Scenes []string `json:"scenes,omitempty"`
Config map[string]any `json:"config,omitempty"`
RetryPolicy map[string]any `json:"retryPolicy,omitempty"`
Priority int `json:"priority"`
Status string `json:"status"`
LastError string `json:"lastError,omitempty"`
LastFailedAt string `json:"lastFailedAt,omitempty"`
LastSucceededAt string `json:"lastSucceededAt,omitempty"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
type FileStorageChannelInput struct {
ChannelKey string `json:"channelKey"`
Name string `json:"name"`
Provider string `json:"provider"`
UploadURL string `json:"uploadUrl"`
APIKey *string `json:"apiKey"`
Scenes []string `json:"scenes"`
Config map[string]any `json:"config"`
RetryPolicy map[string]any `json:"retryPolicy"`
Priority int `json:"priority"`
Status string `json:"status"`
}
type FileStorageSettings struct {
ResultUploadPolicy string `json:"resultUploadPolicy"`
}
type FileStorageSettingsInput struct {
ResultUploadPolicy string `json:"resultUploadPolicy"`
}
type fileStorageChannelScanner interface {
Scan(dest ...any) error
}
func (s *Store) ListFileStorageChannels(ctx context.Context) ([]FileStorageChannel, error) {
rows, err := s.pool.Query(ctx, `
SELECT `+fileStorageChannelColumns+`
FROM file_storage_channels
WHERE deleted_at IS NULL
ORDER BY priority ASC, created_at ASC`)
if err != nil {
return nil, err
}
defer rows.Close()
items := make([]FileStorageChannel, 0)
for rows.Next() {
item, err := scanFileStorageChannel(rows)
if err != nil {
return nil, err
}
items = append(items, item)
}
return items, rows.Err()
}
func (s *Store) ListEnabledFileStorageChannels(ctx context.Context) ([]FileStorageChannel, error) {
return s.listEnabledFileStorageChannels(ctx, "")
}
func (s *Store) ListEnabledFileStorageChannelsForScene(ctx context.Context, scene string) ([]FileStorageChannel, error) {
return s.listEnabledFileStorageChannels(ctx, normalizeFileStorageScene(scene))
}
func (s *Store) listEnabledFileStorageChannels(ctx context.Context, scene string) ([]FileStorageChannel, error) {
rows, err := s.pool.Query(ctx, `
SELECT `+fileStorageChannelColumns+`
FROM file_storage_channels
WHERE deleted_at IS NULL
AND status = 'enabled'
AND (
$1 = ''
OR NOT (config ? 'scenes')
OR jsonb_typeof(config->'scenes') <> 'array'
OR (config->'scenes') ? $1
)
ORDER BY priority ASC, created_at ASC`, scene)
if err != nil {
return nil, err
}
defer rows.Close()
items := make([]FileStorageChannel, 0)
for rows.Next() {
item, err := scanFileStorageChannel(rows)
if err != nil {
return nil, err
}
items = append(items, item)
}
return items, rows.Err()
}
func (s *Store) GetFileStorageChannel(ctx context.Context, id string) (FileStorageChannel, error) {
return scanFileStorageChannel(s.pool.QueryRow(ctx, `
SELECT `+fileStorageChannelColumns+`
FROM file_storage_channels
WHERE id = $1::uuid
AND deleted_at IS NULL`, id))
}
func (s *Store) CreateFileStorageChannel(ctx context.Context, input FileStorageChannelInput) (FileStorageChannel, error) {
input = normalizeFileStorageChannelInput(input)
credentials, _ := json.Marshal(credentialsFromFileStorageInput(input))
config, _ := json.Marshal(configFromFileStorageInput(input))
retryPolicy, _ := json.Marshal(defaultFileStorageRetryPolicyIfEmpty(input.RetryPolicy))
return scanFileStorageChannel(s.pool.QueryRow(ctx, `
INSERT INTO file_storage_channels (
channel_key, name, provider, upload_url, credentials, config, retry_policy, priority, status
)
VALUES ($1, $2, $3, NULLIF($4, ''), $5, $6, $7, $8, $9)
RETURNING `+fileStorageChannelColumns,
input.ChannelKey,
input.Name,
input.Provider,
input.UploadURL,
credentials,
config,
retryPolicy,
input.Priority,
input.Status,
))
}
func (s *Store) UpdateFileStorageChannel(ctx context.Context, id string, input FileStorageChannelInput) (FileStorageChannel, error) {
input = normalizeFileStorageChannelInput(input)
replaceCredentials := input.APIKey != nil
credentials, _ := json.Marshal(credentialsFromFileStorageInput(input))
config, _ := json.Marshal(configFromFileStorageInput(input))
retryPolicy, _ := json.Marshal(defaultFileStorageRetryPolicyIfEmpty(input.RetryPolicy))
return scanFileStorageChannel(s.pool.QueryRow(ctx, `
UPDATE file_storage_channels
SET channel_key = $2,
name = $3,
provider = $4,
upload_url = NULLIF($5, ''),
credentials = CASE WHEN $6::boolean THEN $7 ELSE credentials END,
config = $8,
retry_policy = $9,
priority = $10,
status = $11,
updated_at = now()
WHERE id = $1::uuid
AND deleted_at IS NULL
RETURNING `+fileStorageChannelColumns,
id,
input.ChannelKey,
input.Name,
input.Provider,
input.UploadURL,
replaceCredentials,
credentials,
config,
retryPolicy,
input.Priority,
input.Status,
))
}
func (s *Store) DeleteFileStorageChannel(ctx context.Context, id string) error {
result, err := s.pool.Exec(ctx, `
UPDATE file_storage_channels
SET deleted_at = now(),
status = 'disabled',
updated_at = now()
WHERE id = $1::uuid
AND deleted_at IS NULL`, id)
if err != nil {
return err
}
if result.RowsAffected() == 0 {
return pgx.ErrNoRows
}
return nil
}
func (s *Store) MarkFileStorageChannelFailure(ctx context.Context, id string, message string) error {
if strings.TrimSpace(id) == "" {
return nil
}
_, err := s.pool.Exec(ctx, `
UPDATE file_storage_channels
SET last_error = NULLIF($2, ''),
last_failed_at = now(),
updated_at = now()
WHERE id = $1::uuid
AND deleted_at IS NULL`, id, strings.TrimSpace(message))
return err
}
func (s *Store) MarkFileStorageChannelSuccess(ctx context.Context, id string) error {
if strings.TrimSpace(id) == "" {
return nil
}
_, err := s.pool.Exec(ctx, `
UPDATE file_storage_channels
SET last_error = NULL,
last_succeeded_at = now(),
updated_at = now()
WHERE id = $1::uuid
AND deleted_at IS NULL`, id)
return err
}
func scanFileStorageChannel(scanner fileStorageChannelScanner) (FileStorageChannel, error) {
var item FileStorageChannel
var credentials []byte
var config []byte
var retryPolicy []byte
if err := scanner.Scan(
&item.ID,
&item.ChannelKey,
&item.Name,
&item.Provider,
&item.UploadURL,
&credentials,
&config,
&retryPolicy,
&item.Priority,
&item.Status,
&item.LastError,
&item.LastFailedAt,
&item.LastSucceededAt,
&item.CreatedAt,
&item.UpdatedAt,
); err != nil {
return FileStorageChannel{}, err
}
credentialObject := decodeObject(credentials)
item.APIKey = stringFromObject(credentialObject, "apiKey")
item.CredentialsPreview = maskCredentialsPreview(credentials)
configObject := decodeObject(config)
item.Scenes = fileStorageScenesFromConfig(configObject)
item.Config = fileStorageConfigWithoutManagedFields(configObject)
item.RetryPolicy = decodeObject(retryPolicy)
return item, nil
}
func normalizeFileStorageChannelInput(input FileStorageChannelInput) FileStorageChannelInput {
input.ChannelKey = strings.TrimSpace(input.ChannelKey)
input.Name = strings.TrimSpace(input.Name)
input.Provider = strings.ToLower(strings.TrimSpace(input.Provider))
input.UploadURL = strings.TrimSpace(input.UploadURL)
if input.APIKey != nil {
apiKey := strings.TrimSpace(*input.APIKey)
input.APIKey = &apiKey
}
input.Scenes = normalizeFileStorageScenes(input.Scenes)
input.Status = strings.ToLower(strings.TrimSpace(input.Status))
if input.Provider == "" {
input.Provider = "server_main_openapi"
}
if input.Provider == "server_main_openapi" && input.UploadURL == "" {
input.UploadURL = defaultServerMainUploadURL
}
if input.Status == "" {
input.Status = "disabled"
}
if input.Priority <= 0 {
input.Priority = 100
}
return input
}
func credentialsFromFileStorageInput(input FileStorageChannelInput) map[string]any {
apiKey := fileStorageInputAPIKey(input)
if apiKey == "" {
return map[string]any{}
}
return map[string]any{"apiKey": apiKey}
}
func fileStorageInputAPIKey(input FileStorageChannelInput) string {
if input.APIKey == nil {
return ""
}
return strings.TrimSpace(*input.APIKey)
}
func configFromFileStorageInput(input FileStorageChannelInput) map[string]any {
config := map[string]any{}
for key, value := range emptyObjectIfNil(input.Config) {
config[key] = value
}
config["scenes"] = normalizeFileStorageScenes(input.Scenes)
return config
}
func fileStorageConfigWithoutManagedFields(config map[string]any) map[string]any {
out := map[string]any{}
for key, value := range config {
if key == "scenes" || key == "resultUploadPolicy" {
continue
}
out[key] = value
}
if len(out) == 0 {
return nil
}
return out
}
func DefaultFileStorageSettings() FileStorageSettings {
return FileStorageSettings{ResultUploadPolicy: FileStorageResultUploadPolicyDefault}
}
func (s *Store) GetFileStorageSettings(ctx context.Context) (FileStorageSettings, error) {
var value []byte
err := s.pool.QueryRow(ctx, `
SELECT value
FROM system_settings
WHERE setting_key = $1`, SystemSettingFileStorage).Scan(&value)
if err != nil {
if IsNotFound(err) {
return DefaultFileStorageSettings(), nil
}
return FileStorageSettings{}, err
}
return fileStorageSettingsFromValue(decodeObject(value)), nil
}
func (s *Store) UpdateFileStorageSettings(ctx context.Context, input FileStorageSettingsInput) (FileStorageSettings, error) {
settings := FileStorageSettings{ResultUploadPolicy: NormalizeFileStorageResultUploadPolicy(input.ResultUploadPolicy)}
value, _ := json.Marshal(settings)
var saved []byte
err := s.upsertFileStorageSettings(ctx, value, &saved)
if err != nil && IsUndefinedDatabaseObject(err) {
if ensureErr := s.ensureSystemSettingsTable(ctx); ensureErr != nil {
return FileStorageSettings{}, ensureErr
}
err = s.upsertFileStorageSettings(ctx, value, &saved)
}
if err != nil {
return FileStorageSettings{}, err
}
return fileStorageSettingsFromValue(decodeObject(saved)), nil
}
func (s *Store) upsertFileStorageSettings(ctx context.Context, value []byte, saved *[]byte) error {
return s.pool.QueryRow(ctx, `
INSERT INTO system_settings (setting_key, value)
VALUES ($1, $2)
ON CONFLICT (setting_key)
DO UPDATE SET value = EXCLUDED.value, updated_at = now()
RETURNING value`, SystemSettingFileStorage, value).Scan(saved)
}
func (s *Store) ensureSystemSettingsTable(ctx context.Context) error {
_, err := s.pool.Exec(ctx, `
CREATE TABLE IF NOT EXISTS system_settings (
setting_key text PRIMARY KEY,
value jsonb NOT NULL DEFAULT '{}'::jsonb,
created_at timestamptz NOT NULL DEFAULT now(),
updated_at timestamptz NOT NULL DEFAULT now()
)`)
return err
}
func fileStorageSettingsFromValue(value map[string]any) FileStorageSettings {
settings := DefaultFileStorageSettings()
if value == nil {
return settings
}
settings.ResultUploadPolicy = NormalizeFileStorageResultUploadPolicy(stringFromAny(value["resultUploadPolicy"]))
return settings
}
func NormalizeFileStorageResultUploadPolicy(policy string) string {
normalized := strings.ToLower(strings.TrimSpace(policy))
normalized = strings.ReplaceAll(normalized, "-", "_")
switch normalized {
case "", "default", "non_link_only", "inline_only", "nonlink_only", "non_link":
return FileStorageResultUploadPolicyDefault
case "upload_all", "all", "always", "all_upload":
return FileStorageResultUploadPolicyUploadAll
case "upload_none", "none", "never", "disabled", "no_upload", "skip", "skip_all":
return FileStorageResultUploadPolicyUploadNone
default:
return FileStorageResultUploadPolicyDefault
}
}
func fileStorageScenesFromConfig(config map[string]any) []string {
if config == nil {
return defaultFileStorageScenes()
}
raw, ok := config["scenes"]
if !ok {
return defaultFileStorageScenes()
}
items, ok := raw.([]any)
if !ok {
return defaultFileStorageScenes()
}
scenes := make([]string, 0, len(items))
for _, item := range items {
if value, ok := item.(string); ok {
scenes = append(scenes, value)
}
}
return normalizeFileStorageScenes(scenes)
}
func normalizeFileStorageScenes(scenes []string) []string {
seen := map[string]bool{}
out := make([]string, 0, len(scenes))
for _, item := range scenes {
scene := normalizeFileStorageScene(item)
if scene == "" || seen[scene] {
continue
}
seen[scene] = true
out = append(out, scene)
}
if len(out) == 0 {
return defaultFileStorageScenes()
}
return out
}
func normalizeFileStorageScene(scene string) string {
return strings.ToLower(strings.TrimSpace(scene))
}
func defaultFileStorageScenes() []string {
return []string{FileStorageSceneUpload, FileStorageSceneImageResult}
}
func defaultFileStorageRetryPolicyIfEmpty(policy map[string]any) map[string]any {
if len(policy) > 0 {
return policy
}
return map[string]any{
"enabled": true,
"maxRetries": 3,
"backoffSeconds": []any{60, 120, 180},
"strategy": "exponential",
}
}
func stringFromObject(value map[string]any, key string) string {
if value == nil {
return ""
}
raw, _ := value[key].(string)
return strings.TrimSpace(raw)
}
func stringFromAny(value any) string {
switch typed := value.(type) {
case string:
return strings.TrimSpace(typed)
default:
return ""
}
}

View File

@ -31,9 +31,18 @@ func (s *Store) ReserveRateLimits(ctx context.Context, taskID string, attemptID
}
if reservation.Metric == "" || reservation.Amount > reservation.Limit {
return RateLimitResult{}, &RateLimitExceededError{
Metric: reservation.Metric,
Message: fmt.Sprintf("rate limit exceeded: %s request amount %.0f is greater than limit %.0f", reservation.Metric, reservation.Amount, reservation.Limit),
Retryable: false,
ScopeType: reservation.ScopeType,
ScopeKey: reservation.ScopeKey,
ScopeName: reservation.ScopeName,
ScopeMetadata: reservation.ScopeMetadata,
Metric: reservation.Metric,
Limit: reservation.Limit,
Amount: reservation.Amount,
Projected: reservation.Amount,
WindowSeconds: reservation.WindowSeconds,
Policy: reservation.Policy,
Message: fmt.Sprintf("rate limit exceeded: %s request amount %.0f is greater than limit %.0f", reservation.Metric, reservation.Amount, reservation.Limit),
Retryable: false,
}
}
if reservation.WindowSeconds <= 0 {
@ -78,10 +87,22 @@ WHERE scope_type = $1
}
if active+reservation.Amount > reservation.Limit {
return "", &RateLimitExceededError{
Metric: reservation.Metric,
Message: fmt.Sprintf("rate limit exceeded: concurrent active %.0f plus request %.0f is greater than limit %.0f", active, reservation.Amount, reservation.Limit),
RetryAfter: concurrencyRetryAfter(nextAvailableAt),
Retryable: true,
ScopeType: reservation.ScopeType,
ScopeKey: reservation.ScopeKey,
ScopeName: reservation.ScopeName,
ScopeMetadata: reservation.ScopeMetadata,
Metric: reservation.Metric,
Limit: reservation.Limit,
Amount: reservation.Amount,
Current: active,
Used: active,
Projected: active + reservation.Amount,
WindowSeconds: reservation.WindowSeconds,
ResetAt: nextAvailableAt,
Policy: reservation.Policy,
Message: fmt.Sprintf("rate limit exceeded: concurrent active %.0f plus request %.0f is greater than limit %.0f", active, reservation.Amount, reservation.Limit),
RetryAfter: concurrencyRetryAfter(nextAvailableAt),
Retryable: true,
}
}
var leaseID string
@ -135,11 +156,13 @@ RETURNING window_start`,
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
resetAt := time.Now().Add(time.Duration(reservation.WindowSeconds) * time.Second)
currentUsed := 0.0
currentReserved := 0.0
_ = tx.QueryRow(ctx, `
WITH bounds AS (
SELECT to_timestamp(floor(extract(epoch FROM now()) / $4::int) * $4::int) AS window_start
)
SELECT counters.reset_at
SELECT counters.used_value::float8, counters.reserved_value::float8, counters.reset_at
FROM gateway_rate_limit_counters counters
JOIN bounds ON counters.window_start = bounds.window_start
WHERE scope_type = $1
@ -149,12 +172,26 @@ WHERE scope_type = $1
reservation.ScopeKey,
reservation.Metric,
reservation.WindowSeconds,
).Scan(&resetAt)
).Scan(&currentUsed, &currentReserved, &resetAt)
current := currentUsed + currentReserved
return RateLimitReservation{}, &RateLimitExceededError{
Metric: reservation.Metric,
Message: fmt.Sprintf("rate limit exceeded: %s window has no remaining capacity", reservation.Metric),
RetryAfter: retryAfterUntil(resetAt),
Retryable: true,
ScopeType: reservation.ScopeType,
ScopeKey: reservation.ScopeKey,
ScopeName: reservation.ScopeName,
ScopeMetadata: reservation.ScopeMetadata,
Metric: reservation.Metric,
Limit: reservation.Limit,
Amount: reservation.Amount,
Current: current,
Used: currentUsed,
Reserved: currentReserved,
Projected: current + reservation.Amount,
WindowSeconds: reservation.WindowSeconds,
ResetAt: resetAt,
Policy: reservation.Policy,
Message: fmt.Sprintf("rate limit exceeded: %s window has no remaining capacity", reservation.Metric),
RetryAfter: retryAfterUntil(resetAt),
Retryable: true,
}
}
return RateLimitReservation{}, err

View File

@ -33,10 +33,23 @@ func ModelCandidateErrorCode(err error) string {
}
type RateLimitExceededError struct {
Metric string
Message string
RetryAfter time.Duration
Retryable bool
ScopeType string
ScopeKey string
ScopeName string
ScopeMetadata map[string]any
Metric string
Limit float64
Amount float64
Current float64
Used float64
Reserved float64
Projected float64
WindowSeconds int
ResetAt time.Time
Policy map[string]any
Message string
RetryAfter time.Duration
Retryable bool
}
func (e *RateLimitExceededError) Error() string {
@ -166,12 +179,15 @@ type RateLimitReservation struct {
ReservationID string
ScopeType string
ScopeKey string
ScopeName string
ScopeMetadata map[string]any
Metric string
Limit float64
Amount float64
WindowSeconds int
LeaseTTLSeconds int
WindowStart time.Time
Policy map[string]any
}
type RateLimitResult struct {

View File

@ -10,6 +10,7 @@ import (
type UserGroupPolicy struct {
ID string
GroupKey string
Name string
RateLimitPolicy map[string]any
BillingDiscountPolicy map[string]any
}
@ -23,12 +24,12 @@ func (s *Store) ResolveUserGroupPolicy(ctx context.Context, user *auth.User) (Us
var rateLimit []byte
var billing []byte
err := s.pool.QueryRow(ctx, `
SELECT id::text, group_key, rate_limit_policy, billing_discount_policy
SELECT id::text, group_key, name, rate_limit_policy, billing_discount_policy
FROM gateway_user_groups
WHERE status = 'active'
AND (($1 <> '' AND id = NULLIF($1, '')::uuid) OR ($1 = '' AND group_key = 'default'))
ORDER BY CASE WHEN id::text = $1 THEN 0 ELSE 1 END, priority ASC
LIMIT 1`, userGroupID).Scan(&item.ID, &item.GroupKey, &rateLimit, &billing)
LIMIT 1`, userGroupID).Scan(&item.ID, &item.GroupKey, &item.Name, &rateLimit, &billing)
if err != nil {
if err == pgx.ErrNoRows {
return UserGroupPolicy{}, nil

View File

@ -0,0 +1,85 @@
CREATE TABLE IF NOT EXISTS system_settings (
setting_key text PRIMARY KEY,
value jsonb NOT NULL DEFAULT '{}'::jsonb,
created_at timestamptz NOT NULL DEFAULT now(),
updated_at timestamptz NOT NULL DEFAULT now()
);
INSERT INTO system_settings (setting_key, value)
VALUES (
'file_storage',
'{"resultUploadPolicy": "default"}'::jsonb
)
ON CONFLICT (setting_key) DO NOTHING;
CREATE TABLE IF NOT EXISTS file_storage_channels (
id uuid PRIMARY KEY DEFAULT gen_random_uuid(),
channel_key text NOT NULL UNIQUE,
name text NOT NULL,
provider text NOT NULL DEFAULT 'server_main_openapi',
upload_url text,
credentials jsonb NOT NULL DEFAULT '{}'::jsonb,
config jsonb NOT NULL DEFAULT '{}'::jsonb,
retry_policy jsonb NOT NULL DEFAULT '{
"enabled": true,
"maxRetries": 3,
"backoffSeconds": [60, 120, 180],
"strategy": "exponential"
}'::jsonb,
priority integer NOT NULL DEFAULT 100,
status text NOT NULL DEFAULT 'disabled',
last_error text,
last_failed_at timestamptz,
last_succeeded_at timestamptz,
created_at timestamptz NOT NULL DEFAULT now(),
updated_at timestamptz NOT NULL DEFAULT now(),
deleted_at timestamptz
);
ALTER TABLE IF EXISTS file_storage_channels
ADD COLUMN IF NOT EXISTS upload_url text,
ADD COLUMN IF NOT EXISTS config jsonb NOT NULL DEFAULT '{}'::jsonb,
ADD COLUMN IF NOT EXISTS retry_policy jsonb NOT NULL DEFAULT '{
"enabled": true,
"maxRetries": 3,
"backoffSeconds": [60, 120, 180],
"strategy": "exponential"
}'::jsonb,
ADD COLUMN IF NOT EXISTS priority integer NOT NULL DEFAULT 100,
ADD COLUMN IF NOT EXISTS last_error text,
ADD COLUMN IF NOT EXISTS last_failed_at timestamptz,
ADD COLUMN IF NOT EXISTS last_succeeded_at timestamptz,
ADD COLUMN IF NOT EXISTS deleted_at timestamptz;
CREATE INDEX IF NOT EXISTS idx_file_storage_channels_active
ON file_storage_channels (status, priority, created_at)
WHERE deleted_at IS NULL;
INSERT INTO file_storage_channels (
channel_key,
name,
provider,
upload_url,
credentials,
config,
retry_policy,
priority,
status
)
VALUES (
'server-main-openapi',
'server-main OpenAPI',
'server_main_openapi',
'http://127.0.0.1:3001/v1/files/upload',
'{}'::jsonb,
'{"scenes": ["upload", "image_result"]}'::jsonb,
'{
"enabled": true,
"maxRetries": 3,
"backoffSeconds": [60, 120, 180],
"strategy": "exponential"
}'::jsonb,
100,
'disabled'
)
ON CONFLICT (channel_key) DO NOTHING;

View File

@ -0,0 +1,13 @@
CREATE TABLE IF NOT EXISTS system_settings (
setting_key text PRIMARY KEY,
value jsonb NOT NULL DEFAULT '{}'::jsonb,
created_at timestamptz NOT NULL DEFAULT now(),
updated_at timestamptz NOT NULL DEFAULT now()
);
INSERT INTO system_settings (setting_key, value)
VALUES (
'file_storage',
'{"resultUploadPolicy": "default"}'::jsonb
)
ON CONFLICT (setting_key) DO NOTHING;

View File

@ -2,6 +2,10 @@ import { useEffect, useMemo, useRef, useState, type FormEvent } from 'react';
import type {
BaseModelCatalogItem,
CatalogProvider,
FileStorageChannel,
FileStorageSettings,
FileStorageSettingsUpdateRequest,
FileStorageChannelUpsertRequest,
GatewayAccessRuleBatchRequest,
GatewayAccessRule,
GatewayAccessRuleUpsertRequest,
@ -34,17 +38,22 @@ import {
batchApiKeyAccessRules,
createAccessRule,
createApiKey,
createFileStorageChannel,
createGatewayUser,
createPlatform,
createTenant,
createUserGroup,
deleteAccessRule,
deleteApiKey,
deleteFileStorageChannel,
deleteGatewayUser,
deletePlatform,
deleteTenant,
deleteUserGroup,
GatewayApiError,
getHealth,
listFileStorageChannels,
getFileStorageSettings,
getNetworkProxyConfig,
getRunnerPolicy,
getWalletSummary,
@ -78,6 +87,8 @@ import {
setUserWalletBalance,
type HealthResponse,
updateAccessRule,
updateFileStorageChannel,
updateFileStorageSettings,
updateGatewayUser,
updatePlatform,
updatePlatformDynamicPriority,
@ -135,6 +146,8 @@ type DataKey =
| 'playgroundModels'
| 'modelCatalog'
| 'networkProxyConfig'
| 'fileStorageChannels'
| 'fileStorageSettings'
| 'platforms'
| 'models'
| 'providers'
@ -179,6 +192,8 @@ export function App() {
});
const [playgroundModels, setPlaygroundModels] = useState<PlatformModel[]>([]);
const [networkProxyConfig, setNetworkProxyConfig] = useState<GatewayNetworkProxyConfig | null>(null);
const [fileStorageChannels, setFileStorageChannels] = useState<FileStorageChannel[]>([]);
const [fileStorageSettings, setFileStorageSettings] = useState<FileStorageSettings | null>(null);
const [providers, setProviders] = useState<CatalogProvider[]>([]);
const [baseModels, setBaseModels] = useState<BaseModelCatalogItem[]>([]);
const [pricingRules, setPricingRules] = useState<PricingRule[]>([]);
@ -257,7 +272,8 @@ export function App() {
loadedDataKeysRef.current.add('modelRateLimits');
loadedDataKeysRef.current.add('platforms');
})
.catch(() => {
.catch((err) => {
if (handleAuthExpired(err, token)) return;
loadedDataKeysRef.current.delete('modelRateLimits');
loadedDataKeysRef.current.delete('platforms');
});
@ -296,6 +312,8 @@ export function App() {
auditLogs,
apiKeys,
baseModels,
fileStorageChannels,
fileStorageSettings,
modelCatalog,
models,
networkProxyConfig,
@ -315,7 +333,7 @@ export function App() {
users,
walletAccounts,
walletTransactions,
}), [accessRules, apiKeys, auditLogs, baseModels, modelCatalog, modelRateLimits, modelRateLimitsUpdatedAt, models, networkProxyConfig, platforms, pricingRuleSets, pricingRules, providers, rateLimitWindows, runnerPolicy, runtimePolicySets, taskResult, tasks, tenants, userGroups, users, walletAccounts, walletTransactions]);
}), [accessRules, apiKeys, auditLogs, baseModels, fileStorageChannels, fileStorageSettings, modelCatalog, modelRateLimits, modelRateLimitsUpdatedAt, models, networkProxyConfig, platforms, pricingRuleSets, pricingRules, providers, rateLimitWindows, runnerPolicy, runtimePolicySets, taskResult, tasks, tenants, userGroups, users, walletAccounts, walletTransactions]);
async function refresh(nextToken = token) {
await ensureRouteData(nextToken, true);
@ -354,6 +372,7 @@ export function App() {
requestKeys.forEach((key) => loadedDataKeysRef.current.add(key));
setState('ready');
} catch (err) {
if (handleAuthExpired(err, nextToken)) return;
setState('error');
setError(err instanceof Error ? err.message : '加载失败');
} finally {
@ -388,6 +407,12 @@ export function App() {
case 'networkProxyConfig':
setNetworkProxyConfig(await getNetworkProxyConfig(nextToken));
return;
case 'fileStorageChannels':
setFileStorageChannels((await listFileStorageChannels(nextToken)).items);
return;
case 'fileStorageSettings':
setFileStorageSettings(await getFileStorageSettings(nextToken));
return;
case 'playgroundModels':
setPlaygroundModels((await listPlayableModels(nextToken)).items);
return;
@ -818,6 +843,53 @@ export function App() {
}
}
async function saveFileStorageChannel(input: FileStorageChannelUpsertRequest, channelId?: string) {
setCoreState('loading');
setCoreMessage('');
try {
const item = channelId
? await updateFileStorageChannel(token, channelId, input)
: await createFileStorageChannel(token, input);
setFileStorageChannels((current) => [item, ...current.filter((channel) => channel.id !== item.id)]);
setCoreState('ready');
setCoreMessage(channelId ? '文件存储渠道已更新。' : '文件存储渠道已新增。');
} catch (err) {
setCoreState('error');
setCoreMessage(err instanceof Error ? err.message : channelId ? '更新文件存储渠道失败' : '新增文件存储渠道失败');
throw err;
}
}
async function saveFileStorageSettings(input: FileStorageSettingsUpdateRequest) {
setCoreState('loading');
setCoreMessage('');
try {
const settings = await updateFileStorageSettings(token, input);
setFileStorageSettings(settings);
setCoreState('ready');
setCoreMessage('文件存储全局策略已更新。');
} catch (err) {
setCoreState('error');
setCoreMessage(err instanceof Error ? err.message : '更新文件存储全局策略失败');
throw err;
}
}
async function removeFileStorageChannel(channelId: string) {
setCoreState('loading');
setCoreMessage('');
try {
await deleteFileStorageChannel(token, channelId);
setFileStorageChannels((current) => current.filter((channel) => channel.id !== channelId));
setCoreState('ready');
setCoreMessage('文件存储渠道已删除。');
} catch (err) {
setCoreState('error');
setCoreMessage(err instanceof Error ? err.message : '删除文件存储渠道失败');
throw err;
}
}
async function batchSaveAPIKeyAccessRules(input: GatewayAccessRuleBatchRequest) {
setCoreState('loading');
setCoreMessage('');
@ -856,7 +928,7 @@ export function App() {
}
}
function signOut() {
function resetAuthenticatedSession() {
persistAccessToken('');
setToken('');
loadedDataKeysRef.current = new Set(health ? ['health'] : []);
@ -867,6 +939,7 @@ export function App() {
setModelCatalog({ items: [], filters: { capabilities: [], providers: [] }, summary: { modelCount: 0, sourceCount: 0 } });
setPlaygroundModels([]);
setNetworkProxyConfig(null);
setFileStorageChannels([]);
setProviders([]);
setBaseModels([]);
setPricingRules([]);
@ -892,6 +965,19 @@ export function App() {
setWalletTransactionTotal(0);
setWorkspaceTransactionQuery(defaultWorkspaceTransactionQuery());
setCoreMessage('');
}
function handleAuthExpired(err: unknown, failedToken: string) {
if (!failedToken || !(err instanceof GatewayApiError) || err.details.status !== 401) return false;
resetAuthenticatedSession();
setError('');
setAuthMode('login');
navigatePath(pathForWorkspaceSection('overview'));
return true;
}
function signOut() {
resetAuthenticatedSession();
navigatePath('/');
}
@ -1039,6 +1125,7 @@ export function App() {
onDeletePricingRuleSet={removePricingRuleSet}
onDeleteRuntimePolicySet={removeRuntimePolicySet}
onDeleteAccessRule={removeAccessRule}
onDeleteFileStorageChannel={removeFileStorageChannel}
onDeleteTenant={removeTenant}
onDeleteUser={removeUser}
onDeleteUserGroup={removeUserGroup}
@ -1054,6 +1141,8 @@ export function App() {
onSaveRuntimePolicySet={saveRuntimePolicySet}
onBatchAccessRules={batchSaveAccessRules}
onSaveAccessRule={saveAccessRule}
onSaveFileStorageChannel={saveFileStorageChannel}
onSaveFileStorageSettings={saveFileStorageSettings}
onSaveTenant={saveTenant}
onSaveUser={saveUser}
onSetUserWalletBalance={saveUserWalletBalance}
@ -1267,6 +1356,8 @@ function dataKeysForRoute(
return ['auditLogs'];
case 'accessRules':
return ['accessRules', 'userGroups', 'platforms', 'models'];
case 'systemSettings':
return ['fileStorageSettings', 'fileStorageChannels'];
default:
return [];
}

View File

@ -5,6 +5,10 @@ import type {
CatalogProvider,
CatalogProviderUpsertRequest,
CreatedGatewayApiKey,
FileStorageChannel,
FileStorageSettings,
FileStorageSettingsUpdateRequest,
FileStorageChannelUpsertRequest,
GatewayAccessRuleBatchRequest,
GatewayAccessRule,
GatewayAccessRuleUpsertRequest,
@ -15,6 +19,7 @@ import type {
GatewayTenant,
GatewayTenantUpsertRequest,
GatewayNetworkProxyConfig,
GatewayPricingEstimate,
GatewayTask,
GatewayTaskParamPreprocessingLog,
GatewayUser,
@ -601,10 +606,17 @@ export async function createImageGenerationTask(
model: string;
prompt: string;
aspect_ratio?: string;
content?: Array<Record<string, unknown>>;
count?: number;
height?: number;
image?: string | string[];
image_url?: string | string[];
image_urls?: string[];
images?: string[];
n?: number;
quality?: string;
referenceImage?: string | string[];
reference_image?: string | string[];
resolution?: string;
runMode?: string;
simulation?: boolean;
@ -622,7 +634,26 @@ export async function createImageGenerationTask(
export async function createImageEditTask(
token: string,
input: { model: string; prompt: string; image?: string; mask?: string; runMode?: string; simulation?: boolean },
input: {
model: string;
prompt: string;
aspect_ratio?: string;
content?: Array<Record<string, unknown>>;
count?: number;
height?: number;
image?: string | string[];
image_url?: string | string[];
image_urls?: string[];
images?: string[];
mask?: string;
n?: number;
quality?: string;
resolution?: string;
runMode?: string;
simulation?: boolean;
size?: string;
width?: number;
},
): Promise<{ task: GatewayTask; next: Record<string, string> }> {
return request<{ task: GatewayTask; next: Record<string, string> }>('/api/v1/images/edits', {
body: input,
@ -632,25 +663,83 @@ export async function createImageEditTask(
});
}
export type VideoGenerationContentRole =
| 'first_frame'
| 'last_frame'
| 'reference_image'
| 'reference_video'
| 'reference_audio'
| 'digital_human_frame'
| 'reference'
| 'element'
| 'video_feature'
| 'video_base'
| 'shot_prompt';
export interface VideoGenerationContent {
type: 'text' | 'image_url' | 'audio_url' | 'video_url' | 'element';
text?: string;
image_url?: {
url: string;
};
video_url?: {
url: string;
refer_type?: 'feature' | 'base';
keep_original_sound?: 'yes' | 'no';
};
audio_url?: {
url: string;
};
role?: VideoGenerationContentRole;
shot_index?: number;
duration?: number;
name?: string;
element?: {
system_element_id?: string;
inline_element?: {
name: string;
description?: string;
frontal_image_url: string;
refer_images: Array<{ url: string; slot_key?: string }>;
tags?: string[];
};
};
}
export interface VideoGenerationParams {
content: VideoGenerationContent[];
model: string;
aspect_ratio?: string;
resolution?: string;
duration?: number;
audio_list?: Array<{
url?: string;
audio_url?: string;
name?: string;
}>;
audio?: boolean;
framespersecond?: number;
watermark?: boolean;
seed?: number;
camerafixed?: boolean;
camera_control?: string;
camera_control_strength?: number;
prompt_extend?: boolean;
size?: string;
task_id?: string;
conversation_id?: string;
histories?: string;
callback_url?: string;
prompt_optimizer?: boolean;
fast_pretreatment?: boolean;
mode?: 'std' | 'pro';
negative_prompt?: string;
cfg_scale?: number;
}
export async function createVideoGenerationTask(
token: string,
input: {
audio?: boolean;
model: string;
prompt: string;
aspect_ratio?: string;
count?: number;
duration?: number;
duration_seconds?: number;
height?: number;
n?: number;
output_audio?: boolean;
resolution?: string;
runMode?: string;
simulation?: boolean;
size?: string;
width?: number;
},
input: VideoGenerationParams,
): Promise<{ task: GatewayTask; next: Record<string, string> }> {
return request<{ task: GatewayTask; next: Record<string, string> }>('/api/v1/videos/generations', {
body: input,
@ -660,11 +749,46 @@ export async function createVideoGenerationTask(
});
}
export interface GatewayFileUploadResponse extends Record<string, unknown> {
fileUrl?: string;
file_url?: string;
url?: string;
}
export async function uploadFileToStorage(
token: string,
file: File,
source = 'ai-gateway-playground',
): Promise<GatewayFileUploadResponse> {
const form = new FormData();
form.append('file', file);
form.append('source', source);
const response = await fetch(`${API_BASE}/v1/files/upload`, {
body: form,
headers: {
Authorization: `Bearer ${token}`,
},
method: 'POST',
});
const body = await response.text();
if (!response.ok) {
throw new GatewayApiError(parseErrorDetails(body, response.status, `Request failed: ${response.status}`));
}
if (!body) return {};
try {
const parsed = JSON.parse(body) as unknown;
return recordFromUnknown(parsed) ? (parsed as GatewayFileUploadResponse) : {};
} catch {
return { url: body };
}
}
export async function estimatePricing(
token: string,
input: Record<string, unknown>,
): Promise<{ items: unknown[]; resolver: string }> {
return request<{ items: unknown[]; resolver: string }>('/api/v1/pricing/estimate', {
): Promise<GatewayPricingEstimate> {
return request<GatewayPricingEstimate>('/api/v1/pricing/estimate', {
body: input,
method: 'POST',
token,
@ -758,6 +882,55 @@ export async function getNetworkProxyConfig(token: string): Promise<GatewayNetwo
return request<GatewayNetworkProxyConfig>('/api/admin/config/network-proxy', { token });
}
export async function listFileStorageChannels(token: string): Promise<ListResponse<FileStorageChannel>> {
return request<ListResponse<FileStorageChannel>>('/api/admin/system/file-storage/channels', { token });
}
export async function getFileStorageSettings(token: string): Promise<FileStorageSettings> {
return request<FileStorageSettings>('/api/admin/system/file-storage/settings', { token });
}
export async function updateFileStorageSettings(
token: string,
input: FileStorageSettingsUpdateRequest,
): Promise<FileStorageSettings> {
return request<FileStorageSettings>('/api/admin/system/file-storage/settings', {
body: input,
method: 'PATCH',
token,
});
}
export async function createFileStorageChannel(
token: string,
input: FileStorageChannelUpsertRequest,
): Promise<FileStorageChannel> {
return request<FileStorageChannel>('/api/admin/system/file-storage/channels', {
body: input,
method: 'POST',
token,
});
}
export async function updateFileStorageChannel(
token: string,
channelId: string,
input: FileStorageChannelUpsertRequest,
): Promise<FileStorageChannel> {
return request<FileStorageChannel>(`/api/admin/system/file-storage/channels/${channelId}`, {
body: input,
method: 'PATCH',
token,
});
}
export async function deleteFileStorageChannel(token: string, channelId: string): Promise<void> {
await request<void>(`/api/admin/system/file-storage/channels/${channelId}`, {
method: 'DELETE',
token,
});
}
async function request<T>(
path: string,
options: { token?: string; auth?: boolean; method?: string; body?: unknown; headers?: Record<string, string> } = {},

View File

@ -1,6 +1,8 @@
import type {
BaseModelCatalogItem,
CatalogProvider,
FileStorageChannel,
FileStorageSettings,
GatewayAccessRule,
GatewayApiKey,
GatewayAuditLog,
@ -27,6 +29,8 @@ export interface ConsoleData {
auditLogs: GatewayAuditLog[];
apiKeys: GatewayApiKey[];
baseModels: BaseModelCatalogItem[];
fileStorageChannels: FileStorageChannel[];
fileStorageSettings: FileStorageSettings | null;
modelCatalog: ModelCatalogResponse;
models: PlatformModel[];
networkProxyConfig: GatewayNetworkProxyConfig | null;

View File

@ -1,8 +1,10 @@
import type { ReactNode } from 'react';
import { Boxes, Building2, Gauge, History, KeyRound, Route, ServerCog, ShieldCheck, UsersRound, Workflow } from 'lucide-react';
import { Boxes, Building2, Gauge, History, KeyRound, Route, ServerCog, Settings, ShieldCheck, UsersRound, Workflow } from 'lucide-react';
import type {
BaseModelUpsertRequest,
CatalogProviderUpsertRequest,
FileStorageChannelUpsertRequest,
FileStorageSettingsUpdateRequest,
GatewayAccessRuleBatchRequest,
GatewayAccessRuleUpsertRequest,
GatewayTenantUpsertRequest,
@ -29,6 +31,7 @@ import { PricingRulesPanel } from './admin/PricingRulesPanel';
import { ProviderManagementPanel } from './admin/ProviderManagementPanel';
import { RealtimeLoadPanel } from './admin/RealtimeLoadPanel';
import { RuntimePoliciesPanel } from './admin/RuntimePoliciesPanel';
import { SystemSettingsPanel } from './admin/SystemSettingsPanel';
const tabs = [
{ value: 'overview', label: '总览', icon: <Workflow size={15} /> },
@ -42,6 +45,7 @@ const tabs = [
{ value: 'users', label: '用户', icon: <UsersRound size={15} /> },
{ value: 'userGroups', label: '用户组', icon: <UsersRound size={15} /> },
{ value: 'accessRules', label: '模型权限', icon: <KeyRound size={15} /> },
{ value: 'systemSettings', label: '系统设置', icon: <Settings size={15} /> },
{ value: 'auditLogs', label: '审计日志', icon: <History size={15} /> },
] satisfies Array<{ value: AdminSection; label: string; icon: ReactNode }>;
@ -57,6 +61,7 @@ export function AdminPage(props: {
onDeletePricingRuleSet: (ruleSetId: string) => Promise<void>;
onDeleteRuntimePolicySet: (policySetId: string) => Promise<void>;
onDeleteAccessRule: (ruleId: string) => Promise<void>;
onDeleteFileStorageChannel: (channelId: string) => Promise<void>;
onDeleteTenant: (tenantId: string) => Promise<void>;
onDeleteUser: (userId: string) => Promise<void>;
onDeleteUserGroup: (groupId: string) => Promise<void>;
@ -72,6 +77,8 @@ export function AdminPage(props: {
onSaveRunnerPolicy: (input: GatewayRunnerPolicyUpsertRequest) => Promise<void>;
onSaveRuntimePolicySet: (input: RuntimePolicySetUpsertRequest, policySetId?: string) => Promise<void>;
onSaveAccessRule: (input: GatewayAccessRuleUpsertRequest, ruleId?: string) => Promise<void>;
onSaveFileStorageChannel: (input: FileStorageChannelUpsertRequest, channelId?: string) => Promise<void>;
onSaveFileStorageSettings: (input: FileStorageSettingsUpdateRequest) => Promise<void>;
onSaveTenant: (input: GatewayTenantUpsertRequest, tenantId?: string) => Promise<void>;
onSaveUser: (input: GatewayUserUpsertRequest, userId?: string) => Promise<void>;
onSetUserWalletBalance: (userId: string, input: WalletBalanceAdjustmentRequest) => Promise<void>;
@ -172,6 +179,17 @@ export function AdminPage(props: {
{props.section === 'users' && <UsersPanel {...identityPanelProps(props)} />}
{props.section === 'userGroups' && <UserGroupsPanel {...identityPanelProps(props)} />}
{props.section === 'auditLogs' && <AuditLogsPanel auditLogs={props.data.auditLogs} message={props.operationMessage} />}
{props.section === 'systemSettings' && (
<SystemSettingsPanel
channels={props.data.fileStorageChannels}
settings={props.data.fileStorageSettings}
message={props.operationMessage}
state={props.state}
onDeleteFileStorageChannel={props.onDeleteFileStorageChannel}
onSaveFileStorageChannel={props.onSaveFileStorageChannel}
onSaveFileStorageSettings={props.onSaveFileStorageSettings}
/>
)}
</div>
</div>
</div>

View File

@ -34,6 +34,7 @@ export function ApiDocsPage(props: {
onTaskFormChange: (value: TaskForm) => void;
}) {
const current = docs.find((item) => item.key === props.activeDocSection) ?? docs[0];
const isFileDoc = current.key === 'files';
const bodyExample = useMemo(() => requestBodyExample(props.taskForm), [props.taskForm]);
function handleSubmit(event: FormEvent<HTMLFormElement>) {
@ -87,7 +88,7 @@ export function ApiDocsPage(props: {
<h2>Header </h2>
<Button type="button" variant="secondary" size="sm"></Button>
</header>
<ParamRow name="Content-Type" type="string" required value="application/json" />
<ParamRow name="Content-Type" type="string" required value={isFileDoc ? 'multipart/form-data' : 'application/json'} />
<ParamRow name="Accept" type="string" required value="application/json" />
<ParamRow name="Authorization" type="string" value="Bearer {{YOUR_API_KEY}}" />
</section>
@ -97,10 +98,19 @@ export function ApiDocsPage(props: {
<h2>Body </h2>
<Badge variant="outline">application/json</Badge>
</header>
<ParamRow name="model" type="string" required value="模型 ID 或别名" />
<ParamRow name="messages / prompt" type="array|string" required value="对话消息或图片提示词" />
<ParamRow name="simulation" type="boolean" value="测试模式开关" />
<ParamRow name="stream" type="boolean" value="对话进度流式返回" />
{isFileDoc ? (
<>
<ParamRow name="file" type="file" required value="multipart 文件字段" />
<ParamRow name="source" type="string" value="上传来源标记" />
</>
) : (
<>
<ParamRow name="model" type="string" required value="模型 ID 或别名" />
<ParamRow name="messages / prompt" type="array|string" required value="对话消息或图片提示词" />
<ParamRow name="simulation" type="boolean" value="测试模式开关" />
<ParamRow name="stream" type="boolean" value="对话进度流式返回" />
</>
)}
</section>
</main>

File diff suppressed because it is too large Load Diff

View File

@ -78,8 +78,10 @@ type UserGroupForm = {
description: string;
source: string;
priority: string;
rechargeDiscountPolicyJson: string;
billingDiscountPolicyJson: string;
rechargeDiscountFactor: string;
rechargeDiscountPolicy: Record<string, unknown>;
billingDiscountFactor: string;
billingDiscountPolicy: Record<string, unknown>;
rateLimitPolicyJson: string;
quotaPolicyJson: string;
metadataJson: string;
@ -516,8 +518,8 @@ export function UserGroupsPanel(props: IdentityPanelProps) {
<Label><Select size="sm" value={form.status} onChange={(event) => setForm({ ...form, status: event.target.value })}>{userGroupStatuses.map(option)}</Select></Label>
<Label><Input size="sm" value={form.priority} inputMode="numeric" onChange={(event) => setForm({ ...form, priority: event.target.value })} /></Label>
<Label className="spanTwo"><Input size="sm" value={form.description} onChange={(event) => setForm({ ...form, description: event.target.value })} /></Label>
<JsonField label="充值折扣策略 JSON" value={form.rechargeDiscountPolicyJson} onChange={(value) => setForm({ ...form, rechargeDiscountPolicyJson: value })} />
<JsonField label="计费折扣策略 JSON" value={form.billingDiscountPolicyJson} onChange={(value) => setForm({ ...form, billingDiscountPolicyJson: value })} />
<Label><Input size="sm" value={form.rechargeDiscountFactor} inputMode="decimal" placeholder="1 = 不打折0.95 = 95 折" onChange={(event) => setForm({ ...form, rechargeDiscountFactor: event.target.value })} /></Label>
<Label><Input size="sm" value={form.billingDiscountFactor} inputMode="decimal" placeholder="1 = 不打折0.95 = 95 折" onChange={(event) => setForm({ ...form, billingDiscountFactor: event.target.value })} /></Label>
<JsonField label="限流策略 JSON" value={form.rateLimitPolicyJson} onChange={(value) => setForm({ ...form, rateLimitPolicyJson: value })} />
<JsonField label="额度策略 JSON" value={form.quotaPolicyJson} onChange={(value) => setForm({ ...form, quotaPolicyJson: value })} />
<JsonField label="元数据 JSON" value={form.metadataJson} onChange={(value) => setForm({ ...form, metadataJson: value })} />
@ -769,8 +771,10 @@ function defaultUserGroupForm(): UserGroupForm {
description: '',
source: 'gateway',
priority: '100',
rechargeDiscountPolicyJson: '{}',
billingDiscountPolicyJson: '{}',
rechargeDiscountFactor: '1',
rechargeDiscountPolicy: {},
billingDiscountFactor: '1',
billingDiscountPolicy: {},
rateLimitPolicyJson: '{"rules":[]}',
quotaPolicyJson: '{}',
metadataJson: '{}',
@ -785,8 +789,10 @@ function userGroupToForm(group: UserGroup): UserGroupForm {
description: group.description ?? '',
source: group.source,
priority: String(group.priority),
rechargeDiscountPolicyJson: stringifyJson(group.rechargeDiscountPolicy),
billingDiscountPolicyJson: stringifyJson(group.billingDiscountPolicy),
rechargeDiscountFactor: discountFactorText(group.rechargeDiscountPolicy),
rechargeDiscountPolicy: group.rechargeDiscountPolicy ?? {},
billingDiscountFactor: discountFactorText(group.billingDiscountPolicy),
billingDiscountPolicy: group.billingDiscountPolicy ?? {},
rateLimitPolicyJson: stringifyJson(group.rateLimitPolicy),
quotaPolicyJson: stringifyJson(group.quotaPolicy),
metadataJson: stringifyJson(group.metadata),
@ -801,8 +807,8 @@ function formToUserGroupPayload(form: UserGroupForm): UserGroupUpsertRequest {
description: form.description.trim() || undefined,
source: form.source,
priority: Number(form.priority) || 100,
rechargeDiscountPolicy: parseJsonObject(form.rechargeDiscountPolicyJson, '充值折扣策略 JSON'),
billingDiscountPolicy: parseJsonObject(form.billingDiscountPolicyJson, '计费折扣策略 JSON'),
rechargeDiscountPolicy: discountPolicyPayload(form.rechargeDiscountPolicy, form.rechargeDiscountFactor, '充值折扣系数'),
billingDiscountPolicy: discountPolicyPayload(form.billingDiscountPolicy, form.billingDiscountFactor, '计费折扣系数'),
rateLimitPolicy: parseJsonObject(form.rateLimitPolicyJson, '限流策略 JSON'),
quotaPolicy: parseJsonObject(form.quotaPolicyJson, '额度策略 JSON'),
metadata: parseJsonObject(form.metadataJson, '元数据 JSON'),
@ -854,14 +860,57 @@ function newIdempotencyKey() {
}
function discountSummary(group: UserGroup) {
const billing = group.billingDiscountPolicy?.discountFactor ?? group.billingDiscountPolicy?.factor;
const recharge = group.rechargeDiscountPolicy?.discountFactor ?? group.rechargeDiscountPolicy?.factor;
const billing = discountFactorFromPolicy(group.billingDiscountPolicy);
const recharge = discountFactorFromPolicy(group.rechargeDiscountPolicy);
const parts = [];
if (billing) parts.push(`计费 ${billing}`);
if (recharge) parts.push(`充值 ${recharge}`);
if (billing) parts.push(`计费 ${trimNumber(billing)}`);
if (recharge) parts.push(`充值 ${trimNumber(recharge)}`);
return parts.join(' / ') || '未设置';
}
function discountFactorText(policy?: Record<string, unknown>) {
const value = discountFactorFromPolicy(policy);
return value ? trimNumber(value) : '1';
}
function discountFactorFromPolicy(policy?: Record<string, unknown>) {
return numberFromUnknown(policy?.discountFactor) ?? numberFromUnknown(policy?.factor);
}
function discountPolicyPayload(basePolicy: Record<string, unknown>, discountText: string, label: string) {
const policy = { ...basePolicy };
delete policy.discountFactor;
delete policy.factor;
const discount = optionalPositiveNumber(discountText, label);
if (discount && discount !== 1) {
policy.discountFactor = discount;
}
return Object.keys(policy).length ? policy : undefined;
}
function optionalPositiveNumber(value: string, label: string) {
const text = value.trim();
if (!text) return undefined;
const parsed = Number(text);
if (!Number.isFinite(parsed) || parsed <= 0) {
throw new Error(`${label} 必须是大于 0 的数字`);
}
return parsed;
}
function numberFromUnknown(value: unknown) {
if (typeof value === 'number' && Number.isFinite(value) && value > 0) return value;
if (typeof value === 'string' && value.trim()) {
const parsed = Number(value);
if (Number.isFinite(parsed) && parsed > 0) return parsed;
}
return undefined;
}
function trimNumber(value: number) {
return value.toFixed(6).replace(/\.?0+$/, '');
}
function policyKeys(value?: Record<string, unknown>) {
if (!value) return [];
return Object.keys(value).slice(0, 3);

View File

@ -0,0 +1,444 @@
import { useEffect, useState, type FormEvent } from 'react';
import { Database, Pencil, Plus, RotateCcw, Save, ServerCog, Trash2 } from 'lucide-react';
import type { FileStorageChannel, FileStorageChannelUpsertRequest, FileStorageSettings, FileStorageSettingsUpdateRequest } from '@easyai-ai-gateway/contracts';
import { Badge, Button, Card, CardContent, CardHeader, CardTitle, ConfirmDialog, FormDialog, Input, Label, Select, Tabs, Textarea } from '../../components/ui';
import type { LoadState } from '../../types';
type SystemSettingsTab = 'fileStorage';
type FileStorageChannelForm = {
apiKey: string;
apiKeyPreview: string;
channelKey: string;
configJson: string;
name: string;
priority: string;
provider: string;
retryPolicyJson: string;
scenes: string[];
status: string;
uploadUrl: string;
};
const defaultUploadUrl = 'http://127.0.0.1:3001/v1/files/upload';
const defaultRetryPolicy = {
enabled: true,
maxRetries: 3,
backoffSeconds: [60, 120, 180],
strategy: 'exponential',
};
const providerOptions = [
{ value: 'server_main_openapi', label: 'server-main OpenAPI' },
{ value: 'aliyun_oss', label: '阿里云 OSS' },
{ value: 'tencent_cos', label: '腾讯云 COS' },
];
const defaultScenes = ['upload', 'image_result'];
const sceneOptions = [
{ value: 'upload', label: '上传', description: 'OpenAPI / 管理端主动上传文件' },
{ value: 'image_result', label: '返图', description: '模型返回 base64 / buffer 图片或视频后的转存' },
];
const resultUploadPolicyOptions = [
{ value: 'default', label: '默认:仅非链接资源转存', description: 'URL 结果直接保存base64 / buffer 等结果转存后保存 URL' },
{ value: 'upload_all', label: '全部转存', description: 'URL、base64、buffer 等返图结果都会转存到当前文件渠道' },
{ value: 'upload_none', label: '全部不转存', description: '链接结果直接保存base64 / buffer 结果写入网关本地静态托管后保存 URL' },
];
export function SystemSettingsPanel(props: {
channels: FileStorageChannel[];
message: string;
settings: FileStorageSettings | null;
state: LoadState;
onDeleteFileStorageChannel: (channelId: string) => Promise<void>;
onSaveFileStorageChannel: (input: FileStorageChannelUpsertRequest, channelId?: string) => Promise<void>;
onSaveFileStorageSettings: (input: FileStorageSettingsUpdateRequest) => Promise<void>;
}) {
const [activeTab, setActiveTab] = useState<SystemSettingsTab>('fileStorage');
const [dialogOpen, setDialogOpen] = useState(false);
const [editingChannel, setEditingChannel] = useState<FileStorageChannel | null>(null);
const [pendingDeleteChannel, setPendingDeleteChannel] = useState<FileStorageChannel | null>(null);
const [form, setForm] = useState<FileStorageChannelForm>(() => defaultChannelForm());
const [settingsPolicy, setSettingsPolicy] = useState(() => normalizeResultUploadPolicy(props.settings?.resultUploadPolicy));
const [localError, setLocalError] = useState('');
useEffect(() => {
setSettingsPolicy(normalizeResultUploadPolicy(props.settings?.resultUploadPolicy));
}, [props.settings?.resultUploadPolicy]);
function openCreateDialog() {
setEditingChannel(null);
setForm(defaultChannelForm(`server-main-${Date.now().toString(36)}`));
setLocalError('');
setDialogOpen(true);
}
function editChannel(channel: FileStorageChannel) {
setEditingChannel(channel);
setForm(channelToForm(channel));
setLocalError('');
setDialogOpen(true);
}
function closeDialog() {
setEditingChannel(null);
setForm(defaultChannelForm());
setLocalError('');
setDialogOpen(false);
}
async function submit(event: FormEvent<HTMLFormElement>) {
event.preventDefault();
setLocalError('');
if (form.scenes.length === 0) {
setLocalError('请至少选择一个适用场景。');
return;
}
try {
await props.onSaveFileStorageChannel(formToPayload(form), editingChannel?.id);
closeDialog();
} catch (err) {
setLocalError(err instanceof Error ? err.message : '文件存储渠道保存失败');
}
}
async function deleteChannel(channel: FileStorageChannel) {
try {
await props.onDeleteFileStorageChannel(channel.id);
setPendingDeleteChannel(null);
if (editingChannel?.id === channel.id) closeDialog();
} catch (err) {
setLocalError(err instanceof Error ? err.message : '文件存储渠道删除失败');
}
}
async function saveSettings() {
setLocalError('');
try {
await props.onSaveFileStorageSettings({ resultUploadPolicy: normalizeResultUploadPolicy(settingsPolicy) });
} catch (err) {
setLocalError(err instanceof Error ? err.message : '文件存储全局策略保存失败');
}
}
return (
<div className="pageStack">
<Card>
<CardHeader>
<div>
<CardTitle></CardTitle>
<p className="mutedText">使 60/120/180 退</p>
</div>
<Badge variant="secondary">{props.channels.length} </Badge>
</CardHeader>
<CardContent>
{(props.message || localError) && <p className="formMessage">{localError || props.message}</p>}
<Tabs
value={activeTab}
tabs={[{ value: 'fileStorage', label: '文件存储', icon: <Database size={15} /> }]}
onValueChange={setActiveTab}
/>
</CardContent>
</Card>
{activeTab === 'fileStorage' && (
<section className="fileStoragePanel">
<div className="fileStorageSettingsCard">
<div>
<strong></strong>
<span></span>
</div>
<Label>
<Select value={settingsPolicy} onChange={(event) => setSettingsPolicy(event.target.value)}>
{resultUploadPolicyOptions.map((item) => <option value={item.value} key={item.value}>{item.label}</option>)}
</Select>
<small>{resultUploadPolicyDescription(settingsPolicy)}</small>
</Label>
<Button type="button" onClick={saveSettings} disabled={props.state === 'loading'}>
<Save size={15} />
</Button>
</div>
<div className="fileStorageToolbar">
<div>
<strong></strong>
<span>server-main OpenAPI API Key</span>
</div>
<Button type="button" onClick={openCreateDialog}>
<Plus size={15} />
</Button>
</div>
<div className="fileStorageGrid">
{props.channels.map((channel) => (
<article className="fileStorageCard" key={channel.id}>
<header>
<div className="iconBox"><ServerCog size={18} /></div>
<div>
<strong>{channel.name}</strong>
<span>{channel.channelKey}</span>
</div>
<Badge variant={channel.status === 'enabled' ? 'success' : 'secondary'}>{channel.status}</Badge>
</header>
<div className="fileStorageMeta">
<span>: {providerLabel(channel.provider)}</span>
<span>: {sceneSummary(channel.scenes)}</span>
<span>: {channel.priority}</span>
<span>: {retryPolicySummary(channel.retryPolicy)}</span>
{channel.uploadUrl && <span>: {channel.uploadUrl}</span>}
{apiKeyPreview(channel) && <span>API Key: {apiKeyPreview(channel)}</span>}
{channel.lastError && <span>: {channel.lastError}</span>}
</div>
<footer>
<Button type="button" variant="outline" size="sm" onClick={() => editChannel(channel)}>
<Pencil size={14} />
</Button>
<Button type="button" variant="destructive" size="sm" onClick={() => setPendingDeleteChannel(channel)}>
<Trash2 size={14} />
</Button>
</footer>
</article>
))}
{!props.channels.length && (
<Card>
<CardContent className="emptyState">
<strong></strong>
</CardContent>
</Card>
)}
</div>
</section>
)}
<FormDialog
ariaLabel={editingChannel ? '编辑文件存储渠道' : '新增文件存储渠道'}
bodyClassName="fileStorageDialogBody"
eyebrow={editingChannel ? 'Edit Storage Channel' : 'New Storage Channel'}
footer={(
<>
<Button type="submit" disabled={props.state === 'loading'}>
{editingChannel ? <Save size={15} /> : <Plus size={15} />}
{editingChannel ? '保存渠道' : '新增渠道'}
</Button>
<Button type="button" variant="outline" onClick={closeDialog}>
<RotateCcw size={15} />
</Button>
</>
)}
open={dialogOpen}
title={editingChannel ? '编辑文件存储渠道' : '新增文件存储渠道'}
onClose={closeDialog}
onSubmit={submit}
>
<Label>
<Input value={form.channelKey} onChange={(event) => setForm({ ...form, channelKey: event.target.value })} placeholder="server-main-openapi" />
</Label>
<Label>
<Input value={form.name} onChange={(event) => setForm({ ...form, name: event.target.value })} placeholder="server-main OpenAPI" />
</Label>
<Label>
<Select value={form.provider} onChange={(event) => setForm({ ...form, provider: event.target.value })}>
{providerOptions.map((item) => <option value={item.value} key={item.value}>{item.label}</option>)}
</Select>
</Label>
<Label>
<Select value={form.status} onChange={(event) => setForm({ ...form, status: event.target.value })}>
<option value="enabled">enabled</option>
<option value="disabled">disabled</option>
</Select>
</Label>
<Label className="spanTwo">
<div className="fileStorageSceneGrid">
{sceneOptions.map((scene) => (
<FileStorageSceneToggle
checked={form.scenes.includes(scene.value)}
description={scene.description}
key={scene.value}
label={scene.label}
onChange={(checked) => setForm({ ...form, scenes: nextScenes(form.scenes, scene.value, checked) })}
/>
))}
</div>
</Label>
<Label className="spanTwo">
<Input value={form.uploadUrl} onChange={(event) => setForm({ ...form, uploadUrl: event.target.value })} placeholder={defaultUploadUrl} />
</Label>
<Label className="platformCredentialField">
API Key
<Input value={form.apiKey} onChange={(event) => setForm({ ...form, apiKey: event.target.value })} placeholder={credentialInputPlaceholder(form.apiKeyPreview)} />
<small></small>
</Label>
<Label>
<Input type="number" min={1} value={form.priority} onChange={(event) => setForm({ ...form, priority: event.target.value })} />
</Label>
<Label className="spanTwo">
JSON
<Textarea value={form.retryPolicyJson} onChange={(event) => setForm({ ...form, retryPolicyJson: event.target.value })} />
</Label>
<Label className="spanTwo">
JSON
<Textarea value={form.configJson} onChange={(event) => setForm({ ...form, configJson: event.target.value })} />
</Label>
</FormDialog>
<ConfirmDialog
confirmLabel="删除渠道"
description="删除后该文件存储渠道不会再参与上传轮转。"
loading={props.state === 'loading'}
open={Boolean(pendingDeleteChannel)}
title={`确认删除文件存储渠道 ${pendingDeleteChannel?.name ?? ''}`}
onCancel={() => setPendingDeleteChannel(null)}
onConfirm={() => pendingDeleteChannel ? deleteChannel(pendingDeleteChannel) : undefined}
/>
</div>
);
}
function defaultChannelForm(channelKey = ''): FileStorageChannelForm {
return {
apiKey: '',
apiKeyPreview: '',
channelKey,
configJson: '{}',
name: 'server-main OpenAPI',
priority: '100',
provider: 'server_main_openapi',
retryPolicyJson: stringifyJson(defaultRetryPolicy),
scenes: defaultScenes,
status: 'disabled',
uploadUrl: defaultUploadUrl,
};
}
function channelToForm(channel: FileStorageChannel): FileStorageChannelForm {
const preview = apiKeyPreview(channel);
return {
apiKey: preview,
apiKeyPreview: preview,
channelKey: channel.channelKey,
configJson: stringifyJson(channel.config ?? {}),
name: channel.name,
priority: String(channel.priority || 100),
provider: channel.provider || 'server_main_openapi',
retryPolicyJson: stringifyJson(channel.retryPolicy ?? defaultRetryPolicy),
scenes: normalizeScenes(channel.scenes),
status: channel.status || 'disabled',
uploadUrl: channel.uploadUrl || defaultUploadUrl,
};
}
function formToPayload(form: FileStorageChannelForm): FileStorageChannelUpsertRequest {
return {
apiKey: apiKeyPayloadValue(form),
channelKey: form.channelKey.trim(),
config: parseJsonObject(form.configJson, '扩展配置 JSON'),
name: form.name.trim(),
priority: Number(form.priority) || 100,
provider: form.provider,
retryPolicy: parseJsonObject(form.retryPolicyJson, '重试策略 JSON'),
scenes: normalizeScenes(form.scenes),
status: form.status,
uploadUrl: form.uploadUrl.trim(),
};
}
function parseJsonObject(value: string, label: string) {
try {
const parsed = JSON.parse(value || '{}') as unknown;
if (!parsed || typeof parsed !== 'object' || Array.isArray(parsed)) {
throw new Error(`${label} 必须是对象`);
}
return parsed as Record<string, unknown>;
} catch (err) {
if (err instanceof Error && err.message.includes(label)) throw err;
throw new Error(`${label} 格式不正确`);
}
}
function stringifyJson(value: unknown) {
return JSON.stringify(value ?? {}, null, 2);
}
function providerLabel(provider: string) {
return providerOptions.find((item) => item.value === provider)?.label ?? provider;
}
function sceneSummary(scenes: string[] | undefined) {
return normalizeScenes(scenes).map((scene) => sceneOptions.find((item) => item.value === scene)?.label ?? scene).join(' / ');
}
function normalizeResultUploadPolicy(value: string | undefined) {
const normalized = (value || 'default').trim();
return resultUploadPolicyOptions.some((item) => item.value === normalized) ? normalized : 'default';
}
function resultUploadPolicyDescription(value: string | undefined) {
const normalized = normalizeResultUploadPolicy(value);
return resultUploadPolicyOptions.find((item) => item.value === normalized)?.description ?? '';
}
function normalizeScenes(scenes: string[] | undefined) {
const next = Array.from(new Set((scenes ?? []).map((scene) => scene.trim()).filter(Boolean)));
return next.length ? next : [...defaultScenes];
}
function nextScenes(current: string[], scene: string, checked: boolean) {
if (checked) return normalizeScenes([...current, scene]);
return current.filter((item) => item !== scene);
}
function retryPolicySummary(policy?: Record<string, unknown>) {
const maxRetries = numberFromUnknown(policy?.maxRetries) || 3;
const backoff = Array.isArray(policy?.backoffSeconds) ? policy?.backoffSeconds.join('/') : '60/120/180';
return `${maxRetries} 次 · ${backoff}s`;
}
function apiKeyPreview(channel: FileStorageChannel) {
const value = channel.credentialsPreview?.apiKey;
return typeof value === 'string' ? value : '';
}
function apiKeyPayloadValue(form: FileStorageChannelForm) {
const value = form.apiKey.trim();
if (form.apiKeyPreview && value === form.apiKeyPreview) return undefined;
return value || (form.apiKeyPreview ? '' : undefined);
}
function credentialInputPlaceholder(preview: string) {
return preview ? '填写新凭证以覆盖当前值' : 'sk-...';
}
function FileStorageSceneToggle(props: { checked: boolean; description: string; label: string; onChange: (checked: boolean) => void }) {
return (
<label className="platformToggle">
<input type="checkbox" checked={props.checked} onChange={(event) => props.onChange(event.target.checked)} />
<span>
<strong>{props.label}</strong>
<small>{props.description}</small>
</span>
</label>
);
}
function numberFromUnknown(value: unknown) {
if (typeof value === 'number' && Number.isFinite(value)) return value;
if (typeof value === 'string' && value.trim()) {
const parsed = Number(value);
if (Number.isFinite(parsed)) return parsed;
}
return 0;
}

View File

@ -0,0 +1,685 @@
import { useEffect, useMemo, useRef, useState } from 'react';
import {
AssistantRuntimeProvider,
ComposerPrimitive,
ErrorPrimitive,
MessagePrimitive,
ThreadPrimitive,
useLocalRuntime,
useMessage,
useMessagePartText,
useThread,
type ChatModelAdapter,
type ThreadMessage,
type ThreadMessageLike,
} from '@assistant-ui/react';
import { StreamdownTextPrimitive } from '@assistant-ui/react-streamdown';
import { cjk } from '@streamdown/cjk';
import { code } from '@streamdown/code';
import { math } from '@streamdown/math';
import { mermaid } from '@streamdown/mermaid';
import type { GatewayApiKey } from '@easyai-ai-gateway/contracts';
import { Send } from 'lucide-react';
import { Button, Select } from '../components/ui';
import { GatewayApiError, streamChatCompletionText } from '../api';
import type { PlaygroundMode } from '../types';
import {
chatUploadAccept as sharedChatUploadAccept,
mediaUploadSummaryMessage as sharedMediaUploadSummaryMessage,
openAIContentFromPromptAndUploads,
PlaygroundReferencePicker,
uploadPlaygroundFiles as sharedUploadPlaygroundFiles,
type OpenAIChatContentPart,
type PlaygroundUpload,
} from './playground-upload';
import {
ModeSwitch,
PlaygroundGreeting,
apiKeyNoticeText,
modeOptions,
modelOptionLabel,
placeholderByMode,
resolveSelectedApiKeyId,
type ModelOption,
} from './playground-shared';
const CHAT_MESSAGES_STORAGE_KEY = 'easyai:playground:chat-messages:v1';
const CHAT_MESSAGES_STORAGE_LIMIT = 100;
const streamdownPlugins = { cjk, code, math, mermaid };
type OpenAIChatRole = 'assistant' | 'user';
interface StoredOpenAIChatMessage {
content: OpenAIChatContentPart[] | string;
createdAt: string;
id: string;
role: OpenAIChatRole;
}
type StoredOpenAIChatMessagesById = Record<string, StoredOpenAIChatMessage>;
export function AssistantChatPlayground(props: {
apiKeySecretsById: Record<string, string>;
apiKeys: GatewayApiKey[];
modelOptions: ModelOption[];
selectedApiKeyId: string;
selectedModel: string;
token: string;
onApiKeyChange: (apiKeyId: string) => void;
onCreateApiKey: () => void;
onLogin: () => void;
onModeChange: (mode: PlaygroundMode) => void;
onModelChange: (value: string) => void;
}) {
const activeApiKeyId = resolveSelectedApiKeyId(props.apiKeys, props.apiKeySecretsById, props.selectedApiKeyId);
const activeApiKeySecret = activeApiKeyId ? props.apiKeySecretsById[activeApiKeyId] ?? '' : '';
const canRun = Boolean(props.token && props.selectedModel && activeApiKeySecret);
const apiKeyNotice = apiKeyNoticeText(props.apiKeys, props.apiKeySecretsById);
const initialStoredMessages = useMemo(() => readStoredOpenAIChatMessages(), []);
const initialMessages = useMemo(() => initialStoredMessages.map(threadMessageLikeFromOpenAIMessage), [initialStoredMessages]);
const [storedMessagesById, setStoredMessagesById] = useState<StoredOpenAIChatMessagesById>(() => indexStoredOpenAIChatMessages(initialStoredMessages));
const [chatUploadMessage, setChatUploadMessage] = useState('');
const [chatUploads, setChatUploads] = useState<PlaygroundUpload[]>([]);
const [chatUploading, setChatUploading] = useState(false);
const chatUploadsRef = useRef(chatUploads);
const storedMessagesByIdRef = useRef(storedMessagesById);
useEffect(() => {
chatUploadsRef.current = chatUploads;
}, [chatUploads]);
useEffect(() => {
storedMessagesByIdRef.current = storedMessagesById;
}, [storedMessagesById]);
async function uploadChatFiles(files: File[]) {
if (!files.length) return;
if (!props.token) {
props.onLogin();
return;
}
if (!activeApiKeySecret) {
setChatUploadMessage('请选择可用于测试的 API Key 后再上传。');
return;
}
setChatUploading(true);
setChatUploadMessage('');
try {
const { items, warnings } = await sharedUploadPlaygroundFiles(activeApiKeySecret, files, {
allowFiles: true,
source: 'ai-gateway-playground-chat',
});
if (items.length) {
setChatUploads((current) => [...current, ...items]);
}
setChatUploadMessage(warnings[0] ?? (items.length ? `已上传 ${items.length} 个附件。` : ''));
} catch (err) {
setChatUploadMessage(err instanceof Error ? err.message : '文件上传失败');
} finally {
setChatUploading(false);
}
}
const adapter = useMemo<ChatModelAdapter>(() => ({
async *run({ abortSignal, messages }) {
if (!props.token) {
props.onLogin();
throw new GatewayApiError('请先登录后再测试模型。');
}
if (!activeApiKeySecret) {
throw new GatewayApiError('请选择可用于测试的 API Key如果列表为空请刷新或重新创建一个 Key。');
}
if (!props.selectedModel) {
throw new GatewayApiError('当前没有可用的大模型,请确认用户组权限或平台模型配置。');
}
const requestUploads = chatUploadsRef.current;
const request = buildGatewayChatMessages(messages, requestUploads, storedMessagesByIdRef.current);
if (request.lastUserMessage) {
setStoredMessagesById((current) => ({
...current,
[request.lastUserMessage!.id]: request.lastUserMessage!,
}));
}
if (requestUploads.length) {
chatUploadsRef.current = [];
setChatUploads([]);
setChatUploadMessage('');
}
let text = '';
for await (const delta of streamChatCompletionText(
activeApiKeySecret,
{
messages: request.messages,
model: props.selectedModel,
},
abortSignal,
)) {
text += delta;
yield {
content: [{ type: 'text', text }],
};
}
yield {
content: [{ type: 'text', text }],
status: { type: 'complete', reason: 'stop' },
};
},
}), [activeApiKeySecret, props]);
const runtime = useLocalRuntime(adapter, { initialMessages });
return (
<AssistantRuntimeProvider runtime={runtime}>
<AssistantChatPersistenceBridge storedMessagesById={storedMessagesById} />
<ThreadPrimitive.Root className="assistantThreadRoot">
<ThreadPrimitive.Empty>
<div className="assistantEmptyStage">
<AssistantEmptyState
canRun={canRun}
modelOptions={props.modelOptions}
selectedModel={props.selectedModel}
token={props.token}
activeApiKeySecret={activeApiKeySecret}
uploadAccept={sharedChatUploadAccept}
uploadMessage={chatUploadMessage}
uploads={chatUploads}
uploading={chatUploading}
onModeChange={props.onModeChange}
onModelChange={props.onModelChange}
onRemoveUpload={(id) => setChatUploads((current) => current.filter((item) => item.id !== id))}
onUploadFiles={(files) => void uploadChatFiles(files)}
/>
</div>
</ThreadPrimitive.Empty>
<ThreadPrimitive.If empty={false}>
<div className="assistantShell" data-has-notice={Boolean(apiKeyNotice)}>
{apiKeyNotice && (
<div className="assistantApiKeyNotice">
<span>{apiKeyNotice}</span>
<Button type="button" size="sm" variant="secondary" onClick={props.onCreateApiKey}>
API Key
</Button>
</div>
)}
<ThreadPrimitive.Viewport className="assistantThreadViewport">
<div className="assistantMessageList">
<ThreadPrimitive.Messages
components={{
Message: () => (
<AssistantMessage storedMessagesById={storedMessagesById} />
),
}}
/>
</div>
<ThreadPrimitive.ViewportFooter className="assistantComposerDock">
<AssistantChatComposer
canRun={canRun}
docked
modelOptions={props.modelOptions}
placeholder={assistantPlaceholder(props.token, props.selectedModel, activeApiKeySecret)}
selectedModel={props.selectedModel}
uploadAccept={sharedChatUploadAccept}
uploadMessage={chatUploadMessage}
uploads={chatUploads}
uploading={chatUploading}
onModeChange={props.onModeChange}
onModelChange={props.onModelChange}
onRemoveUpload={(id) => setChatUploads((current) => current.filter((item) => item.id !== id))}
onUploadFiles={(files) => void uploadChatFiles(files)}
/>
</ThreadPrimitive.ViewportFooter>
</ThreadPrimitive.Viewport>
</div>
</ThreadPrimitive.If>
</ThreadPrimitive.Root>
</AssistantRuntimeProvider>
);
}
function AssistantChatPersistenceBridge(props: { storedMessagesById: StoredOpenAIChatMessagesById }) {
const messages = useThread((state) => state.messages);
const skipInitialEmptyWriteRef = useRef(true);
useEffect(() => {
if (skipInitialEmptyWriteRef.current) {
skipInitialEmptyWriteRef.current = false;
if (!messages.length && hasStoredChatMessages()) return;
}
writeStoredOpenAIChatMessages(messages, props.storedMessagesById);
}, [messages, props.storedMessagesById]);
return null;
}
function AssistantEmptyState(props: {
activeApiKeySecret: string;
canRun: boolean;
modelOptions: ModelOption[];
selectedModel: string;
token: string;
uploadAccept: string;
uploadMessage: string;
uploads: PlaygroundUpload[];
uploading: boolean;
onModeChange: (mode: PlaygroundMode) => void;
onModelChange: (value: string) => void;
onRemoveUpload: (id: string) => void;
onUploadFiles: (files: File[]) => void;
}) {
const activeMode = modeOptions.find((item) => item.value === 'chat') ?? modeOptions[0];
const placeholder = props.canRun ? placeholderByMode.chat : assistantPlaceholder(props.token, props.selectedModel, props.activeApiKeySecret);
return (
<div className="assistantEmpty">
<ModeSwitch activeMode="chat" onModeChange={props.onModeChange} />
<PlaygroundGreeting activeMode={activeMode} />
<AssistantChatComposer
canRun={props.canRun}
modelOptions={props.modelOptions}
placeholder={placeholder}
selectedModel={props.selectedModel}
uploadAccept={props.uploadAccept}
uploadMessage={props.uploadMessage}
uploads={props.uploads}
uploading={props.uploading}
onModeChange={props.onModeChange}
onModelChange={props.onModelChange}
onRemoveUpload={props.onRemoveUpload}
onUploadFiles={props.onUploadFiles}
/>
</div>
);
}
function AssistantChatComposer(props: {
canRun: boolean;
docked?: boolean;
modelOptions: ModelOption[];
placeholder: string;
selectedModel: string;
uploadAccept?: string;
uploadMessage?: string;
uploads?: PlaygroundUpload[];
uploading?: boolean;
onModeChange: (mode: PlaygroundMode) => void;
onModelChange: (value: string) => void;
onRemoveUpload?: (id: string) => void;
onUploadFiles?: (files: File[]) => void;
}) {
const className = ['playgroundComposer', 'assistantChatComposer', props.docked ? 'assistantDockComposer' : 'assistantEmptyComposer'].join(' ');
const uploadMessage = props.uploadMessage || sharedMediaUploadSummaryMessage(props.uploads ?? [], 'chat', 'text_to_video');
return (
<ComposerPrimitive.Root className={className}>
<div className="composerBody composerBodyWithReferences">
<PlaygroundReferencePicker
accept={props.uploadAccept ?? sharedChatUploadAccept}
disabled={!props.canRun || !props.onUploadFiles}
mode="chat"
uploadLabel="上传附件"
uploads={props.uploads ?? []}
uploading={props.uploading}
onFiles={props.onUploadFiles}
onRemove={props.onRemoveUpload}
/>
<div className="composerInputStack">
<ComposerPrimitive.Input
className="assistantEmptyInput"
disabled={!props.canRun}
placeholder={props.placeholder}
/>
{uploadMessage && <div className="composerUploadMessage">{uploadMessage}</div>}
</div>
</div>
<div className="composerFooter">
<Select value="chat" onChange={(event) => props.onModeChange(event.target.value as PlaygroundMode)}>
{modeOptions.map((item) => <option value={item.value} key={item.value}>{item.label}</option>)}
</Select>
<Select
className="playgroundModelSelect"
value={props.selectedModel}
disabled={!props.modelOptions.length}
onChange={(event) => props.onModelChange(event.target.value)}
>
{props.modelOptions.length ? props.modelOptions.map((item) => (
<option value={item.value} key={item.value}>{modelOptionLabel(item)}</option>
)) : <option value=""></option>}
</Select>
<ComposerPrimitive.Send className="composerSendButton" disabled={!props.canRun} aria-label="发送消息">
<Send size={18} />
</ComposerPrimitive.Send>
</div>
</ComposerPrimitive.Root>
);
}
function AssistantMessage(props: { storedMessagesById: StoredOpenAIChatMessagesById }) {
const messageId = useMessage((state) => state.id);
const messageContent = useMessage((state) => state.content);
const hasError = useMessage((state) => state.status?.type === 'incomplete' && state.status.reason === 'error');
const storedMessage = messageId ? props.storedMessagesById[messageId] : undefined;
const imageParts = imagePartsFromOpenAIContent(storedMessage?.content);
const hasText = threadMessageContentText(messageContent).trim().length > 0;
return (
<MessagePrimitive.Root className="assistantMessage">
<MessagePrimitive.If user>
<div className="assistantUserMessage">
<ChatMessageImagePreviews parts={imageParts} />
{hasText && (
<div className="assistantBubble user">
<MessagePrimitive.Parts components={{ Text: PlainMessageText }} />
</div>
)}
</div>
</MessagePrimitive.If>
<MessagePrimitive.If assistant>
<div className={hasError ? 'assistantBubble assistant error' : 'assistantBubble assistant'}>
<MessagePrimitive.Parts components={{ Text: AssistantMarkdownText }} />
<MessagePrimitive.Error>
<strong></strong>
<ErrorPrimitive.Message className="assistantErrorMessage" />
</MessagePrimitive.Error>
{!hasError && (
<MessagePrimitive.If hasContent={false}>
<span className="assistantTyping">...</span>
</MessagePrimitive.If>
)}
</div>
</MessagePrimitive.If>
</MessagePrimitive.Root>
);
}
function ChatMessageImagePreviews(props: { parts: Array<{ name: string; url: string }> }) {
if (!props.parts.length) return null;
return (
<div className="assistantUserImageGrid">
{props.parts.map((item) => (
<a className="assistantUserImagePreview" href={item.url} key={item.url} rel="noreferrer" target="_blank" title={item.name}>
<img src={item.url} alt={item.name} loading="lazy" />
</a>
))}
</div>
);
}
function PlainMessageText() {
const { text } = useMessagePartText();
return <span className="assistantPlainText">{text}</span>;
}
function AssistantMarkdownText() {
return (
<StreamdownTextPrimitive
containerClassName="assistantMarkdown"
plugins={streamdownPlugins}
shikiTheme={['github-light', 'github-dark']}
/>
);
}
interface GatewayChatMessageForRequest extends Record<string, unknown> {
content: OpenAIChatContentPart[] | string;
role: OpenAIChatRole;
}
function buildGatewayChatMessages(
messages: readonly ThreadMessage[],
uploads: PlaygroundUpload[],
storedMessagesById: StoredOpenAIChatMessagesById,
) {
const sourceMessages = messages.filter((message) => message.role === 'user' || message.role === 'assistant');
let sourceLastUserIndex = -1;
sourceMessages.forEach((message, index) => {
if (message.role === 'user') sourceLastUserIndex = index;
});
const gatewayMessages: GatewayChatMessageForRequest[] = [];
let lastUserMessage: StoredOpenAIChatMessage | undefined;
sourceMessages.forEach((message, index) => {
const isUploadTarget = uploads.length > 0 && index === sourceLastUserIndex && message.role === 'user';
const text = threadMessageText(message);
const preserved = storedMessagesById[message.id];
const content = isUploadTarget
? openAIContentFromPromptAndUploads(text, uploads)
: preserved?.content ?? text;
if (!openAIContentHasPayload(content)) return;
gatewayMessages.push({
content,
role: message.role,
});
if (isUploadTarget) {
lastUserMessage = {
content,
createdAt: message.createdAt.toISOString(),
id: message.id,
role: 'user',
};
}
});
return { lastUserMessage, messages: gatewayMessages };
}
function readStoredOpenAIChatMessages(): StoredOpenAIChatMessage[] {
if (typeof window === 'undefined') return [];
try {
const raw = window.localStorage.getItem(CHAT_MESSAGES_STORAGE_KEY);
if (!raw) return [];
const parsed = JSON.parse(raw) as unknown;
const record = recordFromUnknown(parsed);
const source = Array.isArray(parsed) ? parsed : record?.messages;
if (!Array.isArray(source)) return [];
const legacyImagesByMessageId = recordFromUnknown(record?.imageUploadsByMessageId);
return source
.map((item) => storedOpenAIChatMessageFromStorage(item, legacyImagesByMessageId))
.filter((item): item is StoredOpenAIChatMessage => Boolean(item))
.slice(-CHAT_MESSAGES_STORAGE_LIMIT);
} catch {
return [];
}
}
function writeStoredOpenAIChatMessages(messages: readonly ThreadMessage[], preservedMessagesById: StoredOpenAIChatMessagesById) {
if (typeof window === 'undefined') return;
try {
const storedMessages = messages
.map((message) => storedOpenAIChatMessageFromThread(message, preservedMessagesById[message.id]))
.filter((item): item is StoredOpenAIChatMessage => Boolean(item))
.slice(-CHAT_MESSAGES_STORAGE_LIMIT);
if (!storedMessages.length) {
window.localStorage.removeItem(CHAT_MESSAGES_STORAGE_KEY);
return;
}
window.localStorage.setItem(CHAT_MESSAGES_STORAGE_KEY, JSON.stringify({
messages: storedMessages,
version: 2,
}));
} catch {
// Best effort only: local chat history should not block sending messages.
}
}
export function clearStoredChatMessages() {
if (typeof window === 'undefined') return;
try {
window.localStorage.removeItem(CHAT_MESSAGES_STORAGE_KEY);
} catch {
// Ignore storage errors.
}
}
function hasStoredChatMessages() {
if (typeof window === 'undefined') return false;
try {
return Boolean(window.localStorage.getItem(CHAT_MESSAGES_STORAGE_KEY));
} catch {
return false;
}
}
function storedOpenAIChatMessageFromThread(message: ThreadMessage, preserved?: StoredOpenAIChatMessage): StoredOpenAIChatMessage | undefined {
if (message.role !== 'assistant' && message.role !== 'user') return undefined;
const content = message.role === 'user' && preserved
? preserved.content
: threadMessageText(message);
if (!openAIContentHasPayload(content)) return undefined;
return {
content,
createdAt: message.createdAt.toISOString(),
id: message.id,
role: message.role,
};
}
function storedOpenAIChatMessageFromStorage(value: unknown, legacyImagesByMessageId?: Record<string, unknown>): StoredOpenAIChatMessage | undefined {
const record = recordFromUnknown(value);
if (!record) return undefined;
const role = record.role === 'assistant' || record.role === 'user' ? record.role : undefined;
const id = stringFromUnknown(record.id);
if (!role || !id) return undefined;
const createdAt = dateStringFromUnknown(record.createdAt) ?? new Date().toISOString();
let content = openAIContentFromUnknown(record.content);
const legacyImages = Array.isArray(legacyImagesByMessageId?.[id]) ? legacyImagesByMessageId[id] as unknown[] : [];
if (role === 'user' && legacyImages.length && (typeof content === 'string' || !content)) {
content = openAIContentWithLegacyImages(typeof content === 'string' ? content : '', legacyImages);
}
if (!content || !openAIContentHasPayload(content)) return undefined;
return {
content,
createdAt,
id,
role,
};
}
function threadMessageLikeFromOpenAIMessage(message: StoredOpenAIChatMessage): ThreadMessageLike {
return {
content: threadMessageLikeContentFromOpenAIContent(message.content),
createdAt: new Date(message.createdAt),
id: message.id,
role: message.role,
status: message.role === 'assistant' ? { type: 'complete', reason: 'stop' } : undefined,
};
}
function indexStoredOpenAIChatMessages(messages: StoredOpenAIChatMessage[]) {
return messages.reduce<StoredOpenAIChatMessagesById>((result, message) => {
result[message.id] = message;
return result;
}, {});
}
function threadMessageLikeContentFromOpenAIContent(content: OpenAIChatContentPart[] | string): ThreadMessageLike['content'] {
const text = openAIContentText(content);
return text || ' ';
}
function threadMessageText(message: ThreadMessage) {
return threadMessageContentText(message.content).trim();
}
function threadMessageContentText(content: ThreadMessage['content']) {
return content
.map((part) => part.type === 'text' ? part.text : '')
.join('')
.trim();
}
function openAIContentFromUnknown(value: unknown): OpenAIChatContentPart[] | string | undefined {
if (typeof value === 'string') return value;
if (!Array.isArray(value)) return undefined;
const parts = value
.map(openAIContentPartFromUnknown)
.filter((item): item is OpenAIChatContentPart => Boolean(item));
return parts.length ? parts : undefined;
}
function openAIContentPartFromUnknown(value: unknown): OpenAIChatContentPart | undefined {
const record = recordFromUnknown(value);
if (!record) return undefined;
if (record.type === 'text' && typeof record.text === 'string') {
return { type: 'text', text: record.text };
}
if (record.type === 'image_url') {
const imageUrl = recordFromUnknown(record.image_url);
const url = stringFromUnknown(imageUrl?.url);
return url ? { type: 'image_url', image_url: { url } } : undefined;
}
if (record.type === 'video_url') {
const videoUrl = recordFromUnknown(record.video_url);
const url = stringFromUnknown(videoUrl?.url);
return url ? { type: 'video_url', video_url: { url } } : undefined;
}
if (record.type === 'audio_url') {
const audioUrl = recordFromUnknown(record.audio_url);
const url = stringFromUnknown(audioUrl?.url);
return url ? { type: 'audio_url', audio_url: { url } } : undefined;
}
if (record.type === 'file_url') {
const fileUrl = recordFromUnknown(record.file_url);
const url = stringFromUnknown(fileUrl?.url);
const filename = stringFromUnknown(fileUrl?.filename) || '文件';
return url ? { type: 'file_url', file_url: { filename, url } } : undefined;
}
return undefined;
}
function openAIContentWithLegacyImages(text: string, images: unknown[]): OpenAIChatContentPart[] {
const content: OpenAIChatContentPart[] = [];
if (text.trim()) {
content.push({ type: 'text', text });
}
images.forEach((item) => {
const record = recordFromUnknown(item);
const url = stringFromUnknown(record?.url);
if (url) {
content.push({ type: 'image_url', image_url: { url } });
}
});
return content;
}
function openAIContentHasPayload(content: OpenAIChatContentPart[] | string) {
if (typeof content === 'string') return content.trim().length > 0;
return content.some((part) => {
if (part.type === 'text') return part.text.trim().length > 0;
return true;
});
}
function openAIContentText(content: OpenAIChatContentPart[] | string) {
if (typeof content === 'string') return content;
return content
.map((part) => part.type === 'text' ? part.text : '')
.join('')
.trim();
}
function imagePartsFromOpenAIContent(content: OpenAIChatContentPart[] | string | undefined) {
if (!Array.isArray(content)) return [];
return content.flatMap((part) => {
if (part.type !== 'image_url') return [];
return [{ name: '图片', url: part.image_url.url }];
});
}
function assistantPlaceholder(token: string, selectedModel: string, apiKeySecret: string) {
if (!token) return '请先登录后再测试模型';
if (!apiKeySecret) return '请选择可用于测试的 API Key';
if (!selectedModel) return '当前没有可用模型';
return '输入消息Enter 发送Shift + Enter 换行';
}
function recordFromUnknown(value: unknown): Record<string, unknown> | undefined {
if (!value || typeof value !== 'object' || Array.isArray(value)) return undefined;
return value as Record<string, unknown>;
}
function stringFromUnknown(value: unknown) {
return typeof value === 'string' ? value.trim() : '';
}
function dateStringFromUnknown(value: unknown) {
if (typeof value !== 'string') return undefined;
const timestamp = Date.parse(value);
return Number.isNaN(timestamp) ? undefined : new Date(timestamp).toISOString();
}

View File

@ -5,16 +5,19 @@ import Slider from 'antd/es/slider';
import {
Download,
Edit3,
FileText,
Image as ImageIcon,
Images,
Link2,
LoaderCircle,
Music2,
Sparkles,
Square,
} from 'lucide-react';
import { resolveApiAssetUrl } from '../api';
import { Button, Input, Popover, PopoverContent, PopoverTrigger } from '../components/ui';
import type { PlaygroundMode } from '../types';
import type { PlaygroundUpload, PlaygroundUploadKind, PlaygroundVideoCreateMode } from './playground-upload';
export type MediaOutputMode = 'single' | 'group';
export type MediaCountPreset = 1 | 2 | 3 | 4 | 'custom';
@ -46,6 +49,8 @@ export interface MediaGenerationRun {
settings: MediaGenerationSettings;
status: GatewayTask['status'] | 'submitting';
task?: GatewayTask;
uploads?: PlaygroundUpload[];
videoMode?: PlaygroundVideoCreateMode;
}
export function gatewayTaskErrorText(task: GatewayTask | undefined, fallback = '任务失败') {
@ -159,8 +164,6 @@ export function mediaRequestPayload(settings: MediaGenerationSettings, mode: Exc
aspect_ratio: settings.aspectRatio === 'auto' ? undefined : settings.aspectRatio,
audio: settings.outputAudio,
duration: settings.durationSeconds,
duration_seconds: settings.durationSeconds,
output_audio: settings.outputAudio,
resolution: settings.resolution,
};
}
@ -515,23 +518,44 @@ function MediaTaskCard(props: {
const isPending = props.run.status === 'submitting' || props.run.status === 'queued' || props.run.status === 'running';
const backdropItem = expectedCount === 1 && items[0]?.type === 'image' ? items[0] : undefined;
const errorText = mediaRunErrorText(props.run);
const references = mediaReferenceItems(props.run);
const promptParts = promptDisplayParts(props.run.prompt, references);
const taskMeta = mediaTaskMetaText(props.run);
const previewState = items.length > 0 ? 'filled' : isPending ? 'loading' : 'empty';
return (
<article className="mediaTaskItem" data-status={props.run.status}>
<header className="mediaTaskHeader">
<div>
<p>
<span>{props.run.prompt}</span>
<small>{props.run.mode === 'video' ? '视频' : '图片'} {props.run.modelLabel} {props.run.settings.aspectRatio} {props.run.settings.resolution}</small>
</p>
<time dateTime={props.run.createdAt}>{formatRunTime(props.run.createdAt)}</time>
<div className="mediaTaskHeaderMain">
{references.length > 0 && <MediaTaskReferenceStack references={references} />}
<div className="mediaTaskHeaderContent">
<div className="mediaTaskMetaLine">
<Sparkles size={16} />
<time dateTime={props.run.createdAt}>{formatRunDateTime(props.run.createdAt)}</time>
<span aria-hidden="true">|</span>
<strong>{props.run.modelLabel}</strong>
<span>{props.run.mode === 'video' ? '视频生成' : references.length ? '图像编辑' : '图像生成'}</span>
<span aria-hidden="true">|</span>
<span>{taskMeta}</span>
</div>
<div className="mediaTaskPromptLine">
<span className="mediaTaskPromptText">{promptParts}</span>
</div>
</div>
</div>
<span className="mediaTaskStatus" data-status={props.run.status}>{status}</span>
</header>
<div className="mediaPreviewStage" data-count={expectedCount}>
{errorText && (
<div className="mediaTaskError">
<strong></strong>
<span>{errorText}</span>
</div>
)}
<div className="mediaPreviewStage" data-count={expectedCount} data-preview-state={previewState}>
{backdropItem && <img aria-hidden="true" className="mediaPreviewBackdrop" src={backdropItem.src} alt="" />}
<div className="mediaGrid" data-count={expectedCount} style={style}>
<div className="mediaGrid" data-count={expectedCount} data-preview-state={previewState} style={style}>
{Array.from({ length: expectedCount }).map((_, index) => (
<MediaTile
expectedCount={expectedCount}
@ -545,12 +569,6 @@ function MediaTaskCard(props: {
</div>
</div>
{errorText && (
<div className="mediaTaskError">
<strong></strong>
<span>{errorText}</span>
</div>
)}
<footer className="mediaTaskActions">
{items[0] ? (
<Button asChild size="sm" variant="secondary">
@ -582,6 +600,57 @@ function MediaTaskCard(props: {
);
}
function MediaTaskReferenceStack(props: { references: PlaygroundUpload[] }) {
const visibleReferences = props.references.slice(0, 8);
const overflowCount = props.references.length - visibleReferences.length;
return (
<span
className="mediaTaskReferenceStack"
style={{ '--reference-count': Math.max(1, visibleReferences.length) } as CSSProperties}
title={props.references.map((item) => item.name).join('\n')}
>
{visibleReferences.map((item, index) => (
<span
className="mediaTaskReferenceCard"
data-kind={item.kind}
key={`${item.id}-${index}`}
style={taskReferenceCardStyle(index)}
>
<MediaTaskReferencePreview item={item} />
<small>{referenceKindLabel(item.kind)}</small>
</span>
))}
{overflowCount > 0 && (
<span className="mediaTaskReferenceOverflow">+{overflowCount}</span>
)}
</span>
);
}
function MediaTaskReferencePreview(props: { item: PlaygroundUpload }) {
if (props.item.kind === 'image') {
return <img src={props.item.url} alt="" draggable={false} />;
}
if (props.item.kind === 'video') {
return <video src={props.item.url} muted playsInline preload="metadata" />;
}
if (props.item.kind === 'audio') {
return <Music2 size={16} />;
}
return <FileText size={16} />;
}
function PromptResourceTag(props: { item: PlaygroundUpload; references: PlaygroundUpload[] }) {
return (
<span className="mediaPromptResourceTag" contentEditable={false}>
<span className="mediaPromptResourceThumb">
<MediaTaskReferencePreview item={props.item} />
</span>
<span>{mediaReferenceMentionLabel(props.item, props.references)}</span>
</span>
);
}
function MediaTile(props: {
expectedCount: number;
index: number;
@ -592,7 +661,7 @@ function MediaTile(props: {
const isLoading = props.status === 'submitting' || props.status === 'queued' || props.status === 'running';
const isFailed = props.status === 'failed' || props.status === 'cancelled';
return (
<div className="mediaTile" data-count={props.expectedCount} data-empty={!props.item && !isLoading} data-kind={props.mode}>
<div className="mediaTile" data-count={props.expectedCount} data-empty={!props.item && !isLoading} data-kind={props.mode} data-placeholder={!props.item}>
{props.item?.type === 'video' && (
<video controls muted playsInline poster={props.item.poster}>
<source src={props.item.src} />
@ -630,6 +699,193 @@ function mediaRunErrorText(run: MediaGenerationRun) {
return gatewayTaskErrorText(run.task, '') || run.error || '';
}
const mediaResourceTokenPattern = /<<<playground-resource:([^>]+)>>>/g;
const taskReferenceTiltValues = [-10, 8, -5, 9, -7, 5, -8, 6];
const taskReferenceYValues = [0, 3, -1, 2, -2, 4, 1, -3];
function promptDisplayParts(raw: string, references: PlaygroundUpload[]): ReactNode[] {
if (!raw.includes('<<<playground-resource:')) return [raw];
const parts: ReactNode[] = [];
const byId = new Map(references.map((item) => [item.id, item]));
const usedIds = new Set<string>();
let fallbackIndex = 0;
let lastIndex = 0;
let match: RegExpExecArray | null;
const re = new RegExp(mediaResourceTokenPattern);
while ((match = re.exec(raw)) !== null) {
if (match.index > lastIndex) {
parts.push(raw.slice(lastIndex, match.index));
}
const id = match[1] ?? '';
let item = byId.get(id);
if (item) {
usedIds.add(item.id);
} else {
while (fallbackIndex < references.length && usedIds.has(references[fallbackIndex]!.id)) {
fallbackIndex += 1;
}
item = references[fallbackIndex];
fallbackIndex += 1;
if (item) usedIds.add(item.id);
}
if (item) {
parts.push(<PromptResourceTag item={item} references={references} key={`${match.index}-${item.id}`} />);
} else {
parts.push('@资产');
}
lastIndex = match.index + (match[0]?.length ?? 0);
}
if (lastIndex < raw.length) {
parts.push(raw.slice(lastIndex));
}
return parts;
}
function taskReferenceCardStyle(index: number) {
const valueIndex = index % taskReferenceTiltValues.length;
return {
'--reference-index': index,
'--reference-tilt': `${taskReferenceTiltValues[valueIndex]}deg`,
'--reference-y': `${taskReferenceYValues[valueIndex]}px`,
} as CSSProperties;
}
function mediaReferenceMentionLabel(item: PlaygroundUpload, references: PlaygroundUpload[]) {
return `@${promptReferenceKindLabel(item.kind)}${referenceKindIndex(item, references)}`;
}
function referenceKindIndex(item: PlaygroundUpload, references: PlaygroundUpload[]) {
const sameKind = references.filter((reference) => reference.kind === item.kind);
return Math.max(1, sameKind.findIndex((reference) => reference.id === item.id) + 1);
}
function mediaReferenceItems(run: MediaGenerationRun): PlaygroundUpload[] {
const uploads = normalizeReferenceUploads(run.uploads);
if (uploads.length) return uploads;
return referencesFromTaskRequest(run.task?.request, run.mode);
}
function normalizeReferenceUploads(value: unknown): PlaygroundUpload[] {
if (!Array.isArray(value)) return [];
return value
.map((item, index) => normalizeReferenceUpload(item, index))
.filter((item): item is PlaygroundUpload => Boolean(item));
}
function normalizeReferenceUpload(value: unknown, index: number): PlaygroundUpload | undefined {
const record = recordFromUnknown(value);
if (!record) return undefined;
const url = stringFromUnknown(record.url);
const kind = referenceKindFromUnknown(record.kind, url);
if (!url || !kind) return undefined;
const size = numberFromUnknown(record.size);
return {
contentType: stringFromUnknown(record.contentType),
id: stringFromUnknown(record.id) || `${kind}-${index}-${url}`,
kind,
name: stringFromUnknown(record.name) || `${referenceKindLabel(kind)} ${index + 1}`,
raw: recordFromUnknown(record.raw) ?? {},
role: record.role === 'first_frame' || record.role === 'last_frame' ? record.role : undefined,
size: size && size > 0 ? Math.round(size) : 0,
url,
};
}
function referencesFromTaskRequest(request: unknown, mode: Exclude<PlaygroundMode, 'chat'>): PlaygroundUpload[] {
const record = recordFromUnknown(request);
if (!record) return [];
const references: PlaygroundUpload[] = [];
if (Array.isArray(record.content)) {
record.content.forEach((item) => {
const content = recordFromUnknown(item);
if (!content) return;
appendReferenceFromContentPart(references, content);
});
}
if (mode === 'image') {
appendImageReferencesFromValue(references, record.image);
appendImageReferencesFromValue(references, record.images);
appendImageReferencesFromValue(references, record.input_image);
appendImageReferencesFromValue(references, record.input_images);
}
return dedupeReferenceUploads(references);
}
function appendReferenceFromContentPart(references: PlaygroundUpload[], part: Record<string, unknown>) {
const type = stringFromUnknown(part.type);
if (type === 'image_url') {
appendReferenceUrl(references, 'image', firstString(nestedString(part.image_url, 'url'), part.url));
return;
}
if (type === 'video_url') {
appendReferenceUrl(references, 'video', firstString(nestedString(part.video_url, 'url'), part.url));
return;
}
if (type === 'audio_url') {
appendReferenceUrl(references, 'audio', firstString(nestedString(part.audio_url, 'url'), part.url));
}
}
function appendImageReferencesFromValue(references: PlaygroundUpload[], value: unknown) {
if (!value) return;
if (typeof value === 'string') {
appendReferenceUrl(references, 'image', value);
return;
}
if (Array.isArray(value)) {
value.forEach((item) => appendImageReferencesFromValue(references, item));
return;
}
const record = recordFromUnknown(value);
if (!record) return;
appendReferenceUrl(references, 'image', firstString(record.url, nestedString(record.image_url, 'url'), record.image_url, record.imageUrl, record.path));
}
function appendReferenceUrl(references: PlaygroundUpload[], kind: PlaygroundUploadKind, rawUrl: unknown) {
const url = stringFromUnknown(rawUrl);
if (!url) return;
const resolvedUrl = resolveApiAssetUrl(url);
references.push({
contentType: '',
id: `${kind}-${references.length}-${resolvedUrl}`,
kind,
name: `${referenceKindLabel(kind)} ${references.filter((item) => item.kind === kind).length + 1}`,
raw: {},
size: 0,
url: resolvedUrl,
});
}
function dedupeReferenceUploads(references: PlaygroundUpload[]) {
const seen = new Set<string>();
return references.filter((item) => {
const key = `${item.kind}:${item.url}`;
if (seen.has(key)) return false;
seen.add(key);
return true;
});
}
function referenceKindFromUnknown(value: unknown, url: string): PlaygroundUploadKind | undefined {
if (value === 'image' || value === 'video' || value === 'audio' || value === 'file') return value;
if (/\.(png|jpe?g|webp|gif|bmp|avif|svg)(\?|#|$)/i.test(url)) return 'image';
if (/\.(mp4|mov|webm|m4v|avi|mkv)(\?|#|$)/i.test(url)) return 'video';
if (/\.(mp3|m4a|wav|aac|flac|ogg|opus)(\?|#|$)/i.test(url)) return 'audio';
return undefined;
}
function referenceKindLabel(kind: PlaygroundUploadKind) {
if (kind === 'image') return '图像';
if (kind === 'video') return '视频';
if (kind === 'audio') return '音频';
return '文件';
}
function promptReferenceKindLabel(kind: PlaygroundUploadKind) {
if (kind === 'image') return '图片';
return referenceKindLabel(kind);
}
function mediaResultItemFromEntry(entry: unknown, mode: Exclude<PlaygroundMode, 'chat'>): MediaResultItem | undefined {
const record = recordFromUnknown(entry);
if (!record) return undefined;
@ -662,8 +918,23 @@ function mediaStatusText(run: MediaGenerationRun) {
return run.status;
}
function formatRunTime(value: string) {
return new Intl.DateTimeFormat('zh-CN', { hour: '2-digit', minute: '2-digit' }).format(new Date(value));
function formatRunDateTime(value: string) {
const date = new Date(value);
if (Number.isNaN(date.getTime())) return value;
const pad = (item: number) => String(item).padStart(2, '0');
return `${date.getFullYear()}-${pad(date.getMonth() + 1)}-${pad(date.getDate())} ${pad(date.getHours())}:${pad(date.getMinutes())}:${pad(date.getSeconds())}`;
}
function mediaTaskMetaText(run: MediaGenerationRun) {
const items = [run.settings.aspectRatio, run.settings.resolution];
if (run.mode === 'video') {
items.push(`${run.settings.durationSeconds}s`);
if (run.settings.outputAudio) items.push('有声音');
} else {
const count = mediaOutputCount(run.settings);
if (count > 1) items.push(`${count}`);
}
return items.filter(Boolean).join(' | ');
}
function cssAspectRatio(settings: MediaGenerationSettings) {

View File

@ -0,0 +1,601 @@
import { useCallback, useEffect, useMemo, useRef, useState, type KeyboardEvent as ReactKeyboardEvent } from 'react';
import { FileText, Music2, Video } from 'lucide-react';
export type PlaygroundMentionUploadKind = 'audio' | 'file' | 'image' | 'video';
export interface PlaygroundMentionUpload {
id: string;
kind: PlaygroundMentionUploadKind;
name: string;
url: string;
}
const resourceTokenPrefix = '<<<playground-resource:';
const resourceTokenSuffix = '>>>';
const resourceTokenPattern = /<<<playground-resource:([^>]+)>>>/g;
const resourceTokenWithSpacePattern = /<<<playground-resource:([^>]+)>>>\s?/g;
export function buildPlaygroundResourceToken(id: string) {
return `${resourceTokenPrefix}${id}${resourceTokenSuffix}`;
}
export function removeInvalidPlaygroundResourceTokens(raw: string, uploads: PlaygroundMentionUpload[]) {
if (!raw.includes(resourceTokenPrefix)) return raw;
const validIds = new Set(uploads.map((item) => item.id));
return raw.replace(resourceTokenWithSpacePattern, (full, id: string) => validIds.has(id) ? full : '');
}
export function replacePlaygroundResourceTokens(
raw: string,
uploads: PlaygroundMentionUpload[],
mode: 'image' | 'video',
) {
if (!raw.includes(resourceTokenPrefix)) return raw;
const byId = new Map(uploads.map((item) => [item.id, item]));
return raw.replace(resourceTokenPattern, (full, id: string) => {
const item = byId.get(id);
if (!item) return '';
return resourcePromptLabel(item, uploads, mode);
});
}
export function replacePlaygroundResourceTokensForDisplay(raw: string, uploads: PlaygroundMentionUpload[]) {
if (!raw.includes(resourceTokenPrefix)) return raw;
const byId = new Map(uploads.map((item) => [item.id, item]));
const usedIds = new Set<string>();
let fallbackIndex = 0;
return raw.replace(resourceTokenPattern, (full, id: string) => {
let item = byId.get(id);
if (item) {
usedIds.add(item.id);
} else {
while (fallbackIndex < uploads.length && usedIds.has(uploads[fallbackIndex]!.id)) {
fallbackIndex += 1;
}
item = uploads[fallbackIndex];
fallbackIndex += 1;
if (item) usedIds.add(item.id);
}
if (!item) return '@资产';
return mentionDisplayLabel(item, uploads);
});
}
export function PlaygroundPromptMentionInput(props: {
disabled?: boolean;
placeholder: string;
uploads: PlaygroundMentionUpload[];
value: string;
onChange: (value: string) => void;
}) {
const editableRef = useRef<HTMLDivElement>(null);
const dropdownRef = useRef<HTMLDivElement>(null);
const blurTimerRef = useRef<number | undefined>(undefined);
const isComposingRef = useRef(false);
const [text, setText] = useState(props.value);
const [focused, setFocused] = useState(false);
const [hasEditableContent, setHasEditableContent] = useState(() => promptTextHasContent(props.value));
const [mentionOpen, setMentionOpen] = useState(false);
const [mentionAtIndex, setMentionAtIndex] = useState(-1);
const [mentionSearch, setMentionSearch] = useState('');
const [highlightIndex, setHighlightIndex] = useState(0);
const [dropdownPosition, setDropdownPosition] = useState({ top: 0, left: 0, placement: 'bottom' as 'bottom' | 'top' });
const showPlaceholder = !hasEditableContent;
const uploadSignature = useMemo(
() => props.uploads.map((item) => `${item.id}:${item.kind}:${item.name}:${item.url}`).join('|'),
[props.uploads],
);
const mentionItems = useMemo(() => props.uploads.map((item, index) => ({
item,
label: mentionDisplayLabel(item, props.uploads),
searchText: `${item.name} ${mentionDisplayLabel(item, props.uploads)} ${uploadKindChinese(item.kind)} 资产 素材 resource asset ${index + 1}`.toLowerCase(),
token: buildPlaygroundResourceToken(item.id),
})), [props.uploads]);
const filteredMentionItems = useMemo(() => {
const keyword = mentionSearch.trim().toLowerCase();
if (!keyword) return mentionItems;
return mentionItems.filter((item) => item.searchText.includes(keyword));
}, [mentionItems, mentionSearch]);
useEffect(() => {
if (props.value === text) {
setHasEditableContent(promptTextHasContent(props.value));
return;
}
setText(props.value);
setHasEditableContent(promptTextHasContent(props.value));
if (!isComposingRef.current) {
requestAnimationFrame(() => renderToEditable(props.value));
}
}, [focused, props.value, text]);
useEffect(() => {
const cleaned = removeInvalidPlaygroundResourceTokens(props.value, props.uploads);
if (cleaned !== props.value) {
setText(cleaned);
setHasEditableContent(promptTextHasContent(cleaned));
props.onChange(cleaned);
requestAnimationFrame(() => renderToEditable(cleaned));
return;
}
if (props.value !== text) {
setText(props.value);
setHasEditableContent(promptTextHasContent(props.value));
}
if (!focused) {
requestAnimationFrame(() => renderToEditable(props.value));
}
}, [uploadSignature]);
useEffect(() => {
requestAnimationFrame(() => renderToEditable(text));
return () => {
if (blurTimerRef.current) window.clearTimeout(blurTimerRef.current);
};
}, []);
const updateMentionDropdownPosition = useCallback(() => {
const editable = editableRef.current;
const root = editable?.parentElement;
if (!editable || !root) return;
const rootRect = root.getBoundingClientRect();
const caretRect = getCaretClientRect(editable) ?? editable.getBoundingClientRect();
const dropdown = dropdownRef.current;
const viewportPadding = 8;
const gap = 6;
const dropdownWidth = Math.min(dropdown?.offsetWidth ?? 320, Math.max(180, window.innerWidth - viewportPadding * 2));
const dropdownHeight = Math.min(dropdown?.offsetHeight ?? 220, Math.max(120, window.innerHeight - viewportPadding * 2));
const minLeft = viewportPadding - rootRect.left;
const maxLeft = window.innerWidth - viewportPadding - dropdownWidth - rootRect.left;
const left = clamp(caretRect.left - rootRect.left, minLeft, Math.max(minLeft, maxLeft));
const belowTop = caretRect.bottom + gap;
const aboveTop = caretRect.top - dropdownHeight - gap;
const shouldOpenAbove = belowTop + dropdownHeight > window.innerHeight - viewportPadding && aboveTop >= viewportPadding;
const viewportTop = shouldOpenAbove
? Math.max(viewportPadding, aboveTop)
: Math.min(belowTop, window.innerHeight - viewportPadding - dropdownHeight);
setDropdownPosition({
top: viewportTop - rootRect.top,
left,
placement: shouldOpenAbove ? 'top' : 'bottom',
});
}, []);
useEffect(() => {
if (!mentionOpen) return;
const frame = window.requestAnimationFrame(updateMentionDropdownPosition);
const handleViewportChange = () => updateMentionDropdownPosition();
window.addEventListener('resize', handleViewportChange);
window.addEventListener('scroll', handleViewportChange, true);
return () => {
window.cancelAnimationFrame(frame);
window.removeEventListener('resize', handleViewportChange);
window.removeEventListener('scroll', handleViewportChange, true);
};
}, [filteredMentionItems.length, mentionOpen, mentionSearch, updateMentionDropdownPosition]);
function renderToEditable(nextText = text, caret?: number) {
const editable = editableRef.current;
if (!editable) return;
editable.innerHTML = textToHtml(nextText, props.uploads);
setHasEditableContent(promptTextHasContent(nextText));
if (typeof caret === 'number') {
editable.focus();
setCaretOffset(editable, caret);
}
}
function commitFromEditable(shouldInspectMention: boolean, event?: InputEvent) {
const editable = editableRef.current;
if (!editable || props.disabled) return;
const nextText = serializeEditableToPlainText(editable);
setHasEditableContent(promptTextHasContent(nextText));
if (isComposingRef.current || event?.isComposing) return;
setText(nextText);
props.onChange(nextText);
if (!shouldInspectMention) return;
inspectMentionTrigger(nextText, getCaretOffset(editable), event);
}
function inspectMentionTrigger(nextText: string, caret: number, event?: InputEvent) {
const beforeCaret = nextText.slice(0, caret);
const atIndex = Math.max(beforeCaret.lastIndexOf('@'), beforeCaret.lastIndexOf(''));
if (atIndex < 0) {
closeMention();
return;
}
const search = beforeCaret.slice(atIndex + 1);
if (/\s/.test(search)) {
closeMention();
return;
}
const inputData = event?.data ?? '';
if (!mentionOpen && search.length > 0 && inputData !== '@' && inputData !== '') {
closeMention();
return;
}
setMentionAtIndex(atIndex);
setMentionSearch(search);
setMentionOpen(true);
setHighlightIndex(0);
requestAnimationFrame(updateMentionDropdownPosition);
}
function closeMention() {
setMentionOpen(false);
setMentionSearch('');
setMentionAtIndex(-1);
setHighlightIndex(0);
}
function applyMention(token: string) {
const editable = editableRef.current;
if (!editable || mentionAtIndex < 0) return;
const currentText = serializeEditableToPlainText(editable);
const replaceEnd = Math.max(mentionAtIndex + 1, mentionAtIndex + 1 + mentionSearch.length);
const before = currentText.slice(0, mentionAtIndex);
const after = currentText.slice(replaceEnd);
const nextText = `${before}${token} ${after}`;
const nextCaret = before.length + token.length + 1;
setText(nextText);
setHasEditableContent(promptTextHasContent(nextText));
props.onChange(nextText);
closeMention();
requestAnimationFrame(() => renderToEditable(nextText, nextCaret));
}
function handleKeyDown(event: ReactKeyboardEvent<HTMLDivElement>) {
if (props.disabled) {
event.preventDefault();
return;
}
if (!mentionOpen || !filteredMentionItems.length) return;
if (event.key === 'ArrowDown') {
event.preventDefault();
setHighlightIndex((current) => (current + 1) % filteredMentionItems.length);
return;
}
if (event.key === 'ArrowUp') {
event.preventDefault();
setHighlightIndex((current) => (current - 1 + filteredMentionItems.length) % filteredMentionItems.length);
return;
}
if (event.key === 'Enter') {
event.preventDefault();
const item = filteredMentionItems[Math.min(highlightIndex, filteredMentionItems.length - 1)];
if (item) applyMention(item.token);
return;
}
if (event.key === 'Escape') {
event.preventDefault();
closeMention();
}
}
return (
<div className="promptMentionInput">
<div
ref={editableRef}
className="promptMentionEditable"
contentEditable={!props.disabled}
role="textbox"
aria-label={props.placeholder}
aria-multiline="true"
suppressContentEditableWarning
onBlur={() => {
blurTimerRef.current = window.setTimeout(() => {
setFocused(false);
closeMention();
commitFromEditable(false);
}, 120);
}}
onCompositionEnd={() => {
isComposingRef.current = false;
setHasEditableContent(editablePromptTextHasContent(editableRef.current));
requestAnimationFrame(() => commitFromEditable(true));
}}
onCompositionStart={() => {
isComposingRef.current = true;
}}
onFocus={() => {
setFocused(true);
if (blurTimerRef.current) window.clearTimeout(blurTimerRef.current);
}}
onInput={(event) => commitFromEditable(true, event.nativeEvent instanceof InputEvent ? event.nativeEvent : undefined)}
onKeyDown={handleKeyDown}
onPaste={(event) => {
if (props.disabled) return;
event.preventDefault();
const plainText = event.clipboardData.getData('text/plain');
document.execCommand('insertText', false, plainText);
}}
/>
{showPlaceholder && <div className="promptMentionPlaceholder" aria-hidden="true">{props.placeholder}</div>}
{mentionOpen && (
<div
ref={dropdownRef}
className="promptMentionDropdown"
data-placement={dropdownPosition.placement}
style={{ left: dropdownPosition.left, top: dropdownPosition.top }}
onMouseDown={(event) => event.preventDefault()}
>
{filteredMentionItems.length ? (
filteredMentionItems.map((candidate, index) => (
<button
type="button"
className="promptMentionItem"
data-active={index === highlightIndex}
key={candidate.item.id}
onMouseEnter={() => setHighlightIndex(index)}
onMouseDown={(event) => {
event.preventDefault();
applyMention(candidate.token);
}}
>
<MentionThumb item={candidate.item} />
<span>
<strong>{candidate.label}</strong>
<small>{candidate.item.name}</small>
</span>
<em>{uploadKindChinese(candidate.item.kind)}</em>
</button>
))
) : (
<div className="promptMentionEmpty"></div>
)}
</div>
)}
</div>
);
}
function promptTextHasContent(raw: string) {
return raw.trim().length > 0;
}
function editablePromptTextHasContent(editable: HTMLElement | null) {
return editable ? promptTextHasContent(serializeEditableToPlainText(editable)) : false;
}
function MentionThumb(props: { item: PlaygroundMentionUpload }) {
if (props.item.kind === 'image') {
return <img className="promptMentionThumb" src={props.item.url} alt="" draggable={false} />;
}
if (props.item.kind === 'video') {
return <video className="promptMentionThumb" src={props.item.url} muted playsInline preload="metadata" />;
}
if (props.item.kind === 'audio') {
return <Music2 className="promptMentionThumbIcon" size={16} />;
}
return <FileText className="promptMentionThumbIcon" size={16} />;
}
function textToHtml(raw: string, uploads: PlaygroundMentionUpload[]) {
const uploadById = new Map(uploads.map((item) => [item.id, item]));
const parts: string[] = [];
let lastIndex = 0;
let match: RegExpExecArray | null;
const re = new RegExp(resourceTokenPattern);
while ((match = re.exec(raw)) !== null) {
if (match.index > lastIndex) {
parts.push(escapeHtml(raw.slice(lastIndex, match.index)));
}
const token = match[0] ?? '';
const id = match[1] ?? '';
const item = uploadById.get(id);
parts.push(item ? mentionChipHtml(token, item, uploads) : '');
lastIndex = match.index + token.length;
}
if (lastIndex < raw.length) {
parts.push(escapeHtml(raw.slice(lastIndex)));
}
return parts.join('').replace(/\n/g, '<br>');
}
function mentionChipHtml(token: string, item: PlaygroundMentionUpload, uploads: PlaygroundMentionUpload[]) {
const label = mentionDisplayLabel(item, uploads);
const thumb = item.kind === 'image'
? `<img class="promptMentionChipThumb" src="${escapeAttr(item.url)}" alt="" draggable="false">`
: item.kind === 'video'
? `<video class="promptMentionChipThumb" src="${escapeAttr(item.url)}" muted preload="metadata" playsinline></video>`
: `<span class="promptMentionChipThumb promptMentionChipThumbPlaceholder">${escapeHtml(uploadKindShort(item.kind))}</span>`;
return `<span contenteditable="false" class="promptMentionChip" data-token="${escapeAttr(token)}">${thumb}<span>${escapeHtml(label)}</span></span>`;
}
function getCaretClientRect(editable: HTMLElement) {
const selection = document.getSelection();
if (!selection || selection.rangeCount === 0) return null;
const originalRange = selection.getRangeAt(0);
if (!editable.contains(originalRange.startContainer)) return null;
const collapsedRange = originalRange.cloneRange();
collapsedRange.collapse(true);
let rect = collapsedRange.getBoundingClientRect();
if (rect.width || rect.height) return rect;
const marker = document.createElement('span');
marker.textContent = '\u200b';
const restoreRange = originalRange.cloneRange();
collapsedRange.insertNode(marker);
rect = marker.getBoundingClientRect();
marker.remove();
selection.removeAllRanges();
selection.addRange(restoreRange);
return rect.width || rect.height ? rect : null;
}
function serializeEditableToPlainText(editable: HTMLElement) {
const parts: string[] = [];
const walk = (node: Node) => {
if (node.nodeType === Node.TEXT_NODE) {
parts.push(node.textContent ?? '');
return;
}
if (node.nodeType !== Node.ELEMENT_NODE) return;
const element = node as HTMLElement;
const token = element.getAttribute('data-token');
if (token) {
parts.push(token);
return;
}
if (element.tagName === 'BR') {
parts.push('\n');
return;
}
if (element.tagName === 'DIV' && parts.length && !parts[parts.length - 1]?.endsWith('\n')) {
parts.push('\n');
}
element.childNodes.forEach(walk);
};
editable.childNodes.forEach(walk);
return parts.join('');
}
function getCaretOffset(editable: HTMLElement) {
const selection = document.getSelection();
if (!selection || selection.rangeCount === 0) return 0;
const range = selection.getRangeAt(0);
let offset = 0;
const walk = (node: Node): boolean => {
if (node.nodeType === Node.TEXT_NODE) {
const textLength = (node.textContent ?? '').length;
if (node === range.startContainer) {
offset += range.startOffset;
return true;
}
offset += textLength;
return false;
}
if (node.nodeType !== Node.ELEMENT_NODE) return false;
const element = node as HTMLElement;
const token = element.getAttribute('data-token');
if (token) {
if (node === range.startContainer || element.contains(range.startContainer)) {
offset += range.startOffset > 0 ? token.length : 0;
return true;
}
offset += token.length;
return false;
}
if (element.tagName === 'BR') {
offset += 1;
return false;
}
for (const child of Array.from(element.childNodes)) {
if (walk(child)) return true;
}
return false;
};
for (const child of Array.from(editable.childNodes)) {
if (walk(child)) break;
}
return offset;
}
function setCaretOffset(editable: HTMLElement, targetOffset: number) {
const selection = document.getSelection();
if (!selection) return;
let offset = 0;
let targetNode: Node | null = null;
let targetNodeOffset = 0;
const walk = (node: Node): boolean => {
if (node.nodeType === Node.TEXT_NODE) {
const textLength = (node.textContent ?? '').length;
if (offset + textLength >= targetOffset) {
targetNode = node;
targetNodeOffset = Math.max(0, targetOffset - offset);
return true;
}
offset += textLength;
return false;
}
if (node.nodeType !== Node.ELEMENT_NODE) return false;
const element = node as HTMLElement;
const token = element.getAttribute('data-token');
if (token) {
if (offset + token.length >= targetOffset) {
targetNode = node;
targetNodeOffset = targetOffset <= offset ? 0 : 1;
return true;
}
offset += token.length;
return false;
}
if (element.tagName === 'BR') {
if (offset + 1 >= targetOffset) {
targetNode = node;
targetNodeOffset = 0;
return true;
}
offset += 1;
return false;
}
for (const child of Array.from(element.childNodes)) {
if (walk(child)) return true;
}
return false;
};
for (const child of Array.from(editable.childNodes)) {
if (walk(child)) break;
}
if (!targetNode) {
targetNode = editable;
targetNodeOffset = editable.childNodes.length;
}
const range = document.createRange();
range.setStart(targetNode, targetNode.nodeType === Node.TEXT_NODE ? targetNodeOffset : Math.min(targetNodeOffset, targetNode.childNodes.length));
range.collapse(true);
selection.removeAllRanges();
selection.addRange(range);
}
function mentionDisplayLabel(item: PlaygroundMentionUpload, uploads: PlaygroundMentionUpload[]) {
return `@${uploadKindChinese(item.kind)} ${uploadKindIndex(item, uploads)}`;
}
function resourcePromptLabel(item: PlaygroundMentionUpload, uploads: PlaygroundMentionUpload[], mode: 'image' | 'video') {
const kind = mode === 'image' ? 'image' : uploadKindEnglish(item.kind);
return `${kind} ${uploadKindIndex(item, uploads)}`;
}
function uploadKindIndex(item: PlaygroundMentionUpload, uploads: PlaygroundMentionUpload[]) {
const sameKind = uploads.filter((upload) => upload.kind === item.kind);
return Math.max(1, sameKind.findIndex((upload) => upload.id === item.id) + 1);
}
function uploadKindChinese(kind: PlaygroundMentionUploadKind) {
if (kind === 'image') return '图像';
if (kind === 'video') return '视频';
if (kind === 'audio') return '音频';
return '文件';
}
function uploadKindEnglish(kind: PlaygroundMentionUploadKind) {
if (kind === 'image') return 'image';
if (kind === 'video') return 'video';
if (kind === 'audio') return 'audio';
return 'file';
}
function uploadKindShort(kind: PlaygroundMentionUploadKind) {
if (kind === 'image') return '图';
if (kind === 'video') return '视';
if (kind === 'audio') return '音';
return '文';
}
function clamp(value: number, min: number, max: number) {
return Math.min(max, Math.max(min, value));
}
function escapeHtml(value: string) {
return value
.replace(/&/g, '&amp;')
.replace(/</g, '&lt;')
.replace(/>/g, '&gt;');
}
function escapeAttr(value: string) {
return escapeHtml(value).replace(/"/g, '&quot;');
}

View File

@ -0,0 +1,120 @@
import type { ReactNode } from 'react';
import type { GatewayApiKey, PlatformModel } from '@easyai-ai-gateway/contracts';
import { Bot, Image as ImageIcon, Video } from 'lucide-react';
import { Select } from '../components/ui';
import type { PlaygroundMode } from '../types';
import type { PlaygroundVideoCreateMode } from './playground-upload';
export type VideoCreateMode = PlaygroundVideoCreateMode;
export interface ModelOption {
count: number;
label: string;
models: PlatformModel[];
provider: string;
value: string;
}
export const modeOptions: Array<{ description: string; icon: ReactNode; label: string; value: PlaygroundMode }> = [
{ value: 'chat', label: '大模型对话', description: '对话、推理、结构化输出', icon: <Bot size={16} /> },
{ value: 'image', label: '图像生成', description: '文生图、图像编辑参数预览', icon: <ImageIcon size={16} /> },
{ value: 'video', label: '视频生成', description: '图生视频、文生视频任务测试', icon: <Video size={16} /> },
];
export const videoModeOptions: Array<{ label: string; value: VideoCreateMode }> = [
{ value: 'text_to_video', label: '文生视频' },
{ value: 'first_last_frame', label: '首尾帧' },
{ value: 'omni_reference', label: '全能参考' },
];
export const placeholderByMode: Record<PlaygroundMode, string> = {
chat: '输入问题、角色设定或测试提示词,支持 OpenAI 兼容格式验证...',
image: '描述你想生成的画面,例如:未来城市中的玻璃温室,晨光,电影级构图...',
video: '描述视频镜头、主体运动和风格,例如:低角度跟拍,一辆复古跑车穿过雨夜街道...',
};
export const quickPrompts: Record<PlaygroundMode, string[]> = {
chat: ['写一个产品发布摘要', '生成接口调用示例', '分析失败重试策略'],
image: ['产品海报', '角色设定图', '电商主图'],
video: ['5 秒运镜', '首帧转视频', '宣传短片'],
};
export function resolveSelectedApiKeyId(apiKeys: GatewayApiKey[], secretsById: Record<string, string>, selectedApiKeyId: string) {
if (selectedApiKeyId && secretsById[selectedApiKeyId]) return selectedApiKeyId;
const firstUsable = apiKeys.find((item) => Boolean(secretsById[item.id]));
return firstUsable?.id ?? '';
}
export function apiKeyNoticeText(apiKeys: GatewayApiKey[], secretsById: Record<string, string>) {
if (!apiKeys.length) return '当前账号还没有可用 API Key请先创建一个 Key。';
if (!apiKeys.some((item) => Boolean(secretsById[item.id]))) {
return '当前没有可用于在线测试的完整 API Key请重新加载或创建一个 Key。';
}
return '';
}
export function modelOptionLabel(option: ModelOption) {
const count = option.count > 1 ? ` · ${option.count} 个客户端` : '';
const provider = option.provider ? ` · ${option.provider}` : '';
return `${option.label}${provider}${count}`;
}
export function ApiKeySelect(props: {
apiKeySecretsById: Record<string, string>;
apiKeys: GatewayApiKey[];
selectedApiKeyId: string;
onApiKeyChange: (apiKeyId: string) => void;
}) {
const activeApiKeyId = resolveSelectedApiKeyId(props.apiKeys, props.apiKeySecretsById, props.selectedApiKeyId);
return (
<Select
className="playgroundApiKeySelect"
value={activeApiKeyId}
disabled={!props.apiKeys.length}
onChange={(event) => props.onApiKeyChange(event.target.value)}
>
{!activeApiKeyId && <option value="">{props.apiKeys.length ? '选择 API Key' : '暂无 API Key'}</option>}
{props.apiKeys.map((item) => {
const usable = Boolean(props.apiKeySecretsById[item.id]);
return (
<option value={item.id} key={item.id} disabled={!usable}>
{item.name} · {item.keyPrefix}
</option>
);
})}
</Select>
);
}
export function ModeSwitch(props: {
activeMode: PlaygroundMode;
onModeChange: (mode: PlaygroundMode) => void;
}) {
return (
<div className="playgroundModeSwitch">
{modeOptions.map((item) => (
<button
type="button"
key={item.value}
data-active={props.activeMode === item.value}
onClick={() => props.onModeChange(item.value)}
>
{item.icon}
<span>{item.label}</span>
</button>
))}
</div>
);
}
export function PlaygroundGreeting(props: {
activeMode: { description: string; label: string };
}) {
return (
<div className="playgroundGreeting">
<span></span>
<strong>{props.activeMode.label}</strong>
<small>{props.activeMode.description}</small>
</div>
);
}

View File

@ -0,0 +1,742 @@
import { useRef, useState, type CSSProperties } from 'react';
import {
FileText,
Image as ImageIcon,
LoaderCircle,
Music2,
Paperclip,
Plus,
Repeat2,
Video,
X,
} from 'lucide-react';
import { uploadFileToStorage } from '../api';
import type { VideoGenerationContent } from '../api';
import type { PlaygroundMode } from '../types';
export type PlaygroundUploadKind = 'audio' | 'file' | 'image' | 'video';
export type PlaygroundUploadRole = 'first_frame' | 'last_frame';
export type PlaygroundVideoCreateMode = 'text_to_video' | 'first_last_frame' | 'omni_reference';
export interface PlaygroundUpload {
contentType: string;
id: string;
kind: PlaygroundUploadKind;
name: string;
raw: Record<string, unknown>;
role?: PlaygroundUploadRole;
size: number;
url: string;
}
export type OpenAIChatContentPart =
| { type: 'text'; text: string }
| { type: 'image_url'; image_url: { url: string } }
| { type: 'video_url'; video_url: { url: string } }
| { type: 'audio_url'; audio_url: { url: string } }
| { type: 'file_url'; file_url: { filename: string; url: string } };
export const mediaUploadAccept = 'image/*,video/*,audio/*';
export const imageOnlyUploadAccept = 'image/*';
export const chatUploadAccept = [
mediaUploadAccept,
'.csv',
'.doc',
'.docx',
'.json',
'.jsonl',
'.md',
'.markdown',
'.pdf',
'.ppt',
'.pptx',
'.txt',
'.xls',
'.xlsx',
'.yaml',
'.yml',
'application/json',
'application/msword',
'application/pdf',
'application/vnd.ms-excel',
'application/vnd.ms-powerpoint',
'application/vnd.openxmlformats-officedocument.presentationml.presentation',
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
'text/*',
].join(',');
export function ComposerUploadButton(props: {
accept: string;
active?: boolean;
disabled?: boolean;
uploading?: boolean;
onFiles?: (files: File[]) => void;
}) {
const inputRef = useRef<HTMLInputElement>(null);
const disabled = props.disabled || props.uploading;
return (
<>
<button
type="button"
className="composerUpload"
aria-label="上传附件"
data-active={props.active === true}
disabled={disabled}
onClick={() => inputRef.current?.click()}
>
{props.uploading ? <LoaderCircle className="composerUploadSpinner" size={18} /> : <Paperclip size={18} />}
</button>
<input
ref={inputRef}
type="file"
multiple
hidden
accept={props.accept}
disabled={disabled}
onChange={(event) => {
const files = Array.from(event.currentTarget.files ?? []);
event.currentTarget.value = '';
props.onFiles?.(files);
}}
/>
</>
);
}
export function PlaygroundReferencePicker(props: {
accept: string;
disabled?: boolean;
mode: PlaygroundMode;
uploadLabel?: string;
uploads: PlaygroundUpload[];
uploading?: boolean;
videoMode?: PlaygroundVideoCreateMode;
onFiles?: (files: File[], targetRole?: PlaygroundUploadRole) => void;
onRemove?: (id: string) => void;
onSwapFrames?: () => void;
}) {
if (props.mode === 'video' && props.videoMode === 'first_last_frame') {
return (
<FirstLastFramePicker
accept={props.accept}
disabled={props.disabled}
uploads={props.uploads}
uploading={props.uploading}
onFiles={props.onFiles}
onRemove={props.onRemove}
onSwapFrames={props.onSwapFrames}
/>
);
}
return (
<StackedReferencePicker
accept={props.accept}
disabled={props.disabled}
uploadLabel={props.uploadLabel ?? (props.mode === 'chat' ? '上传附件' : '参考内容')}
uploads={props.uploads}
uploading={props.uploading}
onFiles={props.onFiles}
onRemove={props.onRemove}
/>
);
}
function StackedReferencePicker(props: {
accept: string;
disabled?: boolean;
uploadLabel: string;
uploads: PlaygroundUpload[];
uploading?: boolean;
onFiles?: (files: File[]) => void;
onRemove?: (id: string) => void;
}) {
const inputRef = useRef<HTMLInputElement>(null);
const [hoveredId, setHoveredId] = useState('');
const [expanded, setExpanded] = useState(false);
const hoveredUpload = props.uploads.find((item) => item.id === hoveredId);
const disabled = props.disabled || props.uploading || !props.onFiles;
const uploadCardIndex = props.uploads.length;
return (
<div
className="mediaReferenceStack"
data-expanded={expanded}
onMouseLeave={() => {
setExpanded(false);
setHoveredId('');
}}
>
{hoveredUpload && <div className="mediaReferenceTooltip">{hoveredUpload.name}</div>}
<div
className="mediaReferenceStackCards"
data-empty={!props.uploads.length}
style={{ '--reference-count': Math.max(1, props.uploads.length + 1) } as CSSProperties}
>
{props.uploads.map((item, index) => (
<div
className="mediaReferenceCard"
data-hovered={hoveredId === item.id}
data-kind={item.kind}
key={item.id}
style={referenceCardStyle(index)}
title={item.name}
onMouseEnter={() => {
setExpanded(true);
setHoveredId(item.id);
}}
>
<ReferencePreview item={item} />
{item.kind !== 'image' && <span className="mediaReferenceDuration">{uploadKindLabel(item.kind)}</span>}
{props.onRemove && hoveredId === item.id && (
<button type="button" className="mediaReferenceRemove" aria-label={`删除 ${item.name}`} onClick={() => props.onRemove?.(item.id)}>
<X size={13} />
</button>
)}
</div>
))}
<button
type="button"
className="mediaReferenceCard mediaReferenceUploadCard"
aria-label={props.uploadLabel}
data-has-uploads={props.uploads.length > 0}
data-uploading={Boolean(props.uploading)}
disabled={disabled}
style={referenceCardStyle(uploadCardIndex)}
title={props.uploadLabel}
onClick={() => inputRef.current?.click()}
onMouseEnter={() => setHoveredId('')}
>
{props.uploading ? <LoaderCircle className="composerUploadSpinner" size={18} /> : <Plus size={20} />}
<span>{props.uploadLabel}</span>
</button>
{props.uploads.length > 0 && (
<button
type="button"
className="mediaReferenceAdd"
aria-label={props.uploadLabel}
data-uploading={Boolean(props.uploading)}
disabled={disabled}
title={props.uploadLabel}
onClick={() => inputRef.current?.click()}
onMouseEnter={() => {
setExpanded(false);
setHoveredId('');
}}
>
{props.uploading ? <LoaderCircle className="composerUploadSpinner" size={15} /> : <Plus size={17} />}
</button>
)}
</div>
<input
ref={inputRef}
type="file"
multiple
hidden
accept={props.accept}
disabled={disabled}
onChange={(event) => {
const files = Array.from(event.currentTarget.files ?? []);
event.currentTarget.value = '';
props.onFiles?.(files);
}}
/>
</div>
);
}
const referenceTiltValues = [-7, 6, -3, 8, -5, 4, -8, 5];
const referenceXValues = [0, -5, 6, -3, 4, -6, 3, -4];
const referenceYValues = [0, 3, -1, 4, 1, 5, 2, -2];
function referenceCardStyle(index: number) {
const valueIndex = index % referenceTiltValues.length;
return {
'--reference-index': index,
'--reference-tilt': `${referenceTiltValues[valueIndex]}deg`,
'--reference-x': `${referenceXValues[valueIndex]}px`,
'--reference-y': `${referenceYValues[valueIndex]}px`,
} as CSSProperties;
}
function FirstLastFramePicker(props: {
accept: string;
disabled?: boolean;
uploads: PlaygroundUpload[];
uploading?: boolean;
onFiles?: (files: File[], targetRole?: PlaygroundUploadRole) => void;
onRemove?: (id: string) => void;
onSwapFrames?: () => void;
}) {
const first = frameUploadByRole(props.uploads, 'first_frame');
const last = frameUploadByRole(props.uploads, 'last_frame');
const canSwap = Boolean(first && last);
return (
<div className="firstLastFramePicker">
<FrameSlot
accept={props.accept}
disabled={props.disabled}
item={first}
label="首帧"
role="first_frame"
uploading={props.uploading}
onFiles={props.onFiles}
onRemove={props.onRemove}
/>
<button type="button" className="frameSwapButton" aria-label="交换首尾帧" disabled={!canSwap} onClick={props.onSwapFrames}>
<Repeat2 size={19} />
</button>
<FrameSlot
accept={props.accept}
disabled={props.disabled}
item={last}
label="尾帧"
role="last_frame"
uploading={props.uploading}
onFiles={props.onFiles}
onRemove={props.onRemove}
/>
</div>
);
}
function FrameSlot(props: {
accept: string;
disabled?: boolean;
item?: PlaygroundUpload;
label: string;
role: PlaygroundUploadRole;
uploading?: boolean;
onFiles?: (files: File[], targetRole?: PlaygroundUploadRole) => void;
onRemove?: (id: string) => void;
}) {
const inputRef = useRef<HTMLInputElement>(null);
const disabled = props.disabled || props.uploading || !props.onFiles;
return (
<div className="frameSlot">
<button
type="button"
className="frameSlotButton"
data-filled={Boolean(props.item)}
disabled={disabled}
title={props.item?.name ?? props.label}
onClick={() => inputRef.current?.click()}
>
{props.item ? <ReferencePreview item={props.item} /> : <Plus size={20} />}
<span>{props.label}</span>
</button>
{props.item && props.onRemove && (
<button type="button" className="frameSlotRemove" aria-label={`删除 ${props.label}`} onClick={() => props.onRemove?.(props.item!.id)}>
<X size={13} />
</button>
)}
<input
ref={inputRef}
type="file"
hidden
accept={props.accept}
disabled={disabled}
onChange={(event) => {
const files = Array.from(event.currentTarget.files ?? []);
event.currentTarget.value = '';
props.onFiles?.(files, props.role);
}}
/>
</div>
);
}
function ReferencePreview(props: { item: PlaygroundUpload }) {
if (props.item.kind === 'image') {
return <img src={props.item.url} alt="" draggable={false} />;
}
if (props.item.kind === 'video') {
return <video src={props.item.url} muted playsInline preload="metadata" />;
}
if (props.item.kind === 'audio') {
return (
<span className="mediaReferencePlaceholder">
<Music2 size={18} />
</span>
);
}
return (
<span className="mediaReferencePlaceholder">
<FileText size={18} />
</span>
);
}
export function UploadAttachmentList(props: {
message?: string;
uploads: PlaygroundUpload[];
onRemove?: (id: string) => void;
}) {
if (!props.uploads.length && !props.message) return null;
return (
<div className="composerUploadArea">
{props.uploads.length > 0 && (
<div className="composerUploadList">
{props.uploads.map((item) => (
<span className="composerUploadChip" key={item.id} title={`${item.name} · ${item.url}`}>
{uploadKindIcon(item.kind)}
<span>{item.name}</span>
<small>{formatFileSize(item.size)}</small>
{props.onRemove && (
<button type="button" aria-label={`移除 ${item.name}`} onClick={() => props.onRemove?.(item.id)}>
<X size={13} />
</button>
)}
</span>
))}
</div>
)}
{props.message && <div className="composerUploadMessage">{props.message}</div>}
</div>
);
}
function uploadKindIcon(kind: PlaygroundUploadKind) {
if (kind === 'image') return <ImageIcon size={14} />;
if (kind === 'video') return <Video size={14} />;
if (kind === 'audio') return <Music2 size={14} />;
return <FileText size={14} />;
}
export async function uploadPlaygroundFiles(
token: string,
files: File[],
options: { allowFiles: boolean; allowedKinds?: PlaygroundUploadKind[]; source: string },
): Promise<{ items: PlaygroundUpload[]; warnings: string[] }> {
const allowedKinds = options.allowedKinds ?? (options.allowFiles ? ['audio', 'file', 'image', 'video'] : ['audio', 'image', 'video']);
const accepted: Array<{ file: File; kind: PlaygroundUploadKind }> = [];
const warnings: string[] = [];
files.forEach((file) => {
const kind = acceptedUploadKind(file, options.allowFiles);
if (!kind || !allowedKinds.includes(kind)) {
warnings.push(options.allowFiles
? `已跳过 ${file.name},聊天仅支持图片、视频、音频和常见文档。`
: `已跳过 ${file.name},当前场景仅支持${allowedUploadKindLabel(allowedKinds)}`);
return;
}
accepted.push({ file, kind });
});
if (!accepted.length) return { items: [], warnings };
const items = await Promise.all(accepted.map(async ({ file, kind }) => {
const response = await uploadFileToStorage(token, file, options.source);
const url = uploadResponseUrl(response);
if (!url) {
throw new Error(`${file.name} 上传成功,但网关没有返回可用文件 URL。`);
}
return {
contentType: file.type || '',
id: newLocalId(),
kind,
name: file.name || '未命名文件',
raw: response,
size: file.size,
url,
};
}));
return { items, warnings };
}
function acceptedUploadKind(file: File, allowFiles: boolean): PlaygroundUploadKind | undefined {
const mime = file.type.toLowerCase();
const extension = fileExtension(file.name);
if (mime.startsWith('image/') || imageExtensions.has(extension)) return 'image';
if (mime.startsWith('video/') || videoExtensions.has(extension)) return 'video';
if (mime.startsWith('audio/') || audioExtensions.has(extension)) return 'audio';
if (allowFiles && (documentExtensions.has(extension) || documentMimes.has(mime))) return 'file';
return undefined;
}
const imageExtensions = new Set(['avif', 'bmp', 'gif', 'heic', 'heif', 'jpeg', 'jpg', 'png', 'svg', 'tif', 'tiff', 'webp']);
const videoExtensions = new Set(['avi', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'webm']);
const audioExtensions = new Set(['aac', 'flac', 'm4a', 'mp3', 'oga', 'ogg', 'opus', 'wav', 'weba']);
const documentExtensions = new Set(['csv', 'doc', 'docx', 'json', 'jsonl', 'md', 'markdown', 'pdf', 'ppt', 'pptx', 'txt', 'xls', 'xlsx', 'yaml', 'yml']);
const documentMimes = new Set([
'application/json',
'application/msword',
'application/pdf',
'application/vnd.ms-excel',
'application/vnd.ms-powerpoint',
'application/vnd.openxmlformats-officedocument.presentationml.presentation',
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
'text/csv',
'text/markdown',
'text/plain',
'text/yaml',
]);
function fileExtension(name: string) {
const index = name.lastIndexOf('.');
return index >= 0 ? name.slice(index + 1).toLowerCase() : '';
}
function uploadResponseUrl(response: Record<string, unknown>) {
const data = recordFromUnknown(response.data);
const file = recordFromUnknown(response.file);
const result = recordFromUnknown(response.result);
return firstString(
response.url,
response.fileUrl,
response.file_url,
response.objectUrl,
response.object_url,
response.downloadUrl,
response.download_url,
data?.url,
data?.fileUrl,
data?.file_url,
file?.url,
file?.fileUrl,
file?.file_url,
result?.url,
result?.fileUrl,
result?.file_url,
);
}
export function openAIContentFromPromptAndUploads(prompt: string, uploads: PlaygroundUpload[]): OpenAIChatContentPart[] {
const content: OpenAIChatContentPart[] = [];
const text = prompt.trim();
if (text) {
content.push({ type: 'text', text });
}
uploads.forEach((item) => {
const part = openAIContentPartFromUpload(item);
if (part) content.push(part);
});
return content.length ? content : [{ type: 'text', text: '' }];
}
function openAIContentPartFromUpload(item: PlaygroundUpload): OpenAIChatContentPart | undefined {
if (!item.url) return undefined;
if (item.kind === 'image') return { type: 'image_url', image_url: { url: item.url } };
if (item.kind === 'video') return { type: 'video_url', video_url: { url: item.url } };
if (item.kind === 'audio') return { type: 'audio_url', audio_url: { url: item.url } };
return { type: 'file_url', file_url: { filename: item.name, url: item.url } };
}
export function mediaUploadRequestPayload(uploads: PlaygroundUpload[], mode: Exclude<PlaygroundMode, 'chat'>) {
const images = uploads.filter((item) => item.kind === 'image').map((item) => item.url);
const payload: Record<string, string | string[]> = {};
if (mode === 'image') {
if (images.length) {
payload.image = singleOrMany(images);
payload.images = images;
}
return payload;
}
return payload;
}
export function videoGenerationContentFromPromptAndUploads(
prompt: string,
uploads: PlaygroundUpload[],
videoMode: PlaygroundVideoCreateMode,
): VideoGenerationContent[] {
const content: VideoGenerationContent[] = [];
const text = prompt.trim();
if (text) {
content.push({ type: 'text', text });
}
if (videoMode === 'first_last_frame') {
const first = frameUploadByRole(uploads, 'first_frame');
const last = frameUploadByRole(uploads, 'last_frame');
if (first?.url) {
content.push({ type: 'image_url', role: 'first_frame', image_url: { url: first.url } });
}
if (last?.url) {
content.push({ type: 'image_url', role: 'last_frame', image_url: { url: last.url } });
}
return content.length ? content : [{ type: 'text', text: '' }];
}
uploads.forEach((item) => {
const part = videoGenerationContentFromUpload(item);
if (part) content.push(part);
});
return content.length ? content : [{ type: 'text', text: '' }];
}
function videoGenerationContentFromUpload(item: PlaygroundUpload): VideoGenerationContent | undefined {
if (!item.url) return undefined;
if (item.kind === 'image') {
return { type: 'image_url', role: 'reference_image', image_url: { url: item.url } };
}
if (item.kind === 'video') {
return { type: 'video_url', role: 'reference_video', video_url: { url: item.url, refer_type: 'feature' } };
}
if (item.kind === 'audio') {
return { type: 'audio_url', role: 'reference_audio', audio_url: { url: item.url } };
}
return undefined;
}
function singleOrMany(values: string[]) {
return values.length === 1 ? values[0] : values;
}
export function uploadKindLabel(kind: PlaygroundUploadKind) {
if (kind === 'image') return '图片';
if (kind === 'video') return '视频';
if (kind === 'audio') return '音频';
return '文件';
}
export function allowedUploadKindLabel(kinds: PlaygroundUploadKind[]) {
const labels = kinds.map(uploadKindLabel);
return labels.length ? labels.join('、') : '当前文件类型';
}
export function formatFileSize(size: number) {
if (!Number.isFinite(size) || size <= 0) return '';
if (size < 1024) return `${size} B`;
if (size < 1024 * 1024) return `${(size / 1024).toFixed(1)} KB`;
return `${(size / 1024 / 1024).toFixed(1)} MB`;
}
export function mediaUploadAcceptForMode(mode: PlaygroundMode, videoMode: PlaygroundVideoCreateMode) {
if (mode === 'image') return imageOnlyUploadAccept;
if (mode === 'video' && videoMode === 'first_last_frame') return imageOnlyUploadAccept;
return mediaUploadAccept;
}
export function allowedMediaUploadKinds(mode: PlaygroundMode, videoMode: PlaygroundVideoCreateMode): PlaygroundUploadKind[] {
if (mode === 'image') return ['image'];
if (mode === 'video' && videoMode === 'first_last_frame') return ['image'];
if (mode === 'video') return ['audio', 'image', 'video'];
return ['audio', 'file', 'image', 'video'];
}
export function mediaUploadSummaryMessage(uploads: PlaygroundUpload[], mode: PlaygroundMode, videoMode: PlaygroundVideoCreateMode) {
if (!uploads.length) return '';
const images = uploads.filter((item) => item.kind === 'image').length;
const videos = uploads.filter((item) => item.kind === 'video').length;
const audios = uploads.filter((item) => item.kind === 'audio').length;
const files = uploads.filter((item) => item.kind === 'file').length;
if (mode === 'image') {
return `已上传 ${images} 张参考图。`;
}
if (mode === 'video' && videoMode === 'first_last_frame') {
const first = frameUploadByRole(uploads, 'first_frame');
const last = frameUploadByRole(uploads, 'last_frame');
if (first && last) return '已上传首帧、尾帧参考图。';
if (first) return '已上传首帧参考图。';
if (last) return '已上传尾帧参考图。';
return `已上传 ${images} 张首尾帧参考图。`;
}
const parts = [
images ? `${images} 张图片` : '',
videos ? `${videos} 个视频` : '',
audios ? `${audios} 段音频` : '',
files ? `${files} 个文件` : '',
].filter(Boolean);
return parts.length ? `已上传 ${parts.join('、')}` : '';
}
export function mergeMediaUploadsForMode(
current: PlaygroundUpload[],
items: PlaygroundUpload[],
mode: PlaygroundMode,
videoMode: PlaygroundVideoCreateMode,
targetRole?: PlaygroundUploadRole,
) {
if (mode === 'image') {
return [...current.filter((item) => item.kind === 'image'), ...items.filter((item) => item.kind === 'image')];
}
if (mode === 'video' && videoMode === 'first_last_frame') {
return mergeFirstLastFrameUploads(current, items, targetRole);
}
if (mode === 'video') {
return [...current, ...items.filter((item) => item.kind === 'image' || item.kind === 'video' || item.kind === 'audio')];
}
return [...current, ...items];
}
export function normalizeFirstLastFrameUploads(uploads: PlaygroundUpload[]) {
const images = uploads.filter((item) => item.kind === 'image');
if (!images.length) return uploads.length ? [] : uploads;
const first = frameUploadByRole(images, 'first_frame') ?? images[0];
const last = frameUploadByRole(images, 'last_frame') ?? images.find((item) => item.id !== first?.id);
const next: PlaygroundUpload[] = [];
if (first) next.push({ ...first, role: 'first_frame' });
if (last) next.push({ ...last, role: 'last_frame' });
return uploadListsEqual(uploads, next) ? uploads : next;
}
function mergeFirstLastFrameUploads(current: PlaygroundUpload[], items: PlaygroundUpload[], targetRole?: PlaygroundUploadRole) {
const incoming = items.filter((item) => item.kind === 'image');
let next = normalizeFirstLastFrameUploads(current);
if (!incoming.length) return next;
const assignUpload = (item: PlaygroundUpload, role: PlaygroundUploadRole) => {
next = next.filter((upload) => upload.role !== role);
next.push({ ...item, role });
};
if (targetRole) {
assignUpload(incoming[0]!, targetRole);
const oppositeRole: PlaygroundUploadRole = targetRole === 'first_frame' ? 'last_frame' : 'first_frame';
incoming.slice(1).forEach((item) => {
if (!frameUploadByRole(next, oppositeRole)) assignUpload(item, oppositeRole);
});
return sortFrameUploads(next);
}
incoming.forEach((item) => {
if (!frameUploadByRole(next, 'first_frame')) {
assignUpload(item, 'first_frame');
} else if (!frameUploadByRole(next, 'last_frame')) {
assignUpload(item, 'last_frame');
}
});
return sortFrameUploads(next);
}
export function swapFirstLastFrameUploads(uploads: PlaygroundUpload[]) {
return sortFrameUploads(uploads.map((item) => {
if (item.role === 'first_frame') return { ...item, role: 'last_frame' as const };
if (item.role === 'last_frame') return { ...item, role: 'first_frame' as const };
return item;
}));
}
function sortFrameUploads(uploads: PlaygroundUpload[]) {
const first = frameUploadByRole(uploads, 'first_frame');
const last = frameUploadByRole(uploads, 'last_frame');
return [first, last].filter((item): item is PlaygroundUpload => Boolean(item));
}
export function frameUploadByRole(uploads: PlaygroundUpload[], role: PlaygroundUploadRole) {
return uploads.find((item) => item.role === role);
}
function uploadListsEqual(left: PlaygroundUpload[], right: PlaygroundUpload[]) {
if (left.length !== right.length) return false;
return left.every((item, index) => {
const next = right[index];
return next && item.id === next.id && item.role === next.role && item.kind === next.kind;
});
}
function recordFromUnknown(value: unknown): Record<string, unknown> | undefined {
if (!value || typeof value !== 'object' || Array.isArray(value)) return undefined;
return value as Record<string, unknown>;
}
function firstString(...values: unknown[]) {
for (const value of values) {
const text = typeof value === 'string' ? value.trim() : '';
if (text) return text;
}
return '';
}
function newLocalId() {
return typeof crypto !== 'undefined' && 'randomUUID' in crypto
? crypto.randomUUID()
: `${Date.now()}-${Math.random().toString(36).slice(2)}`;
}

View File

@ -39,6 +39,7 @@ const adminPaths: Record<AdminSection, string> = {
auditLogs: '/admin/audit-logs',
runtime: '/admin/runtime',
accessRules: '/admin/access-rules',
systemSettings: '/admin/system-settings',
};
const docsPaths: Record<ApiDocSection, string> = {

View File

@ -1812,6 +1812,121 @@
justify-content: flex-end;
}
.fileStoragePanel {
display: grid;
gap: 14px;
}
.fileStorageToolbar {
display: flex;
align-items: center;
justify-content: space-between;
gap: 12px;
padding: 14px 16px;
border: 1px solid var(--border);
border-radius: 10px;
background: #fff;
}
.fileStorageSettingsCard {
display: grid;
grid-template-columns: minmax(0, 1fr) minmax(260px, 380px) auto;
align-items: center;
gap: 14px;
padding: 14px 16px;
border: 1px solid var(--border);
border-radius: 10px;
background: #fff;
}
.fileStorageSettingsCard > div {
display: grid;
gap: 3px;
}
.fileStorageSettingsCard > label {
min-width: 0;
}
.fileStorageToolbar > div {
display: grid;
gap: 3px;
}
.fileStorageToolbar strong {
color: var(--text-normal);
}
.fileStorageToolbar span,
.fileStorageSettingsCard span,
.fileStorageMeta span {
color: var(--muted-foreground);
font-size: var(--font-size-xs);
line-height: 1.45;
}
.fileStorageGrid {
display: grid;
grid-template-columns: repeat(2, minmax(0, 1fr));
gap: 14px;
}
.fileStorageCard {
display: grid;
gap: 12px;
padding: 16px;
border: 1px solid var(--border);
border-radius: 10px;
background: #fff;
}
.fileStorageCard header,
.fileStorageCard footer {
display: flex;
align-items: center;
gap: 10px;
}
.fileStorageCard header > div:nth-child(2) {
display: grid;
min-width: 0;
flex: 1;
gap: 3px;
}
.fileStorageCard strong,
.fileStorageCard header span,
.fileStorageMeta span {
overflow: hidden;
text-overflow: ellipsis;
}
.fileStorageMeta {
display: grid;
gap: 7px;
}
.fileStorageMeta span {
padding: 7px 9px;
border: 1px solid var(--border-subtle);
border-radius: 8px;
background: var(--surface-subtle);
}
.fileStorageCard footer {
justify-content: flex-end;
}
.fileStorageDialogBody {
grid-template-columns: repeat(2, minmax(0, 1fr));
}
.fileStorageSceneGrid {
display: grid;
grid-template-columns: repeat(2, minmax(0, 1fr));
gap: 10px;
}
.runtimePolicyDialog {
width: min(980px, 100%);
}
@ -1983,6 +2098,7 @@
.providerCatalogGrid,
.baseModelGrid,
.runtimePolicyGrid,
.fileStorageGrid,
.platformGrid,
.accessPermissionGrid,
.platformModelChoices {
@ -2023,7 +2139,11 @@
.platformModelRow,
.platformModelToolbar,
.runtimePolicyGrid,
.fileStorageGrid,
.fileStorageSettingsCard,
.fileStorageSceneGrid,
.runtimePolicyFormBody,
.fileStorageDialogBody,
.runtimePolicyRows,
.runnerActionGrid,
.accessPermissionGrid,

File diff suppressed because it is too large Load Diff

View File

@ -17,7 +17,8 @@ export type AdminSection =
| 'userGroups'
| 'auditLogs'
| 'runtime'
| 'accessRules';
| 'accessRules'
| 'systemSettings';
export interface LoginForm {
account: string;

View File

@ -506,6 +506,32 @@ export interface PlayableGatewayApiKey extends GatewayApiKey {
secret: string;
}
export interface GatewayPricingEstimateItem {
amount?: number;
currency?: string;
discountFactor?: number;
durationSeconds?: number;
durationUnit?: string;
durationUnitCount?: number;
model?: string;
modelAlias?: string;
platformId?: string;
platformModelId?: string;
provider?: string;
quantity?: number | string;
resourceType?: string;
simulated?: boolean;
unit?: string;
[key: string]: unknown;
}
export interface GatewayPricingEstimate {
items: GatewayPricingEstimateItem[];
resolver: string;
totalAmount?: number;
currency?: string;
}
export interface GatewayWalletAccount {
id: string;
gatewayTenantId?: string;
@ -844,6 +870,46 @@ export interface GatewayNetworkProxyConfig {
globalHttpProxySource?: string;
}
export interface FileStorageChannel {
id: string;
channelKey: string;
name: string;
provider: 'server_main_openapi' | 'aliyun_oss' | 'tencent_cos' | string;
uploadUrl?: string;
credentialsPreview?: Record<string, unknown>;
scenes?: string[];
config?: Record<string, unknown>;
retryPolicy?: Record<string, unknown>;
priority: number;
status: 'enabled' | 'disabled' | string;
lastError?: string;
lastFailedAt?: string;
lastSucceededAt?: string;
createdAt: string;
updatedAt: string;
}
export interface FileStorageChannelUpsertRequest {
channelKey: string;
name: string;
provider?: 'server_main_openapi' | 'aliyun_oss' | 'tencent_cos' | string;
uploadUrl?: string;
apiKey?: string;
scenes?: string[];
config?: Record<string, unknown>;
retryPolicy?: Record<string, unknown>;
priority?: number;
status?: 'enabled' | 'disabled' | string;
}
export interface FileStorageSettings {
resultUploadPolicy: 'default' | 'upload_all' | 'upload_none' | string;
}
export interface FileStorageSettingsUpdateRequest {
resultUploadPolicy: 'default' | 'upload_all' | 'upload_none' | string;
}
export interface GatewayTask {
id: string;
kind: string;