229 lines
8.5 KiB
Go
229 lines
8.5 KiB
Go
package httpapi
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/base64"
|
|
"io"
|
|
"log/slog"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
|
|
)
|
|
|
|
func TestRequestAssetFromValueDetectsDataURLAndRawBase64(t *testing.T) {
|
|
payload := base64.StdEncoding.EncodeToString([]byte("inline image"))
|
|
decoded, ok, err := requestAssetFromValue("url", []string{"messages", "[0]", "content", "[1]", "image_url"}, "data:image/png;base64,"+payload, nil)
|
|
if err != nil {
|
|
t.Fatalf("decode data URL: %v", err)
|
|
}
|
|
if !ok || decoded.ContentType != "image/png" || string(decoded.Bytes) != "inline image" {
|
|
t.Fatalf("unexpected data URL asset: ok=%v decoded=%+v", ok, decoded)
|
|
}
|
|
|
|
audio := base64.StdEncoding.EncodeToString([]byte("inline audio"))
|
|
decoded, ok, err = requestAssetFromValue("data", []string{"input_audio"}, audio, map[string]any{"format": "mp3"})
|
|
if err != nil {
|
|
t.Fatalf("decode raw audio: %v", err)
|
|
}
|
|
if !ok || decoded.ContentType != "audio/mpeg" || string(decoded.Bytes) != "inline audio" {
|
|
t.Fatalf("unexpected raw audio asset: ok=%v decoded=%+v", ok, decoded)
|
|
}
|
|
}
|
|
|
|
func TestCanonicalConversationMessageHashUsesTextAndAssetRefs(t *testing.T) {
|
|
message := map[string]any{
|
|
"role": "user",
|
|
"content": []any{
|
|
map[string]any{"type": "text", "text": "describe it"},
|
|
map[string]any{"type": "image_url", "image_url": map[string]any{
|
|
"url": "https://cdn.example/a.png",
|
|
"assetRef": map[string]any{"sha256": "sha-a", "url": "https://cdn.example/a.png"},
|
|
}},
|
|
},
|
|
}
|
|
sameMessage := map[string]any{
|
|
"role": "user",
|
|
"content": []any{
|
|
map[string]any{"type": "text", "text": "describe it"},
|
|
map[string]any{"type": "image_url", "image_url": map[string]any{
|
|
"url": "https://different.example/a.png",
|
|
"assetRef": map[string]any{"sha256": "sha-a", "url": "https://different.example/a.png"},
|
|
}},
|
|
},
|
|
}
|
|
changedMessage := map[string]any{
|
|
"role": "user",
|
|
"content": "describe something else",
|
|
}
|
|
|
|
firstHash, assetHashes := canonicalConversationMessageHash(message)
|
|
secondHash, _ := canonicalConversationMessageHash(sameMessage)
|
|
changedHash, _ := canonicalConversationMessageHash(changedMessage)
|
|
if firstHash != secondHash {
|
|
t.Fatalf("message hash should ignore resource URL drift when asset sha is stable")
|
|
}
|
|
if firstHash == changedHash {
|
|
t.Fatalf("message hash should change when text changes")
|
|
}
|
|
if len(assetHashes) != 1 || assetHashes[0] != "sha-a" {
|
|
t.Fatalf("unexpected asset hashes: %+v", assetHashes)
|
|
}
|
|
}
|
|
|
|
func TestImageEditMultipartFormBodyMapsFilesAndFields(t *testing.T) {
|
|
var raw bytes.Buffer
|
|
writer := multipart.NewWriter(&raw)
|
|
if err := writer.WriteField("model", "doubao-5.0图像编辑"); err != nil {
|
|
t.Fatalf("write model field: %v", err)
|
|
}
|
|
if err := writer.WriteField("prompt", "换个姿势"); err != nil {
|
|
t.Fatalf("write prompt field: %v", err)
|
|
}
|
|
if err := writer.WriteField("n", "2"); err != nil {
|
|
t.Fatalf("write n field: %v", err)
|
|
}
|
|
if err := writer.WriteField("sequential_image_generation_options", `{"max_images":2}`); err != nil {
|
|
t.Fatalf("write sequential options field: %v", err)
|
|
}
|
|
writeMultipartFixtureFile(t, writer, "image", "single.png")
|
|
writeMultipartFixtureFile(t, writer, "images", "ref-a.png")
|
|
writeMultipartFixtureFile(t, writer, "images[]", "ref-b.png")
|
|
writeMultipartFixtureFile(t, writer, "mask", "mask.png")
|
|
if err := writer.Close(); err != nil {
|
|
t.Fatalf("close multipart writer: %v", err)
|
|
}
|
|
request := httptest.NewRequest(http.MethodPost, "/api/v1/images/edits", &raw)
|
|
request.Header.Set("Content-Type", writer.FormDataContentType())
|
|
if err := request.ParseMultipartForm(multipartTaskMemoryBytes); err != nil {
|
|
t.Fatalf("parse multipart form: %v", err)
|
|
}
|
|
defer request.MultipartForm.RemoveAll()
|
|
|
|
body, err := imageEditMultipartFormBody(context.Background(), request.MultipartForm, func(_ context.Context, field string, header *multipart.FileHeader) (map[string]any, error) {
|
|
ref := map[string]any{
|
|
"sha256": field + "-" + header.Filename,
|
|
"url": "https://cdn.example/" + header.Filename,
|
|
"contentType": header.Header.Get("Content-Type"),
|
|
"storageProvider": "server_main_openapi",
|
|
}
|
|
return requestAssetWrapper(ref), nil
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("build multipart image edit body: %v", err)
|
|
}
|
|
if body["model"] != "doubao-5.0图像编辑" || body["prompt"] != "换个姿势" {
|
|
t.Fatalf("unexpected scalar fields: %+v", body)
|
|
}
|
|
if body["n"] != float64(2) {
|
|
t.Fatalf("n should be parsed as number, got %#v", body["n"])
|
|
}
|
|
options, _ := body["sequential_image_generation_options"].(map[string]any)
|
|
if options["max_images"] != float64(2) {
|
|
t.Fatalf("sequential options should parse JSON object, got %+v", options)
|
|
}
|
|
image, _ := body["image"].(map[string]any)
|
|
if image["url"] != "https://cdn.example/single.png" {
|
|
t.Fatalf("single image should map to image URL wrapper, got %+v", image)
|
|
}
|
|
images, _ := body["images"].([]any)
|
|
if len(images) != 2 {
|
|
t.Fatalf("multi image fields should map to images array, got %+v", body["images"])
|
|
}
|
|
firstMulti, _ := images[0].(map[string]any)
|
|
secondMulti, _ := images[1].(map[string]any)
|
|
if firstMulti["url"] != "https://cdn.example/ref-a.png" || secondMulti["url"] != "https://cdn.example/ref-b.png" {
|
|
t.Fatalf("unexpected images array: %+v", images)
|
|
}
|
|
mask, _ := body["mask"].(map[string]any)
|
|
if mask["url"] != "https://cdn.example/mask.png" {
|
|
t.Fatalf("mask should map to mask URL wrapper, got %+v", mask)
|
|
}
|
|
}
|
|
|
|
func writeMultipartFixtureFile(t *testing.T, writer *multipart.Writer, field string, filename string) {
|
|
t.Helper()
|
|
part, err := writer.CreateFormFile(field, filename)
|
|
if err != nil {
|
|
t.Fatalf("create multipart file %s/%s: %v", field, filename, err)
|
|
}
|
|
if _, err := part.Write([]byte{0x89, 'P', 'N', 'G', '\r', '\n', 0x1a, '\n'}); err != nil {
|
|
t.Fatalf("write multipart file %s/%s: %v", field, filename, err)
|
|
}
|
|
}
|
|
|
|
func TestCleanupExpiredLocalTempAssetsDeletesExpiredStaticFiles(t *testing.T) {
|
|
uploadedDir := t.TempDir()
|
|
generatedDir := t.TempDir()
|
|
oldUploaded := filepath.Join(uploadedDir, requestAssetFilePrefix+"old.png")
|
|
freshUploaded := filepath.Join(uploadedDir, requestAssetFilePrefix+"fresh.png")
|
|
oldGenerated := filepath.Join(generatedDir, "gateway-result-old.png")
|
|
freshGenerated := filepath.Join(generatedDir, "gateway-result-fresh.png")
|
|
for _, path := range []string{oldUploaded, freshUploaded, oldGenerated, freshGenerated} {
|
|
if err := os.WriteFile(path, []byte("asset"), 0o644); err != nil {
|
|
t.Fatalf("write fixture %s: %v", path, err)
|
|
}
|
|
}
|
|
now := time.Now()
|
|
for _, path := range []string{oldUploaded, oldGenerated} {
|
|
if err := os.Chtimes(path, now.Add(-25*time.Hour), now.Add(-25*time.Hour)); err != nil {
|
|
t.Fatalf("touch old static asset %s: %v", path, err)
|
|
}
|
|
}
|
|
for _, path := range []string{freshUploaded, freshGenerated} {
|
|
if err := os.Chtimes(path, now.Add(-23*time.Hour), now.Add(-23*time.Hour)); err != nil {
|
|
t.Fatalf("touch fresh static asset %s: %v", path, err)
|
|
}
|
|
}
|
|
server := &Server{
|
|
cfg: config.Config{
|
|
LocalGeneratedStorageDir: generatedDir,
|
|
LocalUploadedStorageDir: uploadedDir,
|
|
LocalTempAssetTTLHours: 24,
|
|
},
|
|
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
|
|
}
|
|
|
|
deleted := server.cleanupExpiredLocalTempAssets(context.Background(), now)
|
|
|
|
if deleted != 2 {
|
|
t.Fatalf("expected two expired static asset deletes, got %d", deleted)
|
|
}
|
|
for _, path := range []string{oldUploaded, oldGenerated} {
|
|
if _, err := os.Stat(path); !os.IsNotExist(err) {
|
|
t.Fatalf("old static asset should be deleted %s, stat err=%v", path, err)
|
|
}
|
|
}
|
|
for _, path := range []string{freshUploaded, freshGenerated} {
|
|
if _, err := os.Stat(path); err != nil {
|
|
t.Fatalf("fresh static asset should remain %s: %v", path, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestRequestConversationKeyPriority(t *testing.T) {
|
|
request := httptest.NewRequest(http.MethodPost, "/api/v1/chat/completions", nil)
|
|
request.Header.Set("X-EasyAI-Conversation-ID", "from-header")
|
|
body := map[string]any{
|
|
"conversation_id": "from-body",
|
|
"metadata": map[string]any{"conversation_id": "from-metadata"},
|
|
}
|
|
if got := requestConversationKey(request, body); got != "from-header" {
|
|
t.Fatalf("expected header conversation id, got %q", got)
|
|
}
|
|
request.Header.Del("X-EasyAI-Conversation-ID")
|
|
if got := requestConversationKey(request, body); got != "from-body" {
|
|
t.Fatalf("expected body conversation id, got %q", got)
|
|
}
|
|
delete(body, "conversation_id")
|
|
if got := requestConversationKey(request, body); got != "from-metadata" {
|
|
t.Fatalf("expected metadata conversation id, got %q", got)
|
|
}
|
|
}
|