Compare commits
16 Commits
feature/op
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 8ad5b06c18 | |||
| af9b281d34 | |||
| 11a2c13e4a | |||
| 69d23efb57 | |||
| baffccf8f8 | |||
| 13186f8ed1 | |||
| 73c6d43e4b | |||
| ba419cd90a | |||
| d09a4c2e4d | |||
| 90c3315468 | |||
| ae197a742f | |||
| 34c3251c6d | |||
| 62d426bdfb | |||
| 7abb6a1baf | |||
| be283daaa3 | |||
| 2a9a833cd7 |
10
.gitignore
vendored
10
.gitignore
vendored
@ -15,3 +15,13 @@ coverage/
|
|||||||
.idea
|
.idea
|
||||||
.gitignore
|
.gitignore
|
||||||
|
|
||||||
|
# Devenv
|
||||||
|
.devenv*
|
||||||
|
devenv.local.nix
|
||||||
|
devenv.local.yaml
|
||||||
|
|
||||||
|
# direnv
|
||||||
|
.direnv
|
||||||
|
|
||||||
|
# pre-commit
|
||||||
|
.pre-commit-config.yaml
|
||||||
|
|||||||
@ -2357,6 +2357,355 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/api/admin/system/file-storage/channels": {
|
||||||
|
"get": {
|
||||||
|
"security": [
|
||||||
|
{
|
||||||
|
"BearerAuth": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"description": "返回所有未删除的文件存储通道,用于管理上传与生成资源回传策略。",
|
||||||
|
"produces": [
|
||||||
|
"application/json"
|
||||||
|
],
|
||||||
|
"tags": [
|
||||||
|
"system"
|
||||||
|
],
|
||||||
|
"summary": "列出文件存储通道",
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "OK",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/httpapi.FileStorageChannelListResponse"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"401": {
|
||||||
|
"description": "Unauthorized",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/httpapi.ErrorEnvelope"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"403": {
|
||||||
|
"description": "Forbidden",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/httpapi.ErrorEnvelope"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"500": {
|
||||||
|
"description": "Internal Server Error",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/httpapi.ErrorEnvelope"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"post": {
|
||||||
|
"security": [
|
||||||
|
{
|
||||||
|
"BearerAuth": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"description": "创建文件存储通道,当前主要用于配置 server-main OpenAPI 上传通道。",
|
||||||
|
"consumes": [
|
||||||
|
"application/json"
|
||||||
|
],
|
||||||
|
"produces": [
|
||||||
|
"application/json"
|
||||||
|
],
|
||||||
|
"tags": [
|
||||||
|
"system"
|
||||||
|
],
|
||||||
|
"summary": "创建文件存储通道",
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"description": "文件存储通道",
|
||||||
|
"name": "body",
|
||||||
|
"in": "body",
|
||||||
|
"required": true,
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/store.FileStorageChannelInput"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"responses": {
|
||||||
|
"201": {
|
||||||
|
"description": "Created",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/store.FileStorageChannel"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"400": {
|
||||||
|
"description": "Bad Request",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/httpapi.ErrorEnvelope"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"401": {
|
||||||
|
"description": "Unauthorized",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/httpapi.ErrorEnvelope"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"403": {
|
||||||
|
"description": "Forbidden",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/httpapi.ErrorEnvelope"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"409": {
|
||||||
|
"description": "Conflict",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/httpapi.ErrorEnvelope"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"500": {
|
||||||
|
"description": "Internal Server Error",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/httpapi.ErrorEnvelope"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"/api/admin/system/file-storage/channels/{channelID}": {
|
||||||
|
"delete": {
|
||||||
|
"security": [
|
||||||
|
{
|
||||||
|
"BearerAuth": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"description": "软删除指定文件存储通道。",
|
||||||
|
"produces": [
|
||||||
|
"application/json"
|
||||||
|
],
|
||||||
|
"tags": [
|
||||||
|
"system"
|
||||||
|
],
|
||||||
|
"summary": "删除文件存储通道",
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"description": "文件存储通道 ID",
|
||||||
|
"name": "channelID",
|
||||||
|
"in": "path",
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"responses": {
|
||||||
|
"204": {
|
||||||
|
"description": "No Content"
|
||||||
|
},
|
||||||
|
"401": {
|
||||||
|
"description": "Unauthorized",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/httpapi.ErrorEnvelope"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"403": {
|
||||||
|
"description": "Forbidden",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/httpapi.ErrorEnvelope"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"404": {
|
||||||
|
"description": "Not Found",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/httpapi.ErrorEnvelope"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"500": {
|
||||||
|
"description": "Internal Server Error",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/httpapi.ErrorEnvelope"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"patch": {
|
||||||
|
"security": [
|
||||||
|
{
|
||||||
|
"BearerAuth": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"description": "更新指定文件存储通道的名称、凭证、场景、优先级、状态和重试策略。",
|
||||||
|
"consumes": [
|
||||||
|
"application/json"
|
||||||
|
],
|
||||||
|
"produces": [
|
||||||
|
"application/json"
|
||||||
|
],
|
||||||
|
"tags": [
|
||||||
|
"system"
|
||||||
|
],
|
||||||
|
"summary": "更新文件存储通道",
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"description": "文件存储通道 ID",
|
||||||
|
"name": "channelID",
|
||||||
|
"in": "path",
|
||||||
|
"required": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"description": "文件存储通道",
|
||||||
|
"name": "body",
|
||||||
|
"in": "body",
|
||||||
|
"required": true,
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/store.FileStorageChannelInput"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "OK",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/store.FileStorageChannel"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"400": {
|
||||||
|
"description": "Bad Request",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/httpapi.ErrorEnvelope"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"401": {
|
||||||
|
"description": "Unauthorized",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/httpapi.ErrorEnvelope"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"403": {
|
||||||
|
"description": "Forbidden",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/httpapi.ErrorEnvelope"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"404": {
|
||||||
|
"description": "Not Found",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/httpapi.ErrorEnvelope"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"409": {
|
||||||
|
"description": "Conflict",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/httpapi.ErrorEnvelope"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"500": {
|
||||||
|
"description": "Internal Server Error",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/httpapi.ErrorEnvelope"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"/api/admin/system/file-storage/settings": {
|
||||||
|
"get": {
|
||||||
|
"security": [
|
||||||
|
{
|
||||||
|
"BearerAuth": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"description": "返回文件存储系统设置;数据库对象尚未创建时返回默认设置。",
|
||||||
|
"produces": [
|
||||||
|
"application/json"
|
||||||
|
],
|
||||||
|
"tags": [
|
||||||
|
"system"
|
||||||
|
],
|
||||||
|
"summary": "获取文件存储设置",
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "OK",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/store.FileStorageSettings"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"401": {
|
||||||
|
"description": "Unauthorized",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/httpapi.ErrorEnvelope"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"403": {
|
||||||
|
"description": "Forbidden",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/httpapi.ErrorEnvelope"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"500": {
|
||||||
|
"description": "Internal Server Error",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/httpapi.ErrorEnvelope"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"patch": {
|
||||||
|
"security": [
|
||||||
|
{
|
||||||
|
"BearerAuth": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"description": "更新生成资源上传策略等文件存储系统设置。",
|
||||||
|
"consumes": [
|
||||||
|
"application/json"
|
||||||
|
],
|
||||||
|
"produces": [
|
||||||
|
"application/json"
|
||||||
|
],
|
||||||
|
"tags": [
|
||||||
|
"system"
|
||||||
|
],
|
||||||
|
"summary": "更新文件存储设置",
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"description": "文件存储设置",
|
||||||
|
"name": "body",
|
||||||
|
"in": "body",
|
||||||
|
"required": true,
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/store.FileStorageSettingsInput"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "OK",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/store.FileStorageSettings"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"400": {
|
||||||
|
"description": "Bad Request",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/httpapi.ErrorEnvelope"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"401": {
|
||||||
|
"description": "Unauthorized",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/httpapi.ErrorEnvelope"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"403": {
|
||||||
|
"description": "Forbidden",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/httpapi.ErrorEnvelope"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"500": {
|
||||||
|
"description": "Internal Server Error",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/httpapi.ErrorEnvelope"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"/api/admin/tenants": {
|
"/api/admin/tenants": {
|
||||||
"get": {
|
"get": {
|
||||||
"security": [
|
"security": [
|
||||||
@ -3651,26 +4000,27 @@
|
|||||||
"BearerAuth": []
|
"BearerAuth": []
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"description": "网关任务接口按 model 选择平台模型;/api/v1 路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或 SSE 流。",
|
"description": "/api/v1/chat/completions 同步执行:stream=true 返回 text/event-stream SSE;stream=false 或未传返回兼容 JSON;该接口忽略 X-Async。",
|
||||||
"consumes": [
|
"consumes": [
|
||||||
"application/json"
|
"application/json"
|
||||||
],
|
],
|
||||||
"produces": [
|
"produces": [
|
||||||
"application/json"
|
"application/json",
|
||||||
|
"text/event-stream"
|
||||||
],
|
],
|
||||||
"tags": [
|
"tags": [
|
||||||
"tasks"
|
"tasks"
|
||||||
],
|
],
|
||||||
"summary": "创建或执行 AI 任务",
|
"summary": "创建 Chat Completions",
|
||||||
"parameters": [
|
"parameters": [
|
||||||
{
|
{
|
||||||
"type": "boolean",
|
"type": "boolean",
|
||||||
"description": "true 时异步创建任务并返回 202",
|
"description": "该接口忽略此参数",
|
||||||
"name": "X-Async",
|
"name": "X-Async",
|
||||||
"in": "header"
|
"in": "header"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"description": "AI 任务请求,字段随任务类型变化",
|
"description": "Chat Completions 请求",
|
||||||
"name": "input",
|
"name": "input",
|
||||||
"in": "body",
|
"in": "body",
|
||||||
"required": true,
|
"required": true,
|
||||||
@ -3683,13 +4033,7 @@
|
|||||||
"200": {
|
"200": {
|
||||||
"description": "OK",
|
"description": "OK",
|
||||||
"schema": {
|
"schema": {
|
||||||
"$ref": "#/definitions/httpapi.CompatibleResponse"
|
"$ref": "#/definitions/httpapi.ChatCompletionCompatibleResponse"
|
||||||
}
|
|
||||||
},
|
|
||||||
"202": {
|
|
||||||
"description": "Accepted",
|
|
||||||
"schema": {
|
|
||||||
"$ref": "#/definitions/httpapi.TaskAcceptedResponse"
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"400": {
|
"400": {
|
||||||
@ -3737,6 +4081,74 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/api/v1/files/upload": {
|
||||||
|
"post": {
|
||||||
|
"security": [
|
||||||
|
{
|
||||||
|
"BearerAuth": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"description": "上传文件到配置的文件存储通道;没有启用通道时回退到本地静态上传目录。单文件最大 256MiB。",
|
||||||
|
"consumes": [
|
||||||
|
"multipart/form-data"
|
||||||
|
],
|
||||||
|
"produces": [
|
||||||
|
"application/json"
|
||||||
|
],
|
||||||
|
"tags": [
|
||||||
|
"files"
|
||||||
|
],
|
||||||
|
"summary": "上传文件",
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"type": "file",
|
||||||
|
"description": "要上传的文件",
|
||||||
|
"name": "file",
|
||||||
|
"in": "formData",
|
||||||
|
"required": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"default": "ai-gateway-openapi",
|
||||||
|
"description": "上传来源标识",
|
||||||
|
"name": "source",
|
||||||
|
"in": "formData"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "OK",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/httpapi.FileUploadResponse"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"400": {
|
||||||
|
"description": "Bad Request",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/httpapi.ErrorEnvelope"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"401": {
|
||||||
|
"description": "Unauthorized",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/httpapi.ErrorEnvelope"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"502": {
|
||||||
|
"description": "Bad Gateway",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/httpapi.ErrorEnvelope"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"503": {
|
||||||
|
"description": "Service Unavailable",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/httpapi.ErrorEnvelope"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"/api/v1/images/edits": {
|
"/api/v1/images/edits": {
|
||||||
"post": {
|
"post": {
|
||||||
"security": [
|
"security": [
|
||||||
@ -3744,7 +4156,7 @@
|
|||||||
"BearerAuth": []
|
"BearerAuth": []
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"description": "网关任务接口按 model 选择平台模型;/api/v1 路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或 SSE 流。",
|
"description": "网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或 SSE 流。",
|
||||||
"consumes": [
|
"consumes": [
|
||||||
"application/json"
|
"application/json"
|
||||||
],
|
],
|
||||||
@ -3837,7 +4249,7 @@
|
|||||||
"BearerAuth": []
|
"BearerAuth": []
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"description": "网关任务接口按 model 选择平台模型;/api/v1 路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或 SSE 流。",
|
"description": "网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或 SSE 流。",
|
||||||
"consumes": [
|
"consumes": [
|
||||||
"application/json"
|
"application/json"
|
||||||
],
|
],
|
||||||
@ -4236,7 +4648,7 @@
|
|||||||
"BearerAuth": []
|
"BearerAuth": []
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"description": "网关任务接口按 model 选择平台模型;/api/v1 路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或 SSE 流。",
|
"description": "网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或 SSE 流。",
|
||||||
"consumes": [
|
"consumes": [
|
||||||
"application/json"
|
"application/json"
|
||||||
],
|
],
|
||||||
@ -4568,7 +4980,7 @@
|
|||||||
"BearerAuth": []
|
"BearerAuth": []
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"description": "网关任务接口按 model 选择平台模型;/api/v1 路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或 SSE 流。",
|
"description": "网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或 SSE 流。",
|
||||||
"consumes": [
|
"consumes": [
|
||||||
"application/json"
|
"application/json"
|
||||||
],
|
],
|
||||||
@ -5035,7 +5447,7 @@
|
|||||||
"BearerAuth": []
|
"BearerAuth": []
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"description": "网关任务接口按 model 选择平台模型;/api/v1 路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或 SSE 流。",
|
"description": "网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或 SSE 流。",
|
||||||
"consumes": [
|
"consumes": [
|
||||||
"application/json"
|
"application/json"
|
||||||
],
|
],
|
||||||
@ -5148,7 +5560,7 @@
|
|||||||
"BearerAuth": []
|
"BearerAuth": []
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"description": "网关任务接口按 model 选择平台模型;/api/v1 路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或 SSE 流。",
|
"description": "网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或 SSE 流。",
|
||||||
"consumes": [
|
"consumes": [
|
||||||
"application/json"
|
"application/json"
|
||||||
],
|
],
|
||||||
@ -5241,7 +5653,7 @@
|
|||||||
"BearerAuth": []
|
"BearerAuth": []
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"description": "网关任务接口按 model 选择平台模型;/api/v1 路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或 SSE 流。",
|
"description": "网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或 SSE 流。",
|
||||||
"consumes": [
|
"consumes": [
|
||||||
"application/json"
|
"application/json"
|
||||||
],
|
],
|
||||||
@ -5360,7 +5772,7 @@
|
|||||||
"BearerAuth": []
|
"BearerAuth": []
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"description": "网关任务接口按 model 选择平台模型;/api/v1 路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或 SSE 流。",
|
"description": "网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或 SSE 流。",
|
||||||
"consumes": [
|
"consumes": [
|
||||||
"application/json"
|
"application/json"
|
||||||
],
|
],
|
||||||
@ -5446,6 +5858,41 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/static/generated/{asset}": {
|
||||||
|
"get": {
|
||||||
|
"description": "从本地生成资源目录读取图片、视频等任务产物;不存在时返回 404。",
|
||||||
|
"produces": [
|
||||||
|
"application/octet-stream"
|
||||||
|
],
|
||||||
|
"tags": [
|
||||||
|
"static"
|
||||||
|
],
|
||||||
|
"summary": "获取本地生成资源",
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"description": "资源文件名",
|
||||||
|
"name": "asset",
|
||||||
|
"in": "path",
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "OK",
|
||||||
|
"schema": {
|
||||||
|
"type": "file"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"404": {
|
||||||
|
"description": "Not Found",
|
||||||
|
"schema": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"/static/simulation/{asset}": {
|
"/static/simulation/{asset}": {
|
||||||
"get": {
|
"get": {
|
||||||
"description": "返回本地模拟模式使用的图片、视频封面或短视频资源。",
|
"description": "返回本地模拟模式使用的图片、视频封面或短视频资源。",
|
||||||
@ -5482,6 +5929,41 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/static/uploaded/{asset}": {
|
||||||
|
"get": {
|
||||||
|
"description": "从本地上传资源目录读取用户上传文件;不存在时返回 404。",
|
||||||
|
"produces": [
|
||||||
|
"application/octet-stream"
|
||||||
|
],
|
||||||
|
"tags": [
|
||||||
|
"static"
|
||||||
|
],
|
||||||
|
"summary": "获取本地上传资源",
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"description": "资源文件名",
|
||||||
|
"name": "asset",
|
||||||
|
"in": "path",
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "OK",
|
||||||
|
"schema": {
|
||||||
|
"type": "file"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"404": {
|
||||||
|
"description": "Not Found",
|
||||||
|
"schema": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"/v1/chat/completions": {
|
"/v1/chat/completions": {
|
||||||
"post": {
|
"post": {
|
||||||
"security": [
|
"security": [
|
||||||
@ -5489,7 +5971,7 @@
|
|||||||
"BearerAuth": []
|
"BearerAuth": []
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"description": "网关任务接口按 model 选择平台模型;/api/v1 路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或 SSE 流。",
|
"description": "网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或 SSE 流。",
|
||||||
"consumes": [
|
"consumes": [
|
||||||
"application/json"
|
"application/json"
|
||||||
],
|
],
|
||||||
@ -5575,6 +6057,74 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/v1/files/upload": {
|
||||||
|
"post": {
|
||||||
|
"security": [
|
||||||
|
{
|
||||||
|
"BearerAuth": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"description": "上传文件到配置的文件存储通道;没有启用通道时回退到本地静态上传目录。单文件最大 256MiB。",
|
||||||
|
"consumes": [
|
||||||
|
"multipart/form-data"
|
||||||
|
],
|
||||||
|
"produces": [
|
||||||
|
"application/json"
|
||||||
|
],
|
||||||
|
"tags": [
|
||||||
|
"files"
|
||||||
|
],
|
||||||
|
"summary": "上传文件",
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"type": "file",
|
||||||
|
"description": "要上传的文件",
|
||||||
|
"name": "file",
|
||||||
|
"in": "formData",
|
||||||
|
"required": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"default": "ai-gateway-openapi",
|
||||||
|
"description": "上传来源标识",
|
||||||
|
"name": "source",
|
||||||
|
"in": "formData"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "OK",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/httpapi.FileUploadResponse"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"400": {
|
||||||
|
"description": "Bad Request",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/httpapi.ErrorEnvelope"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"401": {
|
||||||
|
"description": "Unauthorized",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/httpapi.ErrorEnvelope"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"502": {
|
||||||
|
"description": "Bad Gateway",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/httpapi.ErrorEnvelope"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"503": {
|
||||||
|
"description": "Service Unavailable",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/httpapi.ErrorEnvelope"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"/v1/images/edits": {
|
"/v1/images/edits": {
|
||||||
"post": {
|
"post": {
|
||||||
"security": [
|
"security": [
|
||||||
@ -5582,7 +6132,7 @@
|
|||||||
"BearerAuth": []
|
"BearerAuth": []
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"description": "网关任务接口按 model 选择平台模型;/api/v1 路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或 SSE 流。",
|
"description": "网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或 SSE 流。",
|
||||||
"consumes": [
|
"consumes": [
|
||||||
"application/json"
|
"application/json"
|
||||||
],
|
],
|
||||||
@ -5675,7 +6225,7 @@
|
|||||||
"BearerAuth": []
|
"BearerAuth": []
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"description": "网关任务接口按 model 选择平台模型;/api/v1 路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或 SSE 流。",
|
"description": "网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或 SSE 流。",
|
||||||
"consumes": [
|
"consumes": [
|
||||||
"application/json"
|
"application/json"
|
||||||
],
|
],
|
||||||
@ -5768,7 +6318,7 @@
|
|||||||
"BearerAuth": []
|
"BearerAuth": []
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"description": "网关任务接口按 model 选择平台模型;/api/v1 路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或 SSE 流。",
|
"description": "网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或 SSE 流。",
|
||||||
"consumes": [
|
"consumes": [
|
||||||
"application/json"
|
"application/json"
|
||||||
],
|
],
|
||||||
@ -5996,6 +6546,82 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"httpapi.ChatCompletionChoice": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"finish_reason": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "stop"
|
||||||
|
},
|
||||||
|
"index": {
|
||||||
|
"type": "integer",
|
||||||
|
"example": 0
|
||||||
|
},
|
||||||
|
"message": {
|
||||||
|
"$ref": "#/definitions/httpapi.ChatCompletionChoiceMessage"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"httpapi.ChatCompletionChoiceMessage": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"content": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "Hello"
|
||||||
|
},
|
||||||
|
"role": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "assistant"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"httpapi.ChatCompletionCompatibleResponse": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"choices": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/definitions/httpapi.ChatCompletionChoice"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"created": {
|
||||||
|
"type": "integer",
|
||||||
|
"example": 1710000000
|
||||||
|
},
|
||||||
|
"id": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "chatcmpl-123"
|
||||||
|
},
|
||||||
|
"model": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "gpt-4o-mini"
|
||||||
|
},
|
||||||
|
"object": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "chat.completion"
|
||||||
|
},
|
||||||
|
"usage": {
|
||||||
|
"$ref": "#/definitions/httpapi.ChatCompletionUsage"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"httpapi.ChatCompletionUsage": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"completion_tokens": {
|
||||||
|
"type": "integer",
|
||||||
|
"example": 8
|
||||||
|
},
|
||||||
|
"prompt_tokens": {
|
||||||
|
"type": "integer",
|
||||||
|
"example": 12
|
||||||
|
},
|
||||||
|
"total_tokens": {
|
||||||
|
"type": "integer",
|
||||||
|
"example": 20
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"httpapi.ChatMessage": {
|
"httpapi.ChatMessage": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
@ -6062,6 +6688,46 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"httpapi.FileStorageChannelListResponse": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"items": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/definitions/store.FileStorageChannel"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"httpapi.FileUploadResponse": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"assetStorage": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": true
|
||||||
|
},
|
||||||
|
"contentType": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "image/png"
|
||||||
|
},
|
||||||
|
"filename": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "image.png"
|
||||||
|
},
|
||||||
|
"id": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "file_abc123"
|
||||||
|
},
|
||||||
|
"size": {
|
||||||
|
"type": "integer",
|
||||||
|
"example": 1024
|
||||||
|
},
|
||||||
|
"url": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "/static/uploaded/upload-abc123.png"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"httpapi.HealthResponse": {
|
"httpapi.HealthResponse": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
@ -6615,6 +7281,11 @@
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"example": "A watercolor robot reading a book"
|
"example": "A watercolor robot reading a book"
|
||||||
},
|
},
|
||||||
|
"reasoning_effort": {
|
||||||
|
"description": "ReasoningEffort 推理深度,OpenAI-compatible 请求字段;开放字符串,取值随 provider 和模型能力而定,常见值为 none、minimal、low、medium、high、xhigh,也可配置 max 等供应商自定义值。",
|
||||||
|
"type": "string",
|
||||||
|
"example": "medium"
|
||||||
|
},
|
||||||
"resolution": {
|
"resolution": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"example": "720p"
|
"example": "720p"
|
||||||
@ -7407,6 +8078,121 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"store.FileStorageChannel": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"channelKey": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"config": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {}
|
||||||
|
},
|
||||||
|
"createdAt": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"credentialsPreview": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {}
|
||||||
|
},
|
||||||
|
"id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"lastError": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"lastFailedAt": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"lastSucceededAt": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"name": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"priority": {
|
||||||
|
"type": "integer"
|
||||||
|
},
|
||||||
|
"provider": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"retryPolicy": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {}
|
||||||
|
},
|
||||||
|
"scenes": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"status": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"updatedAt": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"uploadUrl": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"store.FileStorageChannelInput": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"apiKey": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"channelKey": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"config": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {}
|
||||||
|
},
|
||||||
|
"name": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"priority": {
|
||||||
|
"type": "integer"
|
||||||
|
},
|
||||||
|
"provider": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"retryPolicy": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {}
|
||||||
|
},
|
||||||
|
"scenes": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"status": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"uploadUrl": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"store.FileStorageSettings": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"resultUploadPolicy": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"store.FileStorageSettingsInput": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"resultUploadPolicy": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"store.GatewayTask": {
|
"store.GatewayTask": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
|||||||
@ -92,6 +92,59 @@ definitions:
|
|||||||
$ref: '#/definitions/store.CatalogProvider'
|
$ref: '#/definitions/store.CatalogProvider'
|
||||||
type: array
|
type: array
|
||||||
type: object
|
type: object
|
||||||
|
httpapi.ChatCompletionChoice:
|
||||||
|
properties:
|
||||||
|
finish_reason:
|
||||||
|
example: stop
|
||||||
|
type: string
|
||||||
|
index:
|
||||||
|
example: 0
|
||||||
|
type: integer
|
||||||
|
message:
|
||||||
|
$ref: '#/definitions/httpapi.ChatCompletionChoiceMessage'
|
||||||
|
type: object
|
||||||
|
httpapi.ChatCompletionChoiceMessage:
|
||||||
|
properties:
|
||||||
|
content:
|
||||||
|
example: Hello
|
||||||
|
type: string
|
||||||
|
role:
|
||||||
|
example: assistant
|
||||||
|
type: string
|
||||||
|
type: object
|
||||||
|
httpapi.ChatCompletionCompatibleResponse:
|
||||||
|
properties:
|
||||||
|
choices:
|
||||||
|
items:
|
||||||
|
$ref: '#/definitions/httpapi.ChatCompletionChoice'
|
||||||
|
type: array
|
||||||
|
created:
|
||||||
|
example: 1710000000
|
||||||
|
type: integer
|
||||||
|
id:
|
||||||
|
example: chatcmpl-123
|
||||||
|
type: string
|
||||||
|
model:
|
||||||
|
example: gpt-4o-mini
|
||||||
|
type: string
|
||||||
|
object:
|
||||||
|
example: chat.completion
|
||||||
|
type: string
|
||||||
|
usage:
|
||||||
|
$ref: '#/definitions/httpapi.ChatCompletionUsage'
|
||||||
|
type: object
|
||||||
|
httpapi.ChatCompletionUsage:
|
||||||
|
properties:
|
||||||
|
completion_tokens:
|
||||||
|
example: 8
|
||||||
|
type: integer
|
||||||
|
prompt_tokens:
|
||||||
|
example: 12
|
||||||
|
type: integer
|
||||||
|
total_tokens:
|
||||||
|
example: 20
|
||||||
|
type: integer
|
||||||
|
type: object
|
||||||
httpapi.ChatMessage:
|
httpapi.ChatMessage:
|
||||||
properties:
|
properties:
|
||||||
content:
|
content:
|
||||||
@ -138,6 +191,34 @@ definitions:
|
|||||||
example: 400
|
example: 400
|
||||||
type: integer
|
type: integer
|
||||||
type: object
|
type: object
|
||||||
|
httpapi.FileStorageChannelListResponse:
|
||||||
|
properties:
|
||||||
|
items:
|
||||||
|
items:
|
||||||
|
$ref: '#/definitions/store.FileStorageChannel'
|
||||||
|
type: array
|
||||||
|
type: object
|
||||||
|
httpapi.FileUploadResponse:
|
||||||
|
properties:
|
||||||
|
assetStorage:
|
||||||
|
additionalProperties: true
|
||||||
|
type: object
|
||||||
|
contentType:
|
||||||
|
example: image/png
|
||||||
|
type: string
|
||||||
|
filename:
|
||||||
|
example: image.png
|
||||||
|
type: string
|
||||||
|
id:
|
||||||
|
example: file_abc123
|
||||||
|
type: string
|
||||||
|
size:
|
||||||
|
example: 1024
|
||||||
|
type: integer
|
||||||
|
url:
|
||||||
|
example: /static/uploaded/upload-abc123.png
|
||||||
|
type: string
|
||||||
|
type: object
|
||||||
httpapi.HealthResponse:
|
httpapi.HealthResponse:
|
||||||
properties:
|
properties:
|
||||||
env:
|
env:
|
||||||
@ -506,6 +587,11 @@ definitions:
|
|||||||
prompt:
|
prompt:
|
||||||
example: A watercolor robot reading a book
|
example: A watercolor robot reading a book
|
||||||
type: string
|
type: string
|
||||||
|
reasoning_effort:
|
||||||
|
description: ReasoningEffort 推理深度,OpenAI-compatible 请求字段;开放字符串,取值随 provider
|
||||||
|
和模型能力而定,常见值为 none、minimal、low、medium、high、xhigh,也可配置 max 等供应商自定义值。
|
||||||
|
example: medium
|
||||||
|
type: string
|
||||||
resolution:
|
resolution:
|
||||||
example: 720p
|
example: 720p
|
||||||
type: string
|
type: string
|
||||||
@ -1045,6 +1131,83 @@ definitions:
|
|||||||
secret:
|
secret:
|
||||||
type: string
|
type: string
|
||||||
type: object
|
type: object
|
||||||
|
store.FileStorageChannel:
|
||||||
|
properties:
|
||||||
|
channelKey:
|
||||||
|
type: string
|
||||||
|
config:
|
||||||
|
additionalProperties: {}
|
||||||
|
type: object
|
||||||
|
createdAt:
|
||||||
|
type: string
|
||||||
|
credentialsPreview:
|
||||||
|
additionalProperties: {}
|
||||||
|
type: object
|
||||||
|
id:
|
||||||
|
type: string
|
||||||
|
lastError:
|
||||||
|
type: string
|
||||||
|
lastFailedAt:
|
||||||
|
type: string
|
||||||
|
lastSucceededAt:
|
||||||
|
type: string
|
||||||
|
name:
|
||||||
|
type: string
|
||||||
|
priority:
|
||||||
|
type: integer
|
||||||
|
provider:
|
||||||
|
type: string
|
||||||
|
retryPolicy:
|
||||||
|
additionalProperties: {}
|
||||||
|
type: object
|
||||||
|
scenes:
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
type: array
|
||||||
|
status:
|
||||||
|
type: string
|
||||||
|
updatedAt:
|
||||||
|
type: string
|
||||||
|
uploadUrl:
|
||||||
|
type: string
|
||||||
|
type: object
|
||||||
|
store.FileStorageChannelInput:
|
||||||
|
properties:
|
||||||
|
apiKey:
|
||||||
|
type: string
|
||||||
|
channelKey:
|
||||||
|
type: string
|
||||||
|
config:
|
||||||
|
additionalProperties: {}
|
||||||
|
type: object
|
||||||
|
name:
|
||||||
|
type: string
|
||||||
|
priority:
|
||||||
|
type: integer
|
||||||
|
provider:
|
||||||
|
type: string
|
||||||
|
retryPolicy:
|
||||||
|
additionalProperties: {}
|
||||||
|
type: object
|
||||||
|
scenes:
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
type: array
|
||||||
|
status:
|
||||||
|
type: string
|
||||||
|
uploadUrl:
|
||||||
|
type: string
|
||||||
|
type: object
|
||||||
|
store.FileStorageSettings:
|
||||||
|
properties:
|
||||||
|
resultUploadPolicy:
|
||||||
|
type: string
|
||||||
|
type: object
|
||||||
|
store.FileStorageSettingsInput:
|
||||||
|
properties:
|
||||||
|
resultUploadPolicy:
|
||||||
|
type: string
|
||||||
|
type: object
|
||||||
store.GatewayTask:
|
store.GatewayTask:
|
||||||
properties:
|
properties:
|
||||||
apiKeyId:
|
apiKeyId:
|
||||||
@ -3644,6 +3807,229 @@ paths:
|
|||||||
summary: 更新 Runner 策略
|
summary: 更新 Runner 策略
|
||||||
tags:
|
tags:
|
||||||
- runtime
|
- runtime
|
||||||
|
/api/admin/system/file-storage/channels:
|
||||||
|
get:
|
||||||
|
description: 返回所有未删除的文件存储通道,用于管理上传与生成资源回传策略。
|
||||||
|
produces:
|
||||||
|
- application/json
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: OK
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/httpapi.FileStorageChannelListResponse'
|
||||||
|
"401":
|
||||||
|
description: Unauthorized
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/httpapi.ErrorEnvelope'
|
||||||
|
"403":
|
||||||
|
description: Forbidden
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/httpapi.ErrorEnvelope'
|
||||||
|
"500":
|
||||||
|
description: Internal Server Error
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/httpapi.ErrorEnvelope'
|
||||||
|
security:
|
||||||
|
- BearerAuth: []
|
||||||
|
summary: 列出文件存储通道
|
||||||
|
tags:
|
||||||
|
- system
|
||||||
|
post:
|
||||||
|
consumes:
|
||||||
|
- application/json
|
||||||
|
description: 创建文件存储通道,当前主要用于配置 server-main OpenAPI 上传通道。
|
||||||
|
parameters:
|
||||||
|
- description: 文件存储通道
|
||||||
|
in: body
|
||||||
|
name: body
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/store.FileStorageChannelInput'
|
||||||
|
produces:
|
||||||
|
- application/json
|
||||||
|
responses:
|
||||||
|
"201":
|
||||||
|
description: Created
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/store.FileStorageChannel'
|
||||||
|
"400":
|
||||||
|
description: Bad Request
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/httpapi.ErrorEnvelope'
|
||||||
|
"401":
|
||||||
|
description: Unauthorized
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/httpapi.ErrorEnvelope'
|
||||||
|
"403":
|
||||||
|
description: Forbidden
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/httpapi.ErrorEnvelope'
|
||||||
|
"409":
|
||||||
|
description: Conflict
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/httpapi.ErrorEnvelope'
|
||||||
|
"500":
|
||||||
|
description: Internal Server Error
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/httpapi.ErrorEnvelope'
|
||||||
|
security:
|
||||||
|
- BearerAuth: []
|
||||||
|
summary: 创建文件存储通道
|
||||||
|
tags:
|
||||||
|
- system
|
||||||
|
/api/admin/system/file-storage/channels/{channelID}:
|
||||||
|
delete:
|
||||||
|
description: 软删除指定文件存储通道。
|
||||||
|
parameters:
|
||||||
|
- description: 文件存储通道 ID
|
||||||
|
in: path
|
||||||
|
name: channelID
|
||||||
|
required: true
|
||||||
|
type: string
|
||||||
|
produces:
|
||||||
|
- application/json
|
||||||
|
responses:
|
||||||
|
"204":
|
||||||
|
description: No Content
|
||||||
|
"401":
|
||||||
|
description: Unauthorized
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/httpapi.ErrorEnvelope'
|
||||||
|
"403":
|
||||||
|
description: Forbidden
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/httpapi.ErrorEnvelope'
|
||||||
|
"404":
|
||||||
|
description: Not Found
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/httpapi.ErrorEnvelope'
|
||||||
|
"500":
|
||||||
|
description: Internal Server Error
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/httpapi.ErrorEnvelope'
|
||||||
|
security:
|
||||||
|
- BearerAuth: []
|
||||||
|
summary: 删除文件存储通道
|
||||||
|
tags:
|
||||||
|
- system
|
||||||
|
patch:
|
||||||
|
consumes:
|
||||||
|
- application/json
|
||||||
|
description: 更新指定文件存储通道的名称、凭证、场景、优先级、状态和重试策略。
|
||||||
|
parameters:
|
||||||
|
- description: 文件存储通道 ID
|
||||||
|
in: path
|
||||||
|
name: channelID
|
||||||
|
required: true
|
||||||
|
type: string
|
||||||
|
- description: 文件存储通道
|
||||||
|
in: body
|
||||||
|
name: body
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/store.FileStorageChannelInput'
|
||||||
|
produces:
|
||||||
|
- application/json
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: OK
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/store.FileStorageChannel'
|
||||||
|
"400":
|
||||||
|
description: Bad Request
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/httpapi.ErrorEnvelope'
|
||||||
|
"401":
|
||||||
|
description: Unauthorized
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/httpapi.ErrorEnvelope'
|
||||||
|
"403":
|
||||||
|
description: Forbidden
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/httpapi.ErrorEnvelope'
|
||||||
|
"404":
|
||||||
|
description: Not Found
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/httpapi.ErrorEnvelope'
|
||||||
|
"409":
|
||||||
|
description: Conflict
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/httpapi.ErrorEnvelope'
|
||||||
|
"500":
|
||||||
|
description: Internal Server Error
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/httpapi.ErrorEnvelope'
|
||||||
|
security:
|
||||||
|
- BearerAuth: []
|
||||||
|
summary: 更新文件存储通道
|
||||||
|
tags:
|
||||||
|
- system
|
||||||
|
/api/admin/system/file-storage/settings:
|
||||||
|
get:
|
||||||
|
description: 返回文件存储系统设置;数据库对象尚未创建时返回默认设置。
|
||||||
|
produces:
|
||||||
|
- application/json
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: OK
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/store.FileStorageSettings'
|
||||||
|
"401":
|
||||||
|
description: Unauthorized
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/httpapi.ErrorEnvelope'
|
||||||
|
"403":
|
||||||
|
description: Forbidden
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/httpapi.ErrorEnvelope'
|
||||||
|
"500":
|
||||||
|
description: Internal Server Error
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/httpapi.ErrorEnvelope'
|
||||||
|
security:
|
||||||
|
- BearerAuth: []
|
||||||
|
summary: 获取文件存储设置
|
||||||
|
tags:
|
||||||
|
- system
|
||||||
|
patch:
|
||||||
|
consumes:
|
||||||
|
- application/json
|
||||||
|
description: 更新生成资源上传策略等文件存储系统设置。
|
||||||
|
parameters:
|
||||||
|
- description: 文件存储设置
|
||||||
|
in: body
|
||||||
|
name: body
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/store.FileStorageSettingsInput'
|
||||||
|
produces:
|
||||||
|
- application/json
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: OK
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/store.FileStorageSettings'
|
||||||
|
"400":
|
||||||
|
description: Bad Request
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/httpapi.ErrorEnvelope'
|
||||||
|
"401":
|
||||||
|
description: Unauthorized
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/httpapi.ErrorEnvelope'
|
||||||
|
"403":
|
||||||
|
description: Forbidden
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/httpapi.ErrorEnvelope'
|
||||||
|
"500":
|
||||||
|
description: Internal Server Error
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/httpapi.ErrorEnvelope'
|
||||||
|
security:
|
||||||
|
- BearerAuth: []
|
||||||
|
summary: 更新文件存储设置
|
||||||
|
tags:
|
||||||
|
- system
|
||||||
/api/admin/tenants:
|
/api/admin/tenants:
|
||||||
get:
|
get:
|
||||||
description: 管理端返回网关租户列表。
|
description: 管理端返回网关租户列表。
|
||||||
@ -4472,14 +4858,14 @@ paths:
|
|||||||
post:
|
post:
|
||||||
consumes:
|
consumes:
|
||||||
- application/json
|
- application/json
|
||||||
description: 网关任务接口按 model 选择平台模型;/api/v1 路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或
|
description: /api/v1/chat/completions 同步执行:stream=true 返回 text/event-stream
|
||||||
SSE 流。
|
SSE;stream=false 或未传返回兼容 JSON;该接口忽略 X-Async。
|
||||||
parameters:
|
parameters:
|
||||||
- description: true 时异步创建任务并返回 202
|
- description: 该接口忽略此参数
|
||||||
in: header
|
in: header
|
||||||
name: X-Async
|
name: X-Async
|
||||||
type: boolean
|
type: boolean
|
||||||
- description: AI 任务请求,字段随任务类型变化
|
- description: Chat Completions 请求
|
||||||
in: body
|
in: body
|
||||||
name: input
|
name: input
|
||||||
required: true
|
required: true
|
||||||
@ -4487,15 +4873,12 @@ paths:
|
|||||||
$ref: '#/definitions/httpapi.TaskRequest'
|
$ref: '#/definitions/httpapi.TaskRequest'
|
||||||
produces:
|
produces:
|
||||||
- application/json
|
- application/json
|
||||||
|
- text/event-stream
|
||||||
responses:
|
responses:
|
||||||
"200":
|
"200":
|
||||||
description: OK
|
description: OK
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/definitions/httpapi.CompatibleResponse'
|
$ref: '#/definitions/httpapi.ChatCompletionCompatibleResponse'
|
||||||
"202":
|
|
||||||
description: Accepted
|
|
||||||
schema:
|
|
||||||
$ref: '#/definitions/httpapi.TaskAcceptedResponse'
|
|
||||||
"400":
|
"400":
|
||||||
description: Bad Request
|
description: Bad Request
|
||||||
schema:
|
schema:
|
||||||
@ -4526,15 +4909,59 @@ paths:
|
|||||||
$ref: '#/definitions/httpapi.ErrorEnvelope'
|
$ref: '#/definitions/httpapi.ErrorEnvelope'
|
||||||
security:
|
security:
|
||||||
- BearerAuth: []
|
- BearerAuth: []
|
||||||
summary: 创建或执行 AI 任务
|
summary: 创建 Chat Completions
|
||||||
tags:
|
tags:
|
||||||
- tasks
|
- tasks
|
||||||
|
/api/v1/files/upload:
|
||||||
|
post:
|
||||||
|
consumes:
|
||||||
|
- multipart/form-data
|
||||||
|
description: 上传文件到配置的文件存储通道;没有启用通道时回退到本地静态上传目录。单文件最大 256MiB。
|
||||||
|
parameters:
|
||||||
|
- description: 要上传的文件
|
||||||
|
in: formData
|
||||||
|
name: file
|
||||||
|
required: true
|
||||||
|
type: file
|
||||||
|
- default: ai-gateway-openapi
|
||||||
|
description: 上传来源标识
|
||||||
|
in: formData
|
||||||
|
name: source
|
||||||
|
type: string
|
||||||
|
produces:
|
||||||
|
- application/json
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: OK
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/httpapi.FileUploadResponse'
|
||||||
|
"400":
|
||||||
|
description: Bad Request
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/httpapi.ErrorEnvelope'
|
||||||
|
"401":
|
||||||
|
description: Unauthorized
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/httpapi.ErrorEnvelope'
|
||||||
|
"502":
|
||||||
|
description: Bad Gateway
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/httpapi.ErrorEnvelope'
|
||||||
|
"503":
|
||||||
|
description: Service Unavailable
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/httpapi.ErrorEnvelope'
|
||||||
|
security:
|
||||||
|
- BearerAuth: []
|
||||||
|
summary: 上传文件
|
||||||
|
tags:
|
||||||
|
- files
|
||||||
/api/v1/images/edits:
|
/api/v1/images/edits:
|
||||||
post:
|
post:
|
||||||
consumes:
|
consumes:
|
||||||
- application/json
|
- application/json
|
||||||
description: 网关任务接口按 model 选择平台模型;/api/v1 路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或
|
description: 网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果,OpenAI-compatible
|
||||||
SSE 流。
|
路径同步返回兼容响应或 SSE 流。
|
||||||
parameters:
|
parameters:
|
||||||
- description: true 时异步创建任务并返回 202
|
- description: true 时异步创建任务并返回 202
|
||||||
in: header
|
in: header
|
||||||
@ -4594,8 +5021,8 @@ paths:
|
|||||||
post:
|
post:
|
||||||
consumes:
|
consumes:
|
||||||
- application/json
|
- application/json
|
||||||
description: 网关任务接口按 model 选择平台模型;/api/v1 路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或
|
description: 网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果,OpenAI-compatible
|
||||||
SSE 流。
|
路径同步返回兼容响应或 SSE 流。
|
||||||
parameters:
|
parameters:
|
||||||
- description: true 时异步创建任务并返回 202
|
- description: true 时异步创建任务并返回 202
|
||||||
in: header
|
in: header
|
||||||
@ -4848,8 +5275,8 @@ paths:
|
|||||||
post:
|
post:
|
||||||
consumes:
|
consumes:
|
||||||
- application/json
|
- application/json
|
||||||
description: 网关任务接口按 model 选择平台模型;/api/v1 路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或
|
description: 网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果,OpenAI-compatible
|
||||||
SSE 流。
|
路径同步返回兼容响应或 SSE 流。
|
||||||
parameters:
|
parameters:
|
||||||
- description: true 时异步创建任务并返回 202
|
- description: true 时异步创建任务并返回 202
|
||||||
in: header
|
in: header
|
||||||
@ -5062,8 +5489,8 @@ paths:
|
|||||||
post:
|
post:
|
||||||
consumes:
|
consumes:
|
||||||
- application/json
|
- application/json
|
||||||
description: 网关任务接口按 model 选择平台模型;/api/v1 路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或
|
description: 网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果,OpenAI-compatible
|
||||||
SSE 流。
|
路径同步返回兼容响应或 SSE 流。
|
||||||
parameters:
|
parameters:
|
||||||
- description: true 时异步创建任务并返回 202
|
- description: true 时异步创建任务并返回 202
|
||||||
in: header
|
in: header
|
||||||
@ -5363,8 +5790,8 @@ paths:
|
|||||||
post:
|
post:
|
||||||
consumes:
|
consumes:
|
||||||
- application/json
|
- application/json
|
||||||
description: 网关任务接口按 model 选择平台模型;/api/v1 路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或
|
description: 网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果,OpenAI-compatible
|
||||||
SSE 流。
|
路径同步返回兼容响应或 SSE 流。
|
||||||
parameters:
|
parameters:
|
||||||
- description: true 时异步创建任务并返回 202
|
- description: true 时异步创建任务并返回 202
|
||||||
in: header
|
in: header
|
||||||
@ -5437,8 +5864,8 @@ paths:
|
|||||||
post:
|
post:
|
||||||
consumes:
|
consumes:
|
||||||
- application/json
|
- application/json
|
||||||
description: 网关任务接口按 model 选择平台模型;/api/v1 路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或
|
description: 网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果,OpenAI-compatible
|
||||||
SSE 流。
|
路径同步返回兼容响应或 SSE 流。
|
||||||
parameters:
|
parameters:
|
||||||
- description: true 时异步创建任务并返回 202
|
- description: true 时异步创建任务并返回 202
|
||||||
in: header
|
in: header
|
||||||
@ -5498,8 +5925,8 @@ paths:
|
|||||||
post:
|
post:
|
||||||
consumes:
|
consumes:
|
||||||
- application/json
|
- application/json
|
||||||
description: 网关任务接口按 model 选择平台模型;/api/v1 路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或
|
description: 网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果,OpenAI-compatible
|
||||||
SSE 流。
|
路径同步返回兼容响应或 SSE 流。
|
||||||
parameters:
|
parameters:
|
||||||
- description: true 时异步创建任务并返回 202
|
- description: true 时异步创建任务并返回 202
|
||||||
in: header
|
in: header
|
||||||
@ -5576,8 +6003,8 @@ paths:
|
|||||||
post:
|
post:
|
||||||
consumes:
|
consumes:
|
||||||
- application/json
|
- application/json
|
||||||
description: 网关任务接口按 model 选择平台模型;/api/v1 路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或
|
description: 网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果,OpenAI-compatible
|
||||||
SSE 流。
|
路径同步返回兼容响应或 SSE 流。
|
||||||
parameters:
|
parameters:
|
||||||
- description: true 时异步创建任务并返回 202
|
- description: true 时异步创建任务并返回 202
|
||||||
in: header
|
in: header
|
||||||
@ -5633,6 +6060,29 @@ paths:
|
|||||||
summary: 创建或执行 AI 任务
|
summary: 创建或执行 AI 任务
|
||||||
tags:
|
tags:
|
||||||
- tasks
|
- tasks
|
||||||
|
/static/generated/{asset}:
|
||||||
|
get:
|
||||||
|
description: 从本地生成资源目录读取图片、视频等任务产物;不存在时返回 404。
|
||||||
|
parameters:
|
||||||
|
- description: 资源文件名
|
||||||
|
in: path
|
||||||
|
name: asset
|
||||||
|
required: true
|
||||||
|
type: string
|
||||||
|
produces:
|
||||||
|
- application/octet-stream
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: OK
|
||||||
|
schema:
|
||||||
|
type: file
|
||||||
|
"404":
|
||||||
|
description: Not Found
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
summary: 获取本地生成资源
|
||||||
|
tags:
|
||||||
|
- static
|
||||||
/static/simulation/{asset}:
|
/static/simulation/{asset}:
|
||||||
get:
|
get:
|
||||||
description: 返回本地模拟模式使用的图片、视频封面或短视频资源。
|
description: 返回本地模拟模式使用的图片、视频封面或短视频资源。
|
||||||
@ -5657,12 +6107,35 @@ paths:
|
|||||||
summary: 获取模拟资源
|
summary: 获取模拟资源
|
||||||
tags:
|
tags:
|
||||||
- simulation
|
- simulation
|
||||||
|
/static/uploaded/{asset}:
|
||||||
|
get:
|
||||||
|
description: 从本地上传资源目录读取用户上传文件;不存在时返回 404。
|
||||||
|
parameters:
|
||||||
|
- description: 资源文件名
|
||||||
|
in: path
|
||||||
|
name: asset
|
||||||
|
required: true
|
||||||
|
type: string
|
||||||
|
produces:
|
||||||
|
- application/octet-stream
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: OK
|
||||||
|
schema:
|
||||||
|
type: file
|
||||||
|
"404":
|
||||||
|
description: Not Found
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
summary: 获取本地上传资源
|
||||||
|
tags:
|
||||||
|
- static
|
||||||
/v1/chat/completions:
|
/v1/chat/completions:
|
||||||
post:
|
post:
|
||||||
consumes:
|
consumes:
|
||||||
- application/json
|
- application/json
|
||||||
description: 网关任务接口按 model 选择平台模型;/api/v1 路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或
|
description: 网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果,OpenAI-compatible
|
||||||
SSE 流。
|
路径同步返回兼容响应或 SSE 流。
|
||||||
parameters:
|
parameters:
|
||||||
- description: true 时异步创建任务并返回 202
|
- description: true 时异步创建任务并返回 202
|
||||||
in: header
|
in: header
|
||||||
@ -5718,12 +6191,56 @@ paths:
|
|||||||
summary: 创建或执行 AI 任务
|
summary: 创建或执行 AI 任务
|
||||||
tags:
|
tags:
|
||||||
- tasks
|
- tasks
|
||||||
|
/v1/files/upload:
|
||||||
|
post:
|
||||||
|
consumes:
|
||||||
|
- multipart/form-data
|
||||||
|
description: 上传文件到配置的文件存储通道;没有启用通道时回退到本地静态上传目录。单文件最大 256MiB。
|
||||||
|
parameters:
|
||||||
|
- description: 要上传的文件
|
||||||
|
in: formData
|
||||||
|
name: file
|
||||||
|
required: true
|
||||||
|
type: file
|
||||||
|
- default: ai-gateway-openapi
|
||||||
|
description: 上传来源标识
|
||||||
|
in: formData
|
||||||
|
name: source
|
||||||
|
type: string
|
||||||
|
produces:
|
||||||
|
- application/json
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: OK
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/httpapi.FileUploadResponse'
|
||||||
|
"400":
|
||||||
|
description: Bad Request
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/httpapi.ErrorEnvelope'
|
||||||
|
"401":
|
||||||
|
description: Unauthorized
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/httpapi.ErrorEnvelope'
|
||||||
|
"502":
|
||||||
|
description: Bad Gateway
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/httpapi.ErrorEnvelope'
|
||||||
|
"503":
|
||||||
|
description: Service Unavailable
|
||||||
|
schema:
|
||||||
|
$ref: '#/definitions/httpapi.ErrorEnvelope'
|
||||||
|
security:
|
||||||
|
- BearerAuth: []
|
||||||
|
summary: 上传文件
|
||||||
|
tags:
|
||||||
|
- files
|
||||||
/v1/images/edits:
|
/v1/images/edits:
|
||||||
post:
|
post:
|
||||||
consumes:
|
consumes:
|
||||||
- application/json
|
- application/json
|
||||||
description: 网关任务接口按 model 选择平台模型;/api/v1 路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或
|
description: 网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果,OpenAI-compatible
|
||||||
SSE 流。
|
路径同步返回兼容响应或 SSE 流。
|
||||||
parameters:
|
parameters:
|
||||||
- description: true 时异步创建任务并返回 202
|
- description: true 时异步创建任务并返回 202
|
||||||
in: header
|
in: header
|
||||||
@ -5783,8 +6300,8 @@ paths:
|
|||||||
post:
|
post:
|
||||||
consumes:
|
consumes:
|
||||||
- application/json
|
- application/json
|
||||||
description: 网关任务接口按 model 选择平台模型;/api/v1 路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或
|
description: 网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果,OpenAI-compatible
|
||||||
SSE 流。
|
路径同步返回兼容响应或 SSE 流。
|
||||||
parameters:
|
parameters:
|
||||||
- description: true 时异步创建任务并返回 202
|
- description: true 时异步创建任务并返回 202
|
||||||
in: header
|
in: header
|
||||||
@ -5844,8 +6361,8 @@ paths:
|
|||||||
post:
|
post:
|
||||||
consumes:
|
consumes:
|
||||||
- application/json
|
- application/json
|
||||||
description: 网关任务接口按 model 选择平台模型;/api/v1 路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或
|
description: 网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果,OpenAI-compatible
|
||||||
SSE 流。
|
路径同步返回兼容响应或 SSE 流。
|
||||||
parameters:
|
parameters:
|
||||||
- description: true 时异步创建任务并返回 202
|
- description: true 时异步创建任务并返回 202
|
||||||
in: header
|
in: header
|
||||||
|
|||||||
@ -13,6 +13,11 @@ require (
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
|
github.com/dlclark/regexp2 v1.11.4 // indirect
|
||||||
|
github.com/dop251/goja v0.0.0-20260311135729-065cd970411c // indirect
|
||||||
|
github.com/dop251/goja_nodejs v0.0.0-20260212111938-1f56ff5bcf14 // indirect
|
||||||
|
github.com/go-sourcemap/sourcemap v2.1.4+incompatible // indirect
|
||||||
|
github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 // indirect
|
||||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||||
|
|||||||
@ -1,8 +1,18 @@
|
|||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo=
|
||||||
|
github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||||
|
github.com/dop251/goja v0.0.0-20260311135729-065cd970411c h1:OcLmPfx1T1RmZVHHFwWMPaZDdRf0DBMZOFMVWJa7Pdk=
|
||||||
|
github.com/dop251/goja v0.0.0-20260311135729-065cd970411c/go.mod h1:MxLav0peU43GgvwVgNbLAj1s/bSGboKkhuULvq/7hx4=
|
||||||
|
github.com/dop251/goja_nodejs v0.0.0-20260212111938-1f56ff5bcf14 h1:3U8dTgyNBhEQ/GVw0jZW5q+93Zw2gAZPRWhJ9TwV3rM=
|
||||||
|
github.com/dop251/goja_nodejs v0.0.0-20260212111938-1f56ff5bcf14/go.mod h1:Tb7Xxye4LX7cT3i8YLvmPMGCV92IOi4CDZvm/V8ylc0=
|
||||||
|
github.com/go-sourcemap/sourcemap v2.1.4+incompatible h1:a+iTbH5auLKxaNwQFg0B+TCYl6lbukKPc7b5x0n1s6Q=
|
||||||
|
github.com/go-sourcemap/sourcemap v2.1.4+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg=
|
||||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||||
|
github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 h1:FKHo8hFI3A+7w0aUQuYXQ+6EN5stWmeY/AZqtM8xk9k=
|
||||||
|
github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8/go.mod h1:K1liHPHnj73Fdn/EKuT8nrFqBihUSKXoLYU0BuatOYo=
|
||||||
github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438 h1:Dj0L5fhJ9F82ZJyVOmBx6msDp/kfd1t9GRfny/mfJA0=
|
github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438 h1:Dj0L5fhJ9F82ZJyVOmBx6msDp/kfd1t9GRfny/mfJA0=
|
||||||
github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438/go.mod h1:a/s9Lp5W7n/DD0VrVoyJ00FbP2ytTPDVOivvn2bMlds=
|
github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438/go.mod h1:a/s9Lp5W7n/DD0VrVoyJ00FbP2ytTPDVOivvn2bMlds=
|
||||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||||
|
|||||||
@ -110,6 +110,7 @@ func TestOpenAIClientChatContract(t *testing.T) {
|
|||||||
t.Fatalf("decode request: %v", err)
|
t.Fatalf("decode request: %v", err)
|
||||||
}
|
}
|
||||||
gotModel, _ = body["model"].(string)
|
gotModel, _ = body["model"].(string)
|
||||||
|
time.Sleep(25 * time.Millisecond)
|
||||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
"id": "chatcmpl-test",
|
"id": "chatcmpl-test",
|
||||||
"object": "chat.completion",
|
"object": "chat.completion",
|
||||||
@ -145,6 +146,190 @@ func TestOpenAIClientChatContract(t *testing.T) {
|
|||||||
if response.RequestID != "req-chat-test" || response.ResponseStartedAt.IsZero() || response.ResponseFinishedAt.IsZero() {
|
if response.RequestID != "req-chat-test" || response.ResponseStartedAt.IsZero() || response.ResponseFinishedAt.IsZero() {
|
||||||
t.Fatalf("response metadata was not captured: %+v", response)
|
t.Fatalf("response metadata was not captured: %+v", response)
|
||||||
}
|
}
|
||||||
|
if response.ResponseDurationMS < 20 {
|
||||||
|
t.Fatalf("response duration should include upstream latency, got %dms", response.ResponseDurationMS)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIClientChatRequestNormalizesToolContext(t *testing.T) {
|
||||||
|
var captured map[string]any
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&captured); err != nil {
|
||||||
|
t.Fatalf("decode request: %v", err)
|
||||||
|
}
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"id": "chatcmpl-normalized-request",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"model": captured["model"],
|
||||||
|
"choices": []any{map[string]any{
|
||||||
|
"message": map[string]any{"role": "assistant", "content": "ok"},
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
_, err := (OpenAIClient{HTTPClient: server.Client()}).Run(context.Background(), Request{
|
||||||
|
Kind: "chat.completions",
|
||||||
|
Model: "openai:gpt-4o-mini",
|
||||||
|
Body: map[string]any{
|
||||||
|
"model": "openai:gpt-4o-mini",
|
||||||
|
"messages": []any{
|
||||||
|
map[string]any{
|
||||||
|
"role": "assistant",
|
||||||
|
"functionCall": map[string]any{
|
||||||
|
"name": "lookup",
|
||||||
|
"arguments": map[string]any{"q": "weather"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
map[string]any{"role": "tool", "toolCallId": "call_0", "content": "sunny"},
|
||||||
|
map[string]any{
|
||||||
|
"role": "user",
|
||||||
|
"content": []any{
|
||||||
|
map[string]any{"type": "text", "text": "keep this"},
|
||||||
|
map[string]any{"type": "tool_result", "tool_use_id": "toolu_1", "content": map[string]any{"ok": true}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Candidate: store.RuntimeModelCandidate{
|
||||||
|
BaseURL: server.URL,
|
||||||
|
ProviderModelName: "openai-compatible-gpt-4o-mini",
|
||||||
|
Credentials: map[string]any{"apiKey": "test-key"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("run openai client: %v", err)
|
||||||
|
}
|
||||||
|
messages, _ := captured["messages"].([]any)
|
||||||
|
if len(messages) != 4 {
|
||||||
|
t.Fatalf("unexpected normalized messages: %+v", messages)
|
||||||
|
}
|
||||||
|
assistant, _ := messages[0].(map[string]any)
|
||||||
|
if _, ok := assistant["functionCall"]; ok {
|
||||||
|
t.Fatalf("functionCall should be converted away: %+v", assistant)
|
||||||
|
}
|
||||||
|
toolCalls, _ := assistant["tool_calls"].([]any)
|
||||||
|
toolCall, _ := toolCalls[0].(map[string]any)
|
||||||
|
function, _ := toolCall["function"].(map[string]any)
|
||||||
|
if function["name"] != "lookup" || function["arguments"] != `{"q":"weather"}` {
|
||||||
|
t.Fatalf("unexpected normalized tool call: %+v", assistant)
|
||||||
|
}
|
||||||
|
toolMessage, _ := messages[1].(map[string]any)
|
||||||
|
if toolMessage["tool_call_id"] != "call_0" || toolMessage["toolCallId"] != nil {
|
||||||
|
t.Fatalf("tool message was not normalized: %+v", toolMessage)
|
||||||
|
}
|
||||||
|
keptUser, _ := messages[2].(map[string]any)
|
||||||
|
convertedToolResult, _ := messages[3].(map[string]any)
|
||||||
|
if keptUser["content"] != "keep this" || convertedToolResult["role"] != "tool" || convertedToolResult["tool_call_id"] != "toolu_1" || convertedToolResult["content"] != `{"ok":true}` {
|
||||||
|
t.Fatalf("tool_result block was not restored: user=%+v tool=%+v", keptUser, convertedToolResult)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIClientChatResponseNormalizesReasoning(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"id": "chatcmpl-reasoning",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"model": "openrouter-test",
|
||||||
|
"choices": []any{map[string]any{
|
||||||
|
"message": map[string]any{
|
||||||
|
"role": "assistant",
|
||||||
|
"reasoning_details": []any{
|
||||||
|
map[string]any{"type": "reasoning.text", "text": "detail-"},
|
||||||
|
map[string]any{"type": "reasoning.summary", "summary": "summary"},
|
||||||
|
map[string]any{"type": "reasoning.encrypted", "data": "secret"},
|
||||||
|
},
|
||||||
|
"content": "<think>tagged</think>answer",
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
response, err := (OpenAIClient{HTTPClient: server.Client()}).Run(context.Background(), Request{
|
||||||
|
Kind: "chat.completions",
|
||||||
|
Model: "OpenRouter-Test",
|
||||||
|
Body: map[string]any{"model": "OpenRouter-Test", "messages": []any{map[string]any{"role": "user", "content": "ping"}}},
|
||||||
|
Candidate: store.RuntimeModelCandidate{
|
||||||
|
BaseURL: server.URL,
|
||||||
|
ModelName: "openrouter-test",
|
||||||
|
Credentials: map[string]any{"apiKey": "test-key"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("run openai client: %v", err)
|
||||||
|
}
|
||||||
|
choices, _ := response.Result["choices"].([]any)
|
||||||
|
choice, _ := choices[0].(map[string]any)
|
||||||
|
message, _ := choice["message"].(map[string]any)
|
||||||
|
if message["reasoning_content"] != "detail-summarytagged" || message["content"] != "answer" {
|
||||||
|
t.Fatalf("reasoning was not normalized: %+v", response.Result)
|
||||||
|
}
|
||||||
|
if _, ok := message["reasoning_details"]; ok {
|
||||||
|
t.Fatalf("reasoning_details should be converted away: %+v", message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIClientChatResponseNormalizesToolCallFormats(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"id": "chatcmpl-tools",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"model": "tool-format-test",
|
||||||
|
"choices": []any{map[string]any{
|
||||||
|
"message": map[string]any{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": []any{
|
||||||
|
map[string]any{"type": "text", "text": "calling tools"},
|
||||||
|
map[string]any{"type": "tool_use", "id": "toolu_1", "name": "anthropic_lookup", "input": map[string]any{"city": "Boston"}},
|
||||||
|
},
|
||||||
|
"toolCalls": []any{map[string]any{"id": "call_camel", "functionCall": map[string]any{"name": "camel_lookup", "args": map[string]any{"city": "SF"}}}},
|
||||||
|
"function_call": map[string]any{"name": "legacy_lookup", "arguments": "{\"city\":\"NYC\"}"},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
response, err := (OpenAIClient{HTTPClient: server.Client()}).Run(context.Background(), Request{
|
||||||
|
Kind: "chat.completions",
|
||||||
|
Model: "Tool-Format-Test",
|
||||||
|
Body: map[string]any{"model": "Tool-Format-Test", "messages": []any{map[string]any{"role": "user", "content": "ping"}}},
|
||||||
|
Candidate: store.RuntimeModelCandidate{
|
||||||
|
BaseURL: server.URL,
|
||||||
|
ModelName: "tool-format-test",
|
||||||
|
Credentials: map[string]any{"apiKey": "test-key"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("run openai client: %v", err)
|
||||||
|
}
|
||||||
|
choices, _ := response.Result["choices"].([]any)
|
||||||
|
choice, _ := choices[0].(map[string]any)
|
||||||
|
message, _ := choice["message"].(map[string]any)
|
||||||
|
if message["content"] != "calling tools" {
|
||||||
|
t.Fatalf("tool_use block should be removed from content: %+v", message)
|
||||||
|
}
|
||||||
|
for _, key := range []string{"toolCalls", "function_call"} {
|
||||||
|
if _, ok := message[key]; ok {
|
||||||
|
t.Fatalf("%s should be converted away: %+v", key, message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
toolCalls, _ := message["tool_calls"].([]any)
|
||||||
|
if len(toolCalls) != 3 {
|
||||||
|
t.Fatalf("expected 3 normalized tool calls, got %+v", message)
|
||||||
|
}
|
||||||
|
assertToolCall := func(index int, id string, name string, arguments string) {
|
||||||
|
t.Helper()
|
||||||
|
toolCall, _ := toolCalls[index].(map[string]any)
|
||||||
|
function, _ := toolCall["function"].(map[string]any)
|
||||||
|
if toolCall["id"] != id || toolCall["type"] != "function" || function["name"] != name || function["arguments"] != arguments {
|
||||||
|
t.Fatalf("unexpected tool call[%d]: %+v", index, toolCall)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assertToolCall(0, "call_camel", "camel_lookup", "{\"city\":\"SF\"}")
|
||||||
|
assertToolCall(1, "call_1", "legacy_lookup", "{\"city\":\"NYC\"}")
|
||||||
|
assertToolCall(2, "toolu_1", "anthropic_lookup", "{\"city\":\"Boston\"}")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOpenAIClientChatStreamContract(t *testing.T) {
|
func TestOpenAIClientChatStreamContract(t *testing.T) {
|
||||||
@ -200,6 +385,133 @@ func TestOpenAIClientChatStreamContract(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIClientChatStreamPreservesStructuredDeltas(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
_, _ = w.Write([]byte("data: {\"id\":\"chatcmpl-structured\",\"object\":\"chat.completion.chunk\",\"model\":\"openrouter-reasoner\",\"choices\":[{\"delta\":{\"reasoning_details\":[{\"type\":\"reasoning.text\",\"text\":\"detail-\"},{\"type\":\"reasoning.summary\",\"summary\":\"summary\"},{\"type\":\"reasoning.encrypted\",\"data\":\"secret\"}]}}],\"usage\":null}\n\n"))
|
||||||
|
_, _ = w.Write([]byte("data: {\"id\":\"chatcmpl-structured\",\"object\":\"chat.completion.chunk\",\"model\":\"openrouter-reasoner\",\"choices\":[{\"delta\":{\"content\":\"<think>tagged</think>answer\"}}],\"usage\":null}\n\n"))
|
||||||
|
_, _ = w.Write([]byte("data: {\"id\":\"chatcmpl-structured\",\"object\":\"chat.completion.chunk\",\"model\":\"deepseek-v4\",\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"type\":\"function\",\"function\":{\"name\":\"lookup\",\"arguments\":\"{\\\"q\\\":\"}}]}}],\"usage\":null}\n\n"))
|
||||||
|
_, _ = w.Write([]byte("data: {\"id\":\"chatcmpl-structured\",\"object\":\"chat.completion.chunk\",\"model\":\"deepseek-v4\",\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"\\\"weather\\\"}\"}}]},\"finish_reason\":\"tool_calls\"}],\"usage\":null}\n\n"))
|
||||||
|
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
captured := make([]StreamDeltaEvent, 0)
|
||||||
|
response, err := (OpenAIClient{HTTPClient: server.Client()}).Run(context.Background(), Request{
|
||||||
|
Kind: "chat.completions",
|
||||||
|
Model: "DeepSeek-V4",
|
||||||
|
Body: map[string]any{
|
||||||
|
"model": "DeepSeek-V4",
|
||||||
|
"messages": []any{map[string]any{"role": "user", "content": "ping"}},
|
||||||
|
"stream": true,
|
||||||
|
},
|
||||||
|
Candidate: store.RuntimeModelCandidate{
|
||||||
|
BaseURL: server.URL,
|
||||||
|
ModelName: "deepseek-v4",
|
||||||
|
Credentials: map[string]any{"apiKey": "test-key"},
|
||||||
|
},
|
||||||
|
StreamDelta: func(event StreamDeltaEvent) error {
|
||||||
|
captured = append(captured, event)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("run openai structured stream client: %v", err)
|
||||||
|
}
|
||||||
|
if len(captured) != 4 || captured[0].ReasoningContent != "detail-summary" || captured[1].ReasoningContent != "tagged" || captured[1].Text != "answer" || captured[2].Event == nil {
|
||||||
|
t.Fatalf("structured stream events were not preserved: %+v", captured)
|
||||||
|
}
|
||||||
|
firstChoices, _ := captured[0].Event["choices"].([]any)
|
||||||
|
firstChoice, _ := firstChoices[0].(map[string]any)
|
||||||
|
firstDelta, _ := firstChoice["delta"].(map[string]any)
|
||||||
|
if firstDelta["reasoning_content"] != "detail-summary" {
|
||||||
|
t.Fatalf("reasoning_details were not converted in stream event: %+v", captured[0].Event)
|
||||||
|
}
|
||||||
|
if _, ok := firstDelta["reasoning_details"]; ok {
|
||||||
|
t.Fatalf("reasoning_details should be removed from stream event: %+v", captured[0].Event)
|
||||||
|
}
|
||||||
|
choices, _ := response.Result["choices"].([]any)
|
||||||
|
choice, _ := choices[0].(map[string]any)
|
||||||
|
message, _ := choice["message"].(map[string]any)
|
||||||
|
if message["reasoning_content"] != "detail-summarytagged" || message["content"] != "answer" || choice["finish_reason"] != "tool_calls" {
|
||||||
|
t.Fatalf("reasoning or finish reason missing from aggregated result: %+v", response.Result)
|
||||||
|
}
|
||||||
|
toolCalls, _ := message["tool_calls"].([]any)
|
||||||
|
toolCall, _ := toolCalls[0].(map[string]any)
|
||||||
|
function, _ := toolCall["function"].(map[string]any)
|
||||||
|
if function["arguments"] != "{\"q\":\"weather\"}" {
|
||||||
|
t.Fatalf("tool call arguments were not aggregated: %+v", response.Result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIClientChatStreamNormalizesToolCallFormats(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
_, _ = w.Write([]byte("data: {\"id\":\"chatcmpl-tools-stream\",\"object\":\"chat.completion.chunk\",\"model\":\"tool-format-test\",\"choices\":[{\"delta\":{\"function_call\":{\"name\":\"legacy_lookup\",\"arguments\":\"{\\\"city\\\":\"}}}],\"usage\":null}\n\n"))
|
||||||
|
_, _ = w.Write([]byte("data: {\"id\":\"chatcmpl-tools-stream\",\"object\":\"chat.completion.chunk\",\"model\":\"tool-format-test\",\"choices\":[{\"delta\":{\"functionCall\":{\"arguments\":\"\\\"Boston\\\"}\"}}}],\"usage\":null}\n\n"))
|
||||||
|
_, _ = w.Write([]byte("data: {\"id\":\"chatcmpl-tools-stream\",\"object\":\"chat.completion.chunk\",\"model\":\"tool-format-test\",\"choices\":[{\"delta\":{\"toolCall\":{\"index\":1,\"id\":\"call_camel\",\"functionCall\":{\"name\":\"camel_lookup\",\"args\":{\"city\":\"SF\"}}}},\"finish_reason\":\"tool_calls\"}],\"usage\":null}\n\n"))
|
||||||
|
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
captured := make([]StreamDeltaEvent, 0)
|
||||||
|
response, err := (OpenAIClient{HTTPClient: server.Client()}).Run(context.Background(), Request{
|
||||||
|
Kind: "chat.completions",
|
||||||
|
Model: "Tool-Format-Test",
|
||||||
|
Body: map[string]any{
|
||||||
|
"model": "Tool-Format-Test",
|
||||||
|
"messages": []any{map[string]any{"role": "user", "content": "ping"}},
|
||||||
|
"stream": true,
|
||||||
|
},
|
||||||
|
Candidate: store.RuntimeModelCandidate{
|
||||||
|
BaseURL: server.URL,
|
||||||
|
ModelName: "tool-format-test",
|
||||||
|
Credentials: map[string]any{"apiKey": "test-key"},
|
||||||
|
},
|
||||||
|
StreamDelta: func(event StreamDeltaEvent) error {
|
||||||
|
captured = append(captured, event)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("run openai stream client: %v", err)
|
||||||
|
}
|
||||||
|
if len(captured) != 3 {
|
||||||
|
t.Fatalf("unexpected captured events: %+v", captured)
|
||||||
|
}
|
||||||
|
for _, event := range captured {
|
||||||
|
choices, _ := event.Event["choices"].([]any)
|
||||||
|
choice, _ := choices[0].(map[string]any)
|
||||||
|
delta, _ := choice["delta"].(map[string]any)
|
||||||
|
if _, ok := delta["function_call"]; ok {
|
||||||
|
t.Fatalf("function_call should be converted away: %+v", event.Event)
|
||||||
|
}
|
||||||
|
if _, ok := delta["functionCall"]; ok {
|
||||||
|
t.Fatalf("functionCall should be converted away: %+v", event.Event)
|
||||||
|
}
|
||||||
|
if _, ok := delta["toolCall"]; ok {
|
||||||
|
t.Fatalf("toolCall should be converted away: %+v", event.Event)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
choices, _ := response.Result["choices"].([]any)
|
||||||
|
choice, _ := choices[0].(map[string]any)
|
||||||
|
message, _ := choice["message"].(map[string]any)
|
||||||
|
toolCalls, _ := message["tool_calls"].([]any)
|
||||||
|
if len(toolCalls) != 2 || choice["finish_reason"] != "tool_calls" {
|
||||||
|
t.Fatalf("unexpected normalized stream result: %+v", response.Result)
|
||||||
|
}
|
||||||
|
legacyCall, _ := toolCalls[0].(map[string]any)
|
||||||
|
legacyFunction, _ := legacyCall["function"].(map[string]any)
|
||||||
|
if legacyFunction["name"] != "legacy_lookup" || legacyFunction["arguments"] != "{\"city\":\"Boston\"}" {
|
||||||
|
t.Fatalf("legacy function_call was not aggregated: %+v", response.Result)
|
||||||
|
}
|
||||||
|
camelCall, _ := toolCalls[1].(map[string]any)
|
||||||
|
camelFunction, _ := camelCall["function"].(map[string]any)
|
||||||
|
if camelCall["id"] != "call_camel" || camelFunction["name"] != "camel_lookup" || camelFunction["arguments"] != "{\"city\":\"SF\"}" {
|
||||||
|
t.Fatalf("camel toolCall was not normalized: %+v", response.Result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGeminiClientChatContract(t *testing.T) {
|
func TestGeminiClientChatContract(t *testing.T) {
|
||||||
var gotPath string
|
var gotPath string
|
||||||
var gotKey string
|
var gotKey string
|
||||||
@ -257,6 +569,213 @@ func TestGeminiClientChatContract(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGeminiClientChatConvertsMediaContentParts(t *testing.T) {
|
||||||
|
var captured map[string]any
|
||||||
|
var gotPath string
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
gotPath = r.URL.Path
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&captured); err != nil {
|
||||||
|
t.Fatalf("decode request: %v", err)
|
||||||
|
}
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"candidates": []any{map[string]any{
|
||||||
|
"content": map[string]any{"parts": []any{map[string]any{"text": "video ok"}}},
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
_, err := (GeminiClient{HTTPClient: server.Client()}).Run(context.Background(), Request{
|
||||||
|
Kind: "chat.completions",
|
||||||
|
Model: "gemini:gemini-2.5-flash",
|
||||||
|
Body: map[string]any{
|
||||||
|
"model": "gemini:gemini-2.5-flash",
|
||||||
|
"messages": []any{map[string]any{
|
||||||
|
"role": "user",
|
||||||
|
"content": []any{
|
||||||
|
map[string]any{"type": "text", "text": "analyze this video"},
|
||||||
|
map[string]any{"type": "video_url", "video_url": map[string]any{"url": "https://cdn.example.com/input.mov", "mime_type": "video/quicktime"}},
|
||||||
|
map[string]any{"type": "audio_url", "audio_url": map[string]any{"url": "data:audio/wav;base64,UklGRg=="}},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
Candidate: store.RuntimeModelCandidate{
|
||||||
|
BaseURL: server.URL + "/v1beta/openai",
|
||||||
|
ProviderModelName: "gemini-2.5-flash",
|
||||||
|
ModelType: "chat",
|
||||||
|
Credentials: map[string]any{"apiKey": "gemini-key"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("run gemini client: %v", err)
|
||||||
|
}
|
||||||
|
if gotPath != "/v1beta/models/gemini-2.5-flash:generateContent" {
|
||||||
|
t.Fatalf("Gemini OpenAI-compatible base URL should normalize to native endpoint, got %s", gotPath)
|
||||||
|
}
|
||||||
|
contents, _ := captured["contents"].([]any)
|
||||||
|
if len(contents) != 1 {
|
||||||
|
t.Fatalf("unexpected Gemini contents: %+v", captured)
|
||||||
|
}
|
||||||
|
turn, _ := contents[0].(map[string]any)
|
||||||
|
parts, _ := turn["parts"].([]any)
|
||||||
|
if len(parts) != 3 {
|
||||||
|
t.Fatalf("expected text, video, and audio parts, got %+v", turn)
|
||||||
|
}
|
||||||
|
video, _ := parts[1].(map[string]any)
|
||||||
|
videoFile, _ := video["fileData"].(map[string]any)
|
||||||
|
if videoFile["fileUri"] != "https://cdn.example.com/input.mov" || videoFile["mimeType"] != "video/quicktime" {
|
||||||
|
t.Fatalf("video_url should become Gemini fileData, got %+v", video)
|
||||||
|
}
|
||||||
|
audio, _ := parts[2].(map[string]any)
|
||||||
|
audioInline, _ := audio["inlineData"].(map[string]any)
|
||||||
|
if audioInline["mimeType"] != "audio/wav" || audioInline["data"] != "UklGRg==" {
|
||||||
|
t.Fatalf("audio data URL should become Gemini inlineData, got %+v", audio)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeminiClientChatRestoresToolContext(t *testing.T) {
|
||||||
|
var captured map[string]any
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&captured); err != nil {
|
||||||
|
t.Fatalf("decode request: %v", err)
|
||||||
|
}
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"candidates": []any{map[string]any{
|
||||||
|
"content": map[string]any{"parts": []any{map[string]any{"text": "gemini ok"}}},
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
_, err := (GeminiClient{HTTPClient: server.Client()}).Run(context.Background(), Request{
|
||||||
|
Kind: "chat.completions",
|
||||||
|
Model: "gemini:gemini-2.5-flash",
|
||||||
|
Body: map[string]any{
|
||||||
|
"model": "gemini:gemini-2.5-flash",
|
||||||
|
"messages": []any{
|
||||||
|
map[string]any{"role": "user", "content": "weather?"},
|
||||||
|
map[string]any{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "checking",
|
||||||
|
"tool_calls": []any{map[string]any{
|
||||||
|
"id": "call_weather",
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]any{
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": `{"city":"SF"}`,
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
map[string]any{"role": "tool", "tool_call_id": "call_weather", "content": `{"temperature":"72F"}`},
|
||||||
|
},
|
||||||
|
"tools": []any{map[string]any{
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]any{
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "lookup weather",
|
||||||
|
"parameters": map[string]any{"type": "object"},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
Candidate: store.RuntimeModelCandidate{
|
||||||
|
BaseURL: server.URL,
|
||||||
|
ProviderModelName: "gemini-2.5-flash",
|
||||||
|
ModelType: "chat",
|
||||||
|
Credentials: map[string]any{"apiKey": "gemini-key"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("run gemini client: %v", err)
|
||||||
|
}
|
||||||
|
contents, _ := captured["contents"].([]any)
|
||||||
|
if len(contents) != 3 {
|
||||||
|
t.Fatalf("unexpected Gemini contents: %+v", captured)
|
||||||
|
}
|
||||||
|
modelTurn, _ := contents[1].(map[string]any)
|
||||||
|
if modelTurn["role"] != "model" {
|
||||||
|
t.Fatalf("assistant turn should become Gemini model turn: %+v", modelTurn)
|
||||||
|
}
|
||||||
|
modelParts, _ := modelTurn["parts"].([]any)
|
||||||
|
callPart, _ := modelParts[1].(map[string]any)
|
||||||
|
functionCall, _ := callPart["functionCall"].(map[string]any)
|
||||||
|
args, _ := functionCall["args"].(map[string]any)
|
||||||
|
if functionCall["name"] != "get_weather" || args["city"] != "SF" {
|
||||||
|
t.Fatalf("tool call was not restored for Gemini: %+v", modelTurn)
|
||||||
|
}
|
||||||
|
toolTurn, _ := contents[2].(map[string]any)
|
||||||
|
toolParts, _ := toolTurn["parts"].([]any)
|
||||||
|
responsePart, _ := toolParts[0].(map[string]any)
|
||||||
|
functionResponse, _ := responsePart["functionResponse"].(map[string]any)
|
||||||
|
response, _ := functionResponse["response"].(map[string]any)
|
||||||
|
if toolTurn["role"] != "user" || functionResponse["name"] != "get_weather" || response["temperature"] != "72F" {
|
||||||
|
t.Fatalf("tool result was not restored for Gemini: %+v", toolTurn)
|
||||||
|
}
|
||||||
|
tools, _ := captured["tools"].([]any)
|
||||||
|
declarationGroup, _ := tools[0].(map[string]any)
|
||||||
|
declarations, _ := declarationGroup["functionDeclarations"].([]any)
|
||||||
|
declaration, _ := declarations[0].(map[string]any)
|
||||||
|
if declaration["name"] != "get_weather" || declaration["description"] != "lookup weather" {
|
||||||
|
t.Fatalf("tool declaration was not converted for Gemini: %+v", captured["tools"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeminiClientChatConvertsFunctionCallResponse(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"candidates": []any{map[string]any{
|
||||||
|
"finishReason": "STOP",
|
||||||
|
"content": map[string]any{"parts": []any{
|
||||||
|
map[string]any{"functionCall": map[string]any{
|
||||||
|
"name": "get_weather",
|
||||||
|
"args": map[string]any{"city": "SF"},
|
||||||
|
}},
|
||||||
|
}},
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
response, err := (GeminiClient{HTTPClient: server.Client()}).Run(context.Background(), Request{
|
||||||
|
Kind: "chat.completions",
|
||||||
|
Model: "gemini:gemini-2.5-flash",
|
||||||
|
Body: map[string]any{
|
||||||
|
"model": "gemini:gemini-2.5-flash",
|
||||||
|
"messages": []any{map[string]any{"role": "user", "content": "weather?"}},
|
||||||
|
"tools": []any{map[string]any{
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]any{"name": "get_weather"},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
Candidate: store.RuntimeModelCandidate{
|
||||||
|
BaseURL: server.URL,
|
||||||
|
ProviderModelName: "gemini-2.5-flash",
|
||||||
|
ModelType: "chat",
|
||||||
|
Credentials: map[string]any{"apiKey": "gemini-key"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("run gemini client: %v", err)
|
||||||
|
}
|
||||||
|
choices, _ := response.Result["choices"].([]any)
|
||||||
|
choice, _ := choices[0].(map[string]any)
|
||||||
|
if choice["finish_reason"] != "tool_calls" {
|
||||||
|
t.Fatalf("Gemini function call should use tool_calls finish reason: %+v", response.Result)
|
||||||
|
}
|
||||||
|
message, _ := choice["message"].(map[string]any)
|
||||||
|
if message["content"] != nil {
|
||||||
|
t.Fatalf("tool-only Gemini response should keep nullable content: %+v", message)
|
||||||
|
}
|
||||||
|
toolCalls, _ := message["tool_calls"].([]any)
|
||||||
|
if len(toolCalls) != 1 {
|
||||||
|
t.Fatalf("Gemini function call was not converted: %+v", message)
|
||||||
|
}
|
||||||
|
toolCall, _ := toolCalls[0].(map[string]any)
|
||||||
|
function, _ := toolCall["function"].(map[string]any)
|
||||||
|
if toolCall["type"] != "function" || toolCall["id"] != "call_0" || function["name"] != "get_weather" || function["arguments"] != `{"city":"SF"}` {
|
||||||
|
t.Fatalf("unexpected Gemini tool call: %+v", toolCall)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGeminiURLAcceptsVersionedBaseURL(t *testing.T) {
|
func TestGeminiURLAcceptsVersionedBaseURL(t *testing.T) {
|
||||||
got := geminiURL("https://generativelanguage.googleapis.com/v1beta", "gemini-2.5-flash", "test-key")
|
got := geminiURL("https://generativelanguage.googleapis.com/v1beta", "gemini-2.5-flash", "test-key")
|
||||||
want := "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent?key=test-key"
|
want := "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent?key=test-key"
|
||||||
@ -662,6 +1181,266 @@ func TestVolcesClientVideoResumePollsExistingTaskID(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestKelingClientVideoSubmitsAndPollsImageTask(t *testing.T) {
|
||||||
|
var submitPath string
|
||||||
|
var pollPath string
|
||||||
|
var gotAuth string
|
||||||
|
var submittedTaskID string
|
||||||
|
var submittedPayload map[string]any
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
gotAuth = r.Header.Get("Authorization")
|
||||||
|
switch r.Method + " " + r.URL.Path {
|
||||||
|
case "POST /videos/image2video":
|
||||||
|
submitPath = r.URL.Path
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&submittedPayload); err != nil {
|
||||||
|
t.Fatalf("decode keling submit: %v", err)
|
||||||
|
}
|
||||||
|
if _, ok := submittedPayload["aspect_ratio"]; ok {
|
||||||
|
t.Fatalf("image2video payload should not include aspect_ratio: %+v", submittedPayload)
|
||||||
|
}
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"code": 0,
|
||||||
|
"request_id": "req-submit",
|
||||||
|
"data": map[string]any{"task_id": "keling-task-1"},
|
||||||
|
})
|
||||||
|
case "GET /videos/image2video/keling-task-1":
|
||||||
|
pollPath = r.URL.Path
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"code": 0,
|
||||||
|
"request_id": "req-poll",
|
||||||
|
"data": map[string]any{
|
||||||
|
"task_id": "keling-task-1",
|
||||||
|
"task_status": "succeed",
|
||||||
|
"created_at": 456,
|
||||||
|
"task_result": map[string]any{
|
||||||
|
"videos": []any{map[string]any{"url": "https://example.com/keling.mp4", "duration": 6}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
t.Fatalf("unexpected request %s %s", r.Method, r.URL.Path)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
response, err := (KelingClient{HTTPClient: server.Client()}).Run(context.Background(), Request{
|
||||||
|
Kind: "videos.generations",
|
||||||
|
ModelType: "image_to_video",
|
||||||
|
Model: "可灵2.6",
|
||||||
|
Body: map[string]any{
|
||||||
|
"model": "可灵2.6",
|
||||||
|
"prompt": "A clean product reveal",
|
||||||
|
"first_frame": "data:image/png;base64,Zmlyc3Q=",
|
||||||
|
"last_frame": "data:image/png;base64,bGFzdA==",
|
||||||
|
"duration": 6,
|
||||||
|
"resolution": "1080p",
|
||||||
|
"aspect_ratio": "16:9",
|
||||||
|
"audio": true,
|
||||||
|
"camera_control": "simple:zoom",
|
||||||
|
"camera_control_strength": 0.6,
|
||||||
|
},
|
||||||
|
Candidate: store.RuntimeModelCandidate{
|
||||||
|
BaseURL: server.URL,
|
||||||
|
Provider: "keling",
|
||||||
|
AuthType: "AccessKey-SecretKey",
|
||||||
|
ModelName: "可灵2.6",
|
||||||
|
ProviderModelName: "kling-v2-6",
|
||||||
|
Credentials: map[string]any{"accessKey": "ak", "secretKey": "sk"},
|
||||||
|
PlatformConfig: map[string]any{
|
||||||
|
"kelingPollIntervalMs": 100,
|
||||||
|
"kelingPollTimeoutSeconds": 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
OnRemoteTaskSubmitted: func(remoteTaskID string, payload map[string]any) error {
|
||||||
|
submittedTaskID = remoteTaskID
|
||||||
|
if payload["endpoint"] != "/videos/image2video" || payload["taskType"] != "image2video" {
|
||||||
|
t.Fatalf("unexpected submitted keling payload: %+v", payload)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("run keling video: %v", err)
|
||||||
|
}
|
||||||
|
if submitPath != "/videos/image2video" || pollPath != "/videos/image2video/keling-task-1" || !strings.HasPrefix(gotAuth, "Bearer ") {
|
||||||
|
t.Fatalf("unexpected keling paths/auth submit=%s poll=%s auth=%s", submitPath, pollPath, gotAuth)
|
||||||
|
}
|
||||||
|
if submittedTaskID != "keling-task-1" {
|
||||||
|
t.Fatalf("remote task submit callback did not receive task id, got %q", submittedTaskID)
|
||||||
|
}
|
||||||
|
if submittedPayload["model_name"] != "kling-v2-6" ||
|
||||||
|
submittedPayload["prompt"] != "A clean product reveal" ||
|
||||||
|
submittedPayload["duration"] != "6" ||
|
||||||
|
submittedPayload["mode"] != "pro" ||
|
||||||
|
submittedPayload["sound"] != "on" ||
|
||||||
|
submittedPayload["image"] != "Zmlyc3Q=" ||
|
||||||
|
submittedPayload["image_tail"] != "bGFzdA==" {
|
||||||
|
t.Fatalf("unexpected keling submit payload: %+v", submittedPayload)
|
||||||
|
}
|
||||||
|
camera, _ := submittedPayload["camera_control"].(map[string]any)
|
||||||
|
config, _ := camera["config"].(map[string]any)
|
||||||
|
if camera["type"] != "simple" || numericValue(config["zoom"], 0) != 0.6 || numericValue(config["pan"], -1) != 0 {
|
||||||
|
t.Fatalf("unexpected keling camera conversion: %+v", submittedPayload["camera_control"])
|
||||||
|
}
|
||||||
|
data, _ := response.Result["data"].([]any)
|
||||||
|
item, _ := data[0].(map[string]any)
|
||||||
|
if response.Result["upstream_task_id"] != "keling-task-1" || item["url"] != "https://example.com/keling.mp4" || item["video_url"] != "https://example.com/keling.mp4" {
|
||||||
|
t.Fatalf("unexpected keling response: %+v", response.Result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKelingOmniPayloadConvertsGatewayContent(t *testing.T) {
|
||||||
|
payload, cleanupIDs, err := (KelingClient{}).kelingOmniPayload(context.Background(), Request{
|
||||||
|
Kind: "videos.generations",
|
||||||
|
ModelType: "omni_video",
|
||||||
|
Model: "可灵V3多模态",
|
||||||
|
Body: map[string]any{
|
||||||
|
"model": "可灵V3多模态",
|
||||||
|
"duration": 8,
|
||||||
|
"aspect_ratio": "9:16",
|
||||||
|
"resolution": "2160p",
|
||||||
|
"audio": true,
|
||||||
|
"content": []any{
|
||||||
|
map[string]any{"type": "text", "text": "Refine the base video"},
|
||||||
|
map[string]any{"type": "image_url", "role": "first_frame", "image_url": map[string]any{"url": "https://example.com/first.png"}},
|
||||||
|
map[string]any{"type": "image_url", "role": "last_frame", "image_url": map[string]any{"url": "https://example.com/last.png"}},
|
||||||
|
map[string]any{
|
||||||
|
"type": "video_url",
|
||||||
|
"role": "video_base",
|
||||||
|
"video_url": map[string]any{
|
||||||
|
"url": "https://example.com/base.mp4",
|
||||||
|
"keep_original_sound": "yes",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Candidate: store.RuntimeModelCandidate{
|
||||||
|
Provider: "keling",
|
||||||
|
ProviderModelName: "kling-v3-omni",
|
||||||
|
Capabilities: map[string]any{"omni_video": map[string]any{}},
|
||||||
|
},
|
||||||
|
}, "token")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("build keling omni payload: %v", err)
|
||||||
|
}
|
||||||
|
if len(cleanupIDs) != 0 {
|
||||||
|
t.Fatalf("unexpected cleanup ids: %+v", cleanupIDs)
|
||||||
|
}
|
||||||
|
if payload["model_name"] != "kling-v3-omni" || payload["mode"] != "4k" || payload["prompt"] != "Refine the base video" {
|
||||||
|
t.Fatalf("unexpected keling omni base fields: %+v", payload)
|
||||||
|
}
|
||||||
|
if _, ok := payload["sound"]; ok {
|
||||||
|
t.Fatalf("omni payload with base video should not include sound: %+v", payload)
|
||||||
|
}
|
||||||
|
if _, ok := payload["duration"]; ok {
|
||||||
|
t.Fatalf("base video edit should not include duration: %+v", payload)
|
||||||
|
}
|
||||||
|
if _, ok := payload["aspect_ratio"]; ok {
|
||||||
|
t.Fatalf("base video edit should not include aspect_ratio: %+v", payload)
|
||||||
|
}
|
||||||
|
watermark, _ := payload["watermark_info"].(map[string]any)
|
||||||
|
if watermark["enabled"] != false {
|
||||||
|
t.Fatalf("keling watermark should be disabled by default: %+v", payload)
|
||||||
|
}
|
||||||
|
images, _ := payload["image_list"].([]any)
|
||||||
|
if len(images) != 2 {
|
||||||
|
t.Fatalf("unexpected keling image_list: %+v", payload["image_list"])
|
||||||
|
}
|
||||||
|
firstImage, _ := images[0].(map[string]any)
|
||||||
|
lastImage, _ := images[1].(map[string]any)
|
||||||
|
if firstImage["type"] != "first_frame" || lastImage["type"] != "end_frame" {
|
||||||
|
t.Fatalf("frame roles should convert to keling omni types: %+v", images)
|
||||||
|
}
|
||||||
|
videos, _ := payload["video_list"].([]map[string]any)
|
||||||
|
if len(videos) != 1 || videos[0]["refer_type"] != "base" || videos[0]["keep_original_sound"] != "yes" {
|
||||||
|
t.Fatalf("video roles should convert to keling omni refer_type: %+v", payload["video_list"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKelingClientVideoResumePollsWithoutSubmitting(t *testing.T) {
|
||||||
|
var submitCalled bool
|
||||||
|
var pollPath string
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.Method + " " + r.URL.Path {
|
||||||
|
case "POST /general/custom-elements", "POST /videos/omni-video":
|
||||||
|
submitCalled = true
|
||||||
|
t.Fatalf("resume should not submit or upload temporary elements")
|
||||||
|
case "GET /videos/omni-video/keling-existing":
|
||||||
|
pollPath = r.URL.Path
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"code": 0,
|
||||||
|
"request_id": "req-resume",
|
||||||
|
"data": map[string]any{
|
||||||
|
"task_id": "keling-existing",
|
||||||
|
"task_status": "succeed",
|
||||||
|
"task_result": map[string]any{
|
||||||
|
"videos": []any{map[string]any{"url": "https://example.com/resumed-keling.mp4"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
t.Fatalf("unexpected request %s %s", r.Method, r.URL.Path)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
response, err := (KelingClient{HTTPClient: server.Client()}).Run(context.Background(), Request{
|
||||||
|
Kind: "videos.generations",
|
||||||
|
ModelType: "omni_video",
|
||||||
|
Model: "可灵V3多模态",
|
||||||
|
Body: map[string]any{"prompt": "resume", "pollIntervalMs": 100, "pollTimeoutSeconds": 1},
|
||||||
|
RemoteTaskID: "keling-existing",
|
||||||
|
RemoteTaskPayload: map[string]any{
|
||||||
|
"endpoint": "/videos/omni-video",
|
||||||
|
},
|
||||||
|
Candidate: store.RuntimeModelCandidate{
|
||||||
|
BaseURL: server.URL,
|
||||||
|
Provider: "keling",
|
||||||
|
AuthType: "AccessKey-SecretKey",
|
||||||
|
ProviderModelName: "kling-v3-omni",
|
||||||
|
Credentials: map[string]any{"accessKey": "ak", "secretKey": "sk"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("resume keling video: %v", err)
|
||||||
|
}
|
||||||
|
if submitCalled || pollPath != "/videos/omni-video/keling-existing" {
|
||||||
|
t.Fatalf("resume should poll existing task only, submit=%v poll=%s", submitCalled, pollPath)
|
||||||
|
}
|
||||||
|
data, _ := response.Result["data"].([]any)
|
||||||
|
item, _ := data[0].(map[string]any)
|
||||||
|
if response.Result["upstream_task_id"] != "keling-existing" || item["url"] != "https://example.com/resumed-keling.mp4" {
|
||||||
|
t.Fatalf("unexpected resumed keling response: %+v", response.Result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKelingElementPayloadMapsTags(t *testing.T) {
|
||||||
|
payload := kelingCreateElementPayload(map[string]any{
|
||||||
|
"name": "subject",
|
||||||
|
"frontal_image_url": "https://example.com/front.png",
|
||||||
|
"tags": []any{"character", "unknown"},
|
||||||
|
"refer_images": []any{
|
||||||
|
map[string]any{"url": "https://example.com/side.png"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if payload["element_name"] != "subject" || payload["element_frontal_image"] != "https://example.com/front.png" {
|
||||||
|
t.Fatalf("unexpected element payload base fields: %+v", payload)
|
||||||
|
}
|
||||||
|
tags, _ := payload["tag_list"].([]any)
|
||||||
|
if len(tags) != 2 {
|
||||||
|
t.Fatalf("unexpected tag list: %+v", payload["tag_list"])
|
||||||
|
}
|
||||||
|
firstTag, _ := tags[0].(map[string]any)
|
||||||
|
secondTag, _ := tags[1].(map[string]any)
|
||||||
|
if firstTag["tag_id"] != "o_102" || secondTag["tag_id"] != "o_108" {
|
||||||
|
t.Fatalf("unexpected keling tag conversion: %+v", payload["tag_list"])
|
||||||
|
}
|
||||||
|
refs, _ := payload["element_refer_list"].([]any)
|
||||||
|
if len(refs) != 1 {
|
||||||
|
t.Fatalf("unexpected element references: %+v", payload["element_refer_list"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func extractText(result map[string]any) string {
|
func extractText(result map[string]any) string {
|
||||||
choices, _ := result["choices"].([]any)
|
choices, _ := result["choices"].([]any)
|
||||||
choice, _ := choices[0].(map[string]any)
|
choice, _ := choices[0].(map[string]any)
|
||||||
|
|||||||
@ -5,8 +5,10 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"mime"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"path"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -27,11 +29,11 @@ func (c GeminiClient) Run(ctx context.Context, request Request) (Response, error
|
|||||||
return Response{}, err
|
return Response{}, err
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
responseStartedAt := time.Now()
|
||||||
resp, err := httpClient(request.HTTPClient, c.HTTPClient).Do(req)
|
resp, err := httpClient(request.HTTPClient, c.HTTPClient).Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Response{}, &ClientError{Code: "network", Message: err.Error(), Retryable: true}
|
return Response{}, &ClientError{Code: "network", Message: err.Error(), Retryable: true}
|
||||||
}
|
}
|
||||||
responseStartedAt := time.Now()
|
|
||||||
requestID := requestIDFromHTTPResponse(resp)
|
requestID := requestIDFromHTTPResponse(resp)
|
||||||
result, err := decodeHTTPResponse(resp)
|
result, err := decodeHTTPResponse(resp)
|
||||||
responseFinishedAt := time.Now()
|
responseFinishedAt := time.Now()
|
||||||
@ -58,6 +60,7 @@ func geminiURL(baseURL string, model string, apiKey string) string {
|
|||||||
if base == "" {
|
if base == "" {
|
||||||
base = "https://generativelanguage.googleapis.com"
|
base = "https://generativelanguage.googleapis.com"
|
||||||
}
|
}
|
||||||
|
base = strings.TrimSuffix(base, "/openai")
|
||||||
if strings.HasSuffix(base, "/v1beta") {
|
if strings.HasSuffix(base, "/v1beta") {
|
||||||
base = strings.TrimSuffix(base, "/v1beta")
|
base = strings.TrimSuffix(base, "/v1beta")
|
||||||
}
|
}
|
||||||
@ -70,15 +73,317 @@ func geminiBody(request Request) map[string]any {
|
|||||||
return map[string]any{"contents": contents}
|
return map[string]any{"contents": contents}
|
||||||
}
|
}
|
||||||
prompt := firstNonEmptyPrompt(request.Body, "")
|
prompt := firstNonEmptyPrompt(request.Body, "")
|
||||||
if prompt == "" {
|
if prompt != "" {
|
||||||
prompt = textFromMessages(request.Body)
|
|
||||||
}
|
|
||||||
return map[string]any{
|
return map[string]any{
|
||||||
"contents": []any{map[string]any{
|
"contents": []any{map[string]any{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"parts": []any{map[string]any{"text": prompt}},
|
"parts": []any{map[string]any{"text": prompt}},
|
||||||
}},
|
}},
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
body := map[string]any{"contents": geminiContentsFromMessages(request.Body)}
|
||||||
|
if tools := geminiToolsFromOpenAITools(request.Body["tools"]); len(tools) > 0 {
|
||||||
|
body["tools"] = tools
|
||||||
|
}
|
||||||
|
contents, _ := body["contents"].([]any)
|
||||||
|
if len(contents) > 0 {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
return map[string]any{"contents": []any{map[string]any{
|
||||||
|
"role": "user",
|
||||||
|
"parts": []any{map[string]any{"text": textFromMessages(request.Body)}},
|
||||||
|
}}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func geminiContentsFromMessages(body map[string]any) []any {
|
||||||
|
normalized := NormalizeChatCompletionRequestBody(body)
|
||||||
|
messages, _ := normalized["messages"].([]any)
|
||||||
|
contents := make([]any, 0, len(messages))
|
||||||
|
toolNames := map[string]string{}
|
||||||
|
for _, rawMessage := range messages {
|
||||||
|
message, _ := rawMessage.(map[string]any)
|
||||||
|
if len(message) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
role := stringFromAny(message["role"])
|
||||||
|
if role == "tool" {
|
||||||
|
toolCallID := stringFromAny(message["tool_call_id"])
|
||||||
|
name := toolNames[toolCallID]
|
||||||
|
if name == "" {
|
||||||
|
name = toolCallID
|
||||||
|
}
|
||||||
|
if name == "" {
|
||||||
|
name = "tool"
|
||||||
|
}
|
||||||
|
contents = append(contents, map[string]any{
|
||||||
|
"role": "user",
|
||||||
|
"parts": []any{map[string]any{"functionResponse": map[string]any{
|
||||||
|
"name": name,
|
||||||
|
"response": geminiFunctionResponsePayload(message["content"]),
|
||||||
|
}}},
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
parts := geminiContentParts(message["content"])
|
||||||
|
if role == "assistant" {
|
||||||
|
for _, rawToolCall := range toolCallsSlice(message["tool_calls"]) {
|
||||||
|
toolCall, _ := rawToolCall.(map[string]any)
|
||||||
|
function, _ := toolCall["function"].(map[string]any)
|
||||||
|
name := stringFromAny(function["name"])
|
||||||
|
if name == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if id := stringFromAny(toolCall["id"]); id != "" {
|
||||||
|
toolNames[id] = name
|
||||||
|
}
|
||||||
|
parts = append(parts, map[string]any{"functionCall": map[string]any{
|
||||||
|
"name": name,
|
||||||
|
"args": geminiFunctionArgs(function["arguments"]),
|
||||||
|
}})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(parts) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
contents = append(contents, map[string]any{
|
||||||
|
"role": geminiRole(role),
|
||||||
|
"parts": parts,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return contents
|
||||||
|
}
|
||||||
|
|
||||||
|
func geminiRole(role string) string {
|
||||||
|
if role == "assistant" {
|
||||||
|
return "model"
|
||||||
|
}
|
||||||
|
return "user"
|
||||||
|
}
|
||||||
|
|
||||||
|
func geminiContentParts(content any) []any {
|
||||||
|
parts := make([]any, 0)
|
||||||
|
switch typed := content.(type) {
|
||||||
|
case string:
|
||||||
|
if strings.TrimSpace(typed) != "" {
|
||||||
|
parts = append(parts, map[string]any{"text": typed})
|
||||||
|
}
|
||||||
|
case []any:
|
||||||
|
for _, rawPart := range typed {
|
||||||
|
part, _ := rawPart.(map[string]any)
|
||||||
|
if len(part) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch stringFromAny(part["type"]) {
|
||||||
|
case "text":
|
||||||
|
if text := strings.TrimSpace(stringFromAny(firstPresent(part["text"], part["content"]))); text != "" {
|
||||||
|
parts = append(parts, map[string]any{"text": text})
|
||||||
|
}
|
||||||
|
case "image_url":
|
||||||
|
if media := geminiMediaPart(part, "image_url", "image"); media != nil {
|
||||||
|
parts = append(parts, media)
|
||||||
|
}
|
||||||
|
case "video_url":
|
||||||
|
if media := geminiMediaPart(part, "video_url", "video"); media != nil {
|
||||||
|
parts = append(parts, media)
|
||||||
|
}
|
||||||
|
case "audio_url":
|
||||||
|
if media := geminiMediaPart(part, "audio_url", "audio"); media != nil {
|
||||||
|
parts = append(parts, media)
|
||||||
|
}
|
||||||
|
case "input_audio":
|
||||||
|
if media := geminiInputAudioPart(part); media != nil {
|
||||||
|
parts = append(parts, media)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if text := strings.TrimSpace(stringFromAny(firstPresent(part["text"], part["content"]))); text != "" {
|
||||||
|
parts = append(parts, map[string]any{"text": text})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return parts
|
||||||
|
}
|
||||||
|
|
||||||
|
func geminiMediaPart(part map[string]any, key string, mediaType string) map[string]any {
|
||||||
|
nested := mapFromAny(part[key])
|
||||||
|
uri := firstNonEmptyString(nested["url"], part["url"], part[key])
|
||||||
|
if uri == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
mimeType := firstNonEmptyString(nested["mime_type"], nested["mimeType"], part["mime_type"], part["mimeType"])
|
||||||
|
return geminiMediaURLPart(uri, mimeType, mediaType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func geminiInputAudioPart(part map[string]any) map[string]any {
|
||||||
|
audio := mapFromAny(part["input_audio"])
|
||||||
|
uri := firstNonEmptyString(audio["data"], audio["url"])
|
||||||
|
if uri == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
mimeType := firstNonEmptyString(audio["mime_type"], audio["mimeType"])
|
||||||
|
if mimeType == "" {
|
||||||
|
format := strings.ToLower(strings.TrimPrefix(stringFromAny(audio["format"]), "."))
|
||||||
|
if strings.Contains(format, "/") {
|
||||||
|
mimeType = format
|
||||||
|
} else if format == "mp3" {
|
||||||
|
mimeType = "audio/mpeg"
|
||||||
|
} else if format != "" {
|
||||||
|
mimeType = "audio/" + format
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return geminiMediaURLPart(uri, mimeType, "audio")
|
||||||
|
}
|
||||||
|
|
||||||
|
func geminiMediaURLPart(uri string, explicitMimeType string, mediaType string) map[string]any {
|
||||||
|
if parsed := geminiDataURL(uri); parsed != nil {
|
||||||
|
return map[string]any{"inlineData": map[string]any{
|
||||||
|
"mimeType": geminiMediaMime(firstNonEmptyString(explicitMimeType, parsed.mimeType), mediaType),
|
||||||
|
"data": parsed.data,
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
return map[string]any{"fileData": map[string]any{
|
||||||
|
"fileUri": uri,
|
||||||
|
"mimeType": geminiMediaMime(firstNonEmptyString(explicitMimeType, mimeFromURI(uri)), mediaType),
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
|
||||||
|
type geminiParsedDataURL struct {
|
||||||
|
mimeType string
|
||||||
|
data string
|
||||||
|
}
|
||||||
|
|
||||||
|
func geminiDataURL(value string) *geminiParsedDataURL {
|
||||||
|
if !strings.HasPrefix(value, "data:") {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
prefix, data, ok := strings.Cut(value, ",")
|
||||||
|
if !ok || !strings.Contains(prefix, ";base64") {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
mimeType := strings.TrimPrefix(strings.Split(prefix, ";")[0], "data:")
|
||||||
|
if mimeType == "" {
|
||||||
|
mimeType = "application/octet-stream"
|
||||||
|
}
|
||||||
|
return &geminiParsedDataURL{mimeType: mimeType, data: data}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mimeFromURI(value string) string {
|
||||||
|
pathValue := value
|
||||||
|
if parsed, err := url.Parse(value); err == nil && parsed.Path != "" {
|
||||||
|
pathValue = parsed.Path
|
||||||
|
}
|
||||||
|
extension := strings.ToLower(path.Ext(pathValue))
|
||||||
|
if extension == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return mime.TypeByExtension(extension)
|
||||||
|
}
|
||||||
|
|
||||||
|
func geminiMediaMime(mimeType string, mediaType string) string {
|
||||||
|
normalized := strings.ToLower(strings.TrimSpace(strings.Split(mimeType, ";")[0]))
|
||||||
|
switch mediaType {
|
||||||
|
case "image":
|
||||||
|
if strings.HasPrefix(normalized, "image/") && normalized != "image/svg+xml" {
|
||||||
|
return normalized
|
||||||
|
}
|
||||||
|
return "image/png"
|
||||||
|
case "video":
|
||||||
|
switch normalized {
|
||||||
|
case "video/x-msvideo":
|
||||||
|
return "video/avi"
|
||||||
|
case "video/quicktime", "video/mpeg", "video/mp4", "video/avi", "video/x-flv", "video/mpg", "video/webm", "video/wmv", "video/3gpp":
|
||||||
|
return normalized
|
||||||
|
default:
|
||||||
|
return "video/mp4"
|
||||||
|
}
|
||||||
|
case "audio":
|
||||||
|
switch normalized {
|
||||||
|
case "audio/x-wav", "audio/wave":
|
||||||
|
return "audio/wav"
|
||||||
|
case "audio/mpeg", "audio/mp3", "audio/wav", "audio/aiff", "audio/aac", "audio/ogg", "audio/flac", "audio/mp4", "audio/webm":
|
||||||
|
return normalized
|
||||||
|
default:
|
||||||
|
return "audio/mpeg"
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return "application/octet-stream"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toolCallsSlice(value any) []any {
|
||||||
|
switch typed := value.(type) {
|
||||||
|
case []any:
|
||||||
|
return typed
|
||||||
|
case map[string]any:
|
||||||
|
return []any{typed}
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func geminiFunctionArgs(value any) map[string]any {
|
||||||
|
if value == nil {
|
||||||
|
return map[string]any{}
|
||||||
|
}
|
||||||
|
if args, ok := value.(map[string]any); ok {
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
if text, ok := value.(string); ok {
|
||||||
|
if strings.TrimSpace(text) == "" {
|
||||||
|
return map[string]any{}
|
||||||
|
}
|
||||||
|
var args map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(text), &args); err == nil {
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
return map[string]any{"arguments": text}
|
||||||
|
}
|
||||||
|
return map[string]any{"arguments": value}
|
||||||
|
}
|
||||||
|
|
||||||
|
func geminiFunctionResponsePayload(value any) map[string]any {
|
||||||
|
if payload, ok := value.(map[string]any); ok {
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
if text, ok := value.(string); ok {
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(text), &payload); err == nil {
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
return map[string]any{"content": text}
|
||||||
|
}
|
||||||
|
if value == nil {
|
||||||
|
return map[string]any{}
|
||||||
|
}
|
||||||
|
return map[string]any{"content": value}
|
||||||
|
}
|
||||||
|
|
||||||
|
func geminiToolsFromOpenAITools(value any) []any {
|
||||||
|
tools, ok := value.([]any)
|
||||||
|
if !ok || len(tools) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
declarations := make([]any, 0, len(tools))
|
||||||
|
for _, rawTool := range tools {
|
||||||
|
tool, _ := rawTool.(map[string]any)
|
||||||
|
function, _ := tool["function"].(map[string]any)
|
||||||
|
name := stringFromAny(function["name"])
|
||||||
|
if name == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
declaration := map[string]any{"name": name}
|
||||||
|
if description := stringFromAny(function["description"]); description != "" {
|
||||||
|
declaration["description"] = description
|
||||||
|
}
|
||||||
|
if parameters, ok := function["parameters"]; ok {
|
||||||
|
declaration["parameters"] = parameters
|
||||||
|
}
|
||||||
|
declarations = append(declarations, declaration)
|
||||||
|
}
|
||||||
|
if len(declarations) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return []any{map[string]any{"functionDeclarations": declarations}}
|
||||||
}
|
}
|
||||||
|
|
||||||
func geminiResult(request Request, raw map[string]any) map[string]any {
|
func geminiResult(request Request, raw map[string]any) map[string]any {
|
||||||
@ -95,7 +400,7 @@ func geminiResult(request Request, raw map[string]any) map[string]any {
|
|||||||
"raw": raw,
|
"raw": raw,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
content := geminiText(raw)
|
message, finishReason := geminiChatMessage(raw)
|
||||||
return map[string]any{
|
return map[string]any{
|
||||||
"id": "gemini-chat",
|
"id": "gemini-chat",
|
||||||
"object": "chat.completion",
|
"object": "chat.completion",
|
||||||
@ -103,8 +408,8 @@ func geminiResult(request Request, raw map[string]any) map[string]any {
|
|||||||
"model": request.Model,
|
"model": request.Model,
|
||||||
"choices": []any{map[string]any{
|
"choices": []any{map[string]any{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"finish_reason": "stop",
|
"finish_reason": finishReason,
|
||||||
"message": map[string]any{"role": "assistant", "content": content},
|
"message": message,
|
||||||
}},
|
}},
|
||||||
"usage": geminiUsageMap(raw),
|
"usage": geminiUsageMap(raw),
|
||||||
"raw": raw,
|
"raw": raw,
|
||||||
@ -133,19 +438,59 @@ func textFromMessages(body map[string]any) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func geminiText(raw map[string]any) string {
|
func geminiText(raw map[string]any) string {
|
||||||
|
message, _ := geminiChatMessage(raw)
|
||||||
|
content, _ := message["content"].(string)
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
|
||||||
|
func geminiChatMessage(raw map[string]any) (map[string]any, string) {
|
||||||
candidates, _ := raw["candidates"].([]any)
|
candidates, _ := raw["candidates"].([]any)
|
||||||
for _, candidate := range candidates {
|
for _, candidate := range candidates {
|
||||||
candidateMap, _ := candidate.(map[string]any)
|
candidateMap, _ := candidate.(map[string]any)
|
||||||
content, _ := candidateMap["content"].(map[string]any)
|
content, _ := candidateMap["content"].(map[string]any)
|
||||||
parts, _ := content["parts"].([]any)
|
parts, _ := content["parts"].([]any)
|
||||||
|
textParts := make([]string, 0, len(parts))
|
||||||
|
toolCalls := make([]any, 0)
|
||||||
for _, part := range parts {
|
for _, part := range parts {
|
||||||
partMap, _ := part.(map[string]any)
|
partMap, _ := part.(map[string]any)
|
||||||
if text, ok := partMap["text"].(string); ok && text != "" {
|
if text, ok := partMap["text"].(string); ok && text != "" {
|
||||||
return text
|
textParts = append(textParts, text)
|
||||||
|
}
|
||||||
|
functionCall := mapFromAny(firstPresent(partMap["functionCall"], partMap["function_call"]))
|
||||||
|
if len(functionCall) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if toolCall := normalizeGeminiFunctionCall(functionCall, len(toolCalls), false); toolCall != nil {
|
||||||
|
toolCalls = append(toolCalls, toolCall)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
message := map[string]any{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": strings.Join(textParts, ""),
|
||||||
|
}
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
message["tool_calls"] = toolCalls
|
||||||
|
if len(textParts) == 0 {
|
||||||
|
message["content"] = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return message, geminiFinishReason(candidateMap, len(toolCalls) > 0)
|
||||||
|
}
|
||||||
|
return map[string]any{"role": "assistant", "content": ""}, "stop"
|
||||||
|
}
|
||||||
|
|
||||||
|
func geminiFinishReason(candidate map[string]any, hasToolCalls bool) string {
|
||||||
|
if hasToolCalls {
|
||||||
|
return "tool_calls"
|
||||||
|
}
|
||||||
|
switch strings.ToUpper(stringFromAny(candidate["finishReason"])) {
|
||||||
|
case "MAX_TOKENS":
|
||||||
|
return "length"
|
||||||
|
case "SAFETY", "RECITATION", "BLOCKLIST", "PROHIBITED_CONTENT", "SPII":
|
||||||
|
return "content_filter"
|
||||||
|
default:
|
||||||
|
return "stop"
|
||||||
}
|
}
|
||||||
return ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func geminiImageData(raw map[string]any) []any {
|
func geminiImageData(raw map[string]any) []any {
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"math"
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -87,8 +88,11 @@ func decodeOpenAIStreamReader(reader io.Reader, onDelta StreamDelta) (map[string
|
|||||||
scanner.Buffer(make([]byte, 0, 64*1024), 16*1024*1024)
|
scanner.Buffer(make([]byte, 0, 64*1024), 16*1024*1024)
|
||||||
rawLines := make([]string, 0)
|
rawLines := make([]string, 0)
|
||||||
parts := make([]string, 0)
|
parts := make([]string, 0)
|
||||||
|
reasoningParts := make([]string, 0)
|
||||||
var last map[string]any
|
var last map[string]any
|
||||||
var usage Usage
|
var usage Usage
|
||||||
|
finishReason := ""
|
||||||
|
toolCalls := map[int]map[string]any{}
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
rawLine := scanner.Text()
|
rawLine := scanner.Text()
|
||||||
rawLines = append(rawLines, rawLine)
|
rawLines = append(rawLines, rawLine)
|
||||||
@ -104,13 +108,23 @@ func decodeOpenAIStreamReader(reader io.Reader, onDelta StreamDelta) (map[string
|
|||||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
event = NormalizeChatCompletionStreamEvent(event)
|
||||||
last = event
|
last = event
|
||||||
if text := streamEventText(event); text != "" {
|
text := streamEventText(event)
|
||||||
|
reasoningText := streamEventReasoningContent(event)
|
||||||
|
if text != "" {
|
||||||
parts = append(parts, text)
|
parts = append(parts, text)
|
||||||
if onDelta != nil {
|
|
||||||
if err := onDelta(text); err != nil {
|
|
||||||
return nil, true, err
|
|
||||||
}
|
}
|
||||||
|
if reasoningText != "" {
|
||||||
|
reasoningParts = append(reasoningParts, reasoningText)
|
||||||
|
}
|
||||||
|
aggregateStreamToolCalls(event, toolCalls)
|
||||||
|
if reason := streamEventFinishReason(event); reason != "" {
|
||||||
|
finishReason = reason
|
||||||
|
}
|
||||||
|
if onDelta != nil {
|
||||||
|
if err := onDelta(StreamDeltaEvent{Text: text, ReasoningContent: reasoningText, Event: event}); err != nil {
|
||||||
|
return nil, true, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if eventUsage := usageFromOpenAI(event); eventUsage.TotalTokens > 0 {
|
if eventUsage := usageFromOpenAI(event); eventUsage.TotalTokens > 0 {
|
||||||
@ -131,7 +145,7 @@ func decodeOpenAIStreamReader(reader io.Reader, onDelta StreamDelta) (map[string
|
|||||||
}
|
}
|
||||||
return out, true, nil
|
return out, true, nil
|
||||||
}
|
}
|
||||||
return buildOpenAIStreamResult(last, parts, usage), true, nil
|
return buildOpenAIStreamResult(last, parts, reasoningParts, toolCalls, finishReason, usage), true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func decodeOpenAIStream(raw []byte) (map[string]any, bool) {
|
func decodeOpenAIStream(raw []byte) (map[string]any, bool) {
|
||||||
@ -142,10 +156,23 @@ func decodeOpenAIStream(raw []byte) (map[string]any, bool) {
|
|||||||
return result, ok && err == nil
|
return result, ok && err == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildOpenAIStreamResult(last map[string]any, parts []string, usage Usage) map[string]any {
|
func buildOpenAIStreamResult(last map[string]any, parts []string, reasoningParts []string, toolCalls map[int]map[string]any, finishReason string, usage Usage) map[string]any {
|
||||||
if len(parts) == 0 {
|
if len(parts) == 0 && len(reasoningParts) == 0 && len(toolCalls) == 0 {
|
||||||
return last
|
return last
|
||||||
}
|
}
|
||||||
|
message := map[string]any{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": strings.Join(parts, ""),
|
||||||
|
}
|
||||||
|
if len(reasoningParts) > 0 {
|
||||||
|
message["reasoning_content"] = strings.Join(reasoningParts, "")
|
||||||
|
}
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
message["tool_calls"] = sortedStreamToolCalls(toolCalls)
|
||||||
|
}
|
||||||
|
if finishReason == "" {
|
||||||
|
finishReason = "stop"
|
||||||
|
}
|
||||||
var out map[string]any
|
var out map[string]any
|
||||||
out = map[string]any{
|
out = map[string]any{
|
||||||
"id": stringFromAny(firstPresent(last["id"], "chatcmpl-stream")),
|
"id": stringFromAny(firstPresent(last["id"], "chatcmpl-stream")),
|
||||||
@ -153,11 +180,8 @@ func buildOpenAIStreamResult(last map[string]any, parts []string, usage Usage) m
|
|||||||
"model": stringFromAny(last["model"]),
|
"model": stringFromAny(last["model"]),
|
||||||
"choices": []any{map[string]any{
|
"choices": []any{map[string]any{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"message": map[string]any{
|
"message": message,
|
||||||
"role": "assistant",
|
"finish_reason": finishReason,
|
||||||
"content": strings.Join(parts, ""),
|
|
||||||
},
|
|
||||||
"finish_reason": "stop",
|
|
||||||
}},
|
}},
|
||||||
}
|
}
|
||||||
if usage.TotalTokens > 0 {
|
if usage.TotalTokens > 0 {
|
||||||
@ -170,6 +194,571 @@ func buildOpenAIStreamResult(last map[string]any, parts []string, usage Usage) m
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NormalizeChatCompletionRequestBody 将后续请求里的工具调用上下文还原为
|
||||||
|
// OpenAI Chat Completions 标准格式,便于再次发送给 OpenAI-compatible 上游。
|
||||||
|
func NormalizeChatCompletionRequestBody(body map[string]any) map[string]any {
|
||||||
|
if body == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := cloneBody(body)
|
||||||
|
messages, ok := out["messages"].([]any)
|
||||||
|
if !ok {
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
normalizedMessages := make([]any, 0, len(messages))
|
||||||
|
for _, rawMessage := range messages {
|
||||||
|
message, ok := rawMessage.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
normalizedMessages = append(normalizedMessages, rawMessage)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
copied := cloneMapAny(message)
|
||||||
|
normalizeToolCallsContainer(copied, false)
|
||||||
|
normalizeToolMessageFields(copied)
|
||||||
|
toolMessages, cleanContent, changed := toolResultMessagesFromContent(copied["content"])
|
||||||
|
if changed {
|
||||||
|
if cleanContent != nil && contentHasText(cleanContent) {
|
||||||
|
copied["content"] = cleanContent
|
||||||
|
normalizedMessages = append(normalizedMessages, copied)
|
||||||
|
} else if len(copied) > 1 || copied["role"] != nil {
|
||||||
|
delete(copied, "content")
|
||||||
|
if len(copied) > 1 {
|
||||||
|
normalizedMessages = append(normalizedMessages, copied)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
normalizedMessages = append(normalizedMessages, toolMessages...)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
normalizedMessages = append(normalizedMessages, copied)
|
||||||
|
}
|
||||||
|
out["messages"] = normalizedMessages
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneMapAny(source map[string]any) map[string]any {
|
||||||
|
if source == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make(map[string]any, len(source))
|
||||||
|
for key, value := range source {
|
||||||
|
out[key] = value
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// NormalizeChatCompletionResult 将供应商自定义推理字段归一化到
|
||||||
|
// message.reasoning_content,并从最终回答 content 中剥离内联推理块。
|
||||||
|
// 加密推理载荷不可展示,且不应作为正文输出,因此会被忽略。
|
||||||
|
func NormalizeChatCompletionResult(result map[string]any) map[string]any {
|
||||||
|
if result == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
choices, _ := result["choices"].([]any)
|
||||||
|
for _, rawChoice := range choices {
|
||||||
|
choice, _ := rawChoice.(map[string]any)
|
||||||
|
if message, ok := choice["message"].(map[string]any); ok {
|
||||||
|
normalizeToolCallsContainer(message, false)
|
||||||
|
normalizeReasoningContainer(message, false)
|
||||||
|
}
|
||||||
|
if delta, ok := choice["delta"].(map[string]any); ok {
|
||||||
|
normalizeToolCallsContainer(delta, true)
|
||||||
|
normalizeReasoningContainer(delta, true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// NormalizeChatCompletionStreamEvent 将供应商自定义流式推理字段
|
||||||
|
// (例如 reasoning_details 或 reasoning)归一化到 delta.reasoning_content。
|
||||||
|
func NormalizeChatCompletionStreamEvent(event map[string]any) map[string]any {
|
||||||
|
if event == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
choices, _ := event["choices"].([]any)
|
||||||
|
for _, rawChoice := range choices {
|
||||||
|
choice, _ := rawChoice.(map[string]any)
|
||||||
|
if delta, ok := choice["delta"].(map[string]any); ok {
|
||||||
|
normalizeToolCallsContainer(delta, true)
|
||||||
|
normalizeReasoningContainer(delta, true)
|
||||||
|
}
|
||||||
|
if message, ok := choice["message"].(map[string]any); ok {
|
||||||
|
normalizeToolCallsContainer(message, false)
|
||||||
|
normalizeReasoningContainer(message, false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return event
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeToolCallsContainer(container map[string]any, stream bool) {
|
||||||
|
if container == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
toolCalls := make([]any, 0)
|
||||||
|
for _, rawToolCall := range rawToolCallValues(container) {
|
||||||
|
for _, normalized := range normalizeRawToolCalls(rawToolCall, len(toolCalls), stream) {
|
||||||
|
toolCalls = append(toolCalls, normalized)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if contentToolCalls, cleanContent, changed := toolCallsFromContent(container["content"], len(toolCalls), stream); changed {
|
||||||
|
toolCalls = append(toolCalls, contentToolCalls...)
|
||||||
|
setNormalizedContent(container, cleanContent, stream)
|
||||||
|
}
|
||||||
|
if partToolCalls := toolCallsFromParts(container["parts"], len(toolCalls), stream); len(partToolCalls) > 0 {
|
||||||
|
toolCalls = append(toolCalls, partToolCalls...)
|
||||||
|
delete(container, "parts")
|
||||||
|
}
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
container["tool_calls"] = toolCalls
|
||||||
|
}
|
||||||
|
for _, key := range []string{"tool_call", "toolCall", "toolCalls", "function_call", "functionCall"} {
|
||||||
|
delete(container, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeToolMessageFields(message map[string]any) {
|
||||||
|
if message == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if id := firstNonEmptyString(message["tool_call_id"], message["toolCallId"], message["tool_use_id"], message["toolUseId"], message["call_id"], message["callId"]); id != "" {
|
||||||
|
message["tool_call_id"] = id
|
||||||
|
}
|
||||||
|
for _, key := range []string{"toolCallId", "tool_use_id", "toolUseId", "call_id", "callId"} {
|
||||||
|
delete(message, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toolResultMessagesFromContent(value any) ([]any, any, bool) {
|
||||||
|
blocks, ok := value.([]any)
|
||||||
|
if !ok {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
toolMessages := make([]any, 0)
|
||||||
|
remaining := make([]any, 0, len(blocks))
|
||||||
|
for _, rawBlock := range blocks {
|
||||||
|
block, _ := rawBlock.(map[string]any)
|
||||||
|
if len(block) == 0 || stringFromAny(block["type"]) != "tool_result" {
|
||||||
|
remaining = append(remaining, rawBlock)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
message := map[string]any{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": firstNonEmptyString(block["tool_call_id"], block["toolCallId"], block["tool_use_id"], block["toolUseId"], block["id"]),
|
||||||
|
"content": toolResultContent(block["content"]),
|
||||||
|
}
|
||||||
|
toolMessages = append(toolMessages, message)
|
||||||
|
}
|
||||||
|
if len(toolMessages) == 0 {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
return toolMessages, contentBlocksText(remaining), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func toolResultContent(value any) any {
|
||||||
|
if text, ok := value.(string); ok {
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
return jsonStringFromAny(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func contentHasText(value any) bool {
|
||||||
|
switch typed := value.(type) {
|
||||||
|
case string:
|
||||||
|
return strings.TrimSpace(typed) != ""
|
||||||
|
case []any:
|
||||||
|
return len(typed) > 0
|
||||||
|
default:
|
||||||
|
return value != nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func rawToolCallValues(container map[string]any) []any {
|
||||||
|
values := make([]any, 0, 6)
|
||||||
|
for _, key := range []string{"tool_calls", "tool_call", "toolCalls", "toolCall", "function_call", "functionCall"} {
|
||||||
|
if value, ok := container[key]; ok {
|
||||||
|
values = append(values, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return values
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeRawToolCalls(value any, startIndex int, stream bool) []any {
|
||||||
|
switch typed := value.(type) {
|
||||||
|
case []any:
|
||||||
|
out := make([]any, 0, len(typed))
|
||||||
|
for _, raw := range typed {
|
||||||
|
if toolCall := normalizeToolCall(raw, startIndex+len(out), stream); toolCall != nil {
|
||||||
|
out = append(out, toolCall)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
default:
|
||||||
|
if toolCall := normalizeToolCall(value, startIndex, stream); toolCall != nil {
|
||||||
|
return []any{toolCall}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeToolCall(value any, index int, stream bool) map[string]any {
|
||||||
|
source, _ := value.(map[string]any)
|
||||||
|
if len(source) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
functionSource := mapFromAny(source["function"])
|
||||||
|
if len(functionSource) == 0 {
|
||||||
|
functionSource = mapFromAny(firstPresent(source["function_call"], source["functionCall"]))
|
||||||
|
}
|
||||||
|
name := firstNonEmptyString(
|
||||||
|
functionSource["name"], source["name"], source["function_name"], source["functionName"], source["tool_name"], source["toolName"],
|
||||||
|
)
|
||||||
|
arguments, hasArguments := toolCallArguments(functionSource)
|
||||||
|
if !hasArguments {
|
||||||
|
arguments, hasArguments = toolCallArguments(source)
|
||||||
|
}
|
||||||
|
if name == "" && !hasArguments && firstNonEmptyString(source["id"], source["call_id"], source["callId"], source["tool_call_id"], source["toolCallId"]) == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
function := map[string]any{}
|
||||||
|
if name != "" {
|
||||||
|
function["name"] = name
|
||||||
|
}
|
||||||
|
if hasArguments {
|
||||||
|
function["arguments"] = arguments
|
||||||
|
}
|
||||||
|
toolCall := map[string]any{
|
||||||
|
"type": firstNonEmptyString(source["type"], "function"),
|
||||||
|
"function": function,
|
||||||
|
}
|
||||||
|
if id := firstNonEmptyString(source["id"], source["call_id"], source["callId"], source["tool_call_id"], source["toolCallId"]); id != "" {
|
||||||
|
toolCall["id"] = id
|
||||||
|
} else if !stream {
|
||||||
|
toolCall["id"] = fmt.Sprintf("call_%d", index)
|
||||||
|
}
|
||||||
|
if stream {
|
||||||
|
if rawIndex, ok := firstPresent(source["index"], source["idx"]).(float64); ok {
|
||||||
|
toolCall["index"] = int(math.Round(rawIndex))
|
||||||
|
} else if rawIndex, ok := firstPresent(source["index"], source["idx"]).(int); ok {
|
||||||
|
toolCall["index"] = rawIndex
|
||||||
|
} else {
|
||||||
|
toolCall["index"] = index
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return toolCall
|
||||||
|
}
|
||||||
|
|
||||||
|
func toolCallsFromContent(value any, startIndex int, stream bool) ([]any, any, bool) {
|
||||||
|
blocks, ok := value.([]any)
|
||||||
|
if !ok {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
toolCalls := make([]any, 0)
|
||||||
|
remaining := make([]any, 0, len(blocks))
|
||||||
|
containsReasoning := false
|
||||||
|
for _, rawBlock := range blocks {
|
||||||
|
block, _ := rawBlock.(map[string]any)
|
||||||
|
if len(block) == 0 {
|
||||||
|
remaining = append(remaining, rawBlock)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch stringFromAny(block["type"]) {
|
||||||
|
case "tool_use":
|
||||||
|
if toolCall := normalizeToolUseBlock(block, startIndex+len(toolCalls), stream); toolCall != nil {
|
||||||
|
toolCalls = append(toolCalls, toolCall)
|
||||||
|
}
|
||||||
|
case "tool_result":
|
||||||
|
remaining = append(remaining, rawBlock)
|
||||||
|
default:
|
||||||
|
if isReasoningContentBlock(block) {
|
||||||
|
containsReasoning = true
|
||||||
|
}
|
||||||
|
remaining = append(remaining, rawBlock)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(toolCalls) == 0 {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if containsReasoning {
|
||||||
|
return toolCalls, remaining, true
|
||||||
|
}
|
||||||
|
return toolCalls, contentBlocksText(remaining), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeToolUseBlock(block map[string]any, index int, stream bool) map[string]any {
|
||||||
|
toolCall := map[string]any{
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]any{
|
||||||
|
"name": stringFromAny(block["name"]),
|
||||||
|
"arguments": jsonStringFromAny(block["input"]),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if id := firstNonEmptyString(block["id"], block["tool_use_id"], block["toolUseId"]); id != "" {
|
||||||
|
toolCall["id"] = id
|
||||||
|
} else if !stream {
|
||||||
|
toolCall["id"] = fmt.Sprintf("call_%d", index)
|
||||||
|
}
|
||||||
|
if stream {
|
||||||
|
toolCall["index"] = index
|
||||||
|
}
|
||||||
|
return toolCall
|
||||||
|
}
|
||||||
|
|
||||||
|
func toolCallsFromParts(value any, startIndex int, stream bool) []any {
|
||||||
|
parts, ok := value.([]any)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]any, 0)
|
||||||
|
for _, rawPart := range parts {
|
||||||
|
part, _ := rawPart.(map[string]any)
|
||||||
|
if functionCall := mapFromAny(firstPresent(part["functionCall"], part["function_call"])); len(functionCall) > 0 {
|
||||||
|
if toolCall := normalizeGeminiFunctionCall(functionCall, startIndex+len(out), stream); toolCall != nil {
|
||||||
|
out = append(out, toolCall)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeGeminiFunctionCall(functionCall map[string]any, index int, stream bool) map[string]any {
|
||||||
|
toolCall := map[string]any{
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]any{
|
||||||
|
"name": stringFromAny(functionCall["name"]),
|
||||||
|
"arguments": jsonStringFromAny(firstPresent(functionCall["args"], functionCall["arguments"])),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if id := firstNonEmptyString(functionCall["id"], functionCall["call_id"], functionCall["callId"]); id != "" {
|
||||||
|
toolCall["id"] = id
|
||||||
|
} else if !stream {
|
||||||
|
toolCall["id"] = fmt.Sprintf("call_%d", index)
|
||||||
|
}
|
||||||
|
if stream {
|
||||||
|
toolCall["index"] = index
|
||||||
|
}
|
||||||
|
return toolCall
|
||||||
|
}
|
||||||
|
|
||||||
|
func setNormalizedContent(container map[string]any, value any, stream bool) {
|
||||||
|
if text, ok := value.(string); ok && text == "" {
|
||||||
|
if stream {
|
||||||
|
delete(container, "content")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
container["content"] = nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
container["content"] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
func isReasoningContentBlock(block map[string]any) bool {
|
||||||
|
switch stringFromAny(block["type"]) {
|
||||||
|
case "thinking", "redacted_thinking", "reasoning.text", "reasoning.summary", "reasoning.encrypted":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func contentBlocksText(blocks []any) string {
|
||||||
|
parts := make([]string, 0, len(blocks))
|
||||||
|
for _, rawBlock := range blocks {
|
||||||
|
switch block := rawBlock.(type) {
|
||||||
|
case string:
|
||||||
|
parts = append(parts, block)
|
||||||
|
case map[string]any:
|
||||||
|
if text := stringFromAny(firstPresent(block["text"], block["content"])); text != "" {
|
||||||
|
parts = append(parts, text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.Join(parts, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func toolCallArguments(source map[string]any) (string, bool) {
|
||||||
|
for _, key := range []string{"arguments", "args", "input", "parameters"} {
|
||||||
|
if value, ok := source[key]; ok {
|
||||||
|
return jsonStringFromAny(value), true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
func jsonStringFromAny(value any) string {
|
||||||
|
if value == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if text, ok := value.(string); ok {
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
encoded, err := json.Marshal(value)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return string(encoded)
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeReasoningContainer(container map[string]any, deleteEmptyContent bool) {
|
||||||
|
if container == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
reasoningParts := make([]string, 0, 3)
|
||||||
|
if reasoning := reasoningDetailsText(container["reasoning_details"]); reasoning != "" {
|
||||||
|
reasoningParts = append(reasoningParts, reasoning)
|
||||||
|
} else if reasoning := stringFromAny(container["reasoning_content"]); reasoning != "" {
|
||||||
|
reasoningParts = append(reasoningParts, reasoning)
|
||||||
|
} else if reasoning := stringFromAny(container["reasoning"]); reasoning != "" {
|
||||||
|
reasoningParts = append(reasoningParts, reasoning)
|
||||||
|
}
|
||||||
|
if content, ok := container["content"]; ok {
|
||||||
|
cleanContent, contentReasoning, changed := normalizeReasoningContentValue(content)
|
||||||
|
if changed {
|
||||||
|
if deleteEmptyContent {
|
||||||
|
if text, ok := cleanContent.(string); ok && text == "" {
|
||||||
|
delete(container, "content")
|
||||||
|
} else {
|
||||||
|
container["content"] = cleanContent
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
container["content"] = cleanContent
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if contentReasoning != "" {
|
||||||
|
reasoningParts = append(reasoningParts, contentReasoning)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(reasoningParts) > 0 {
|
||||||
|
container["reasoning_content"] = strings.Join(reasoningParts, "")
|
||||||
|
}
|
||||||
|
delete(container, "reasoning_details")
|
||||||
|
delete(container, "reasoning")
|
||||||
|
}
|
||||||
|
|
||||||
|
func reasoningDetailsText(value any) string {
|
||||||
|
rawItems, ok := value.([]any)
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
parts := make([]string, 0, len(rawItems))
|
||||||
|
for _, rawItem := range rawItems {
|
||||||
|
item, _ := rawItem.(map[string]any)
|
||||||
|
switch stringFromAny(item["type"]) {
|
||||||
|
case "reasoning.text":
|
||||||
|
if text := stringFromAny(item["text"]); text != "" {
|
||||||
|
parts = append(parts, text)
|
||||||
|
}
|
||||||
|
case "reasoning.summary":
|
||||||
|
if summary := stringFromAny(item["summary"]); summary != "" {
|
||||||
|
parts = append(parts, summary)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.Join(parts, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeReasoningContentValue(value any) (any, string, bool) {
|
||||||
|
switch typed := value.(type) {
|
||||||
|
case string:
|
||||||
|
cleanContent, reasoning, changed := splitTaggedReasoningText(typed)
|
||||||
|
return cleanContent, reasoning, changed
|
||||||
|
case []any:
|
||||||
|
contentParts := make([]string, 0, len(typed))
|
||||||
|
reasoningParts := make([]string, 0)
|
||||||
|
changed := false
|
||||||
|
for _, rawItem := range typed {
|
||||||
|
switch item := rawItem.(type) {
|
||||||
|
case string:
|
||||||
|
contentParts = append(contentParts, item)
|
||||||
|
case map[string]any:
|
||||||
|
switch stringFromAny(item["type"]) {
|
||||||
|
case "thinking":
|
||||||
|
if thinking := stringFromAny(item["thinking"]); thinking != "" {
|
||||||
|
reasoningParts = append(reasoningParts, thinking)
|
||||||
|
}
|
||||||
|
changed = true
|
||||||
|
case "redacted_thinking", "reasoning.encrypted":
|
||||||
|
changed = true
|
||||||
|
case "reasoning.text":
|
||||||
|
if text := stringFromAny(item["text"]); text != "" {
|
||||||
|
reasoningParts = append(reasoningParts, text)
|
||||||
|
}
|
||||||
|
changed = true
|
||||||
|
case "reasoning.summary":
|
||||||
|
if summary := stringFromAny(item["summary"]); summary != "" {
|
||||||
|
reasoningParts = append(reasoningParts, summary)
|
||||||
|
}
|
||||||
|
changed = true
|
||||||
|
case "text", "output_text":
|
||||||
|
if text := stringFromAny(firstPresent(item["text"], item["content"])); text != "" {
|
||||||
|
cleanText, reasoning, tagged := splitTaggedReasoningText(text)
|
||||||
|
contentParts = append(contentParts, cleanText)
|
||||||
|
if reasoning != "" {
|
||||||
|
reasoningParts = append(reasoningParts, reasoning)
|
||||||
|
}
|
||||||
|
changed = changed || tagged
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if text := stringFromAny(firstPresent(item["text"], item["content"])); text != "" {
|
||||||
|
contentParts = append(contentParts, text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !changed {
|
||||||
|
return value, "", false
|
||||||
|
}
|
||||||
|
return strings.Join(contentParts, ""), strings.Join(reasoningParts, ""), true
|
||||||
|
default:
|
||||||
|
return value, "", false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func splitTaggedReasoningText(text string) (string, string, bool) {
|
||||||
|
lower := strings.ToLower(text)
|
||||||
|
clean := strings.Builder{}
|
||||||
|
reasoning := strings.Builder{}
|
||||||
|
changed := false
|
||||||
|
for offset := 0; offset < len(text); {
|
||||||
|
start, tag := nextReasoningOpenTag(lower, offset)
|
||||||
|
if start < 0 {
|
||||||
|
clean.WriteString(text[offset:])
|
||||||
|
break
|
||||||
|
}
|
||||||
|
clean.WriteString(text[offset:start])
|
||||||
|
openEnd := start + len("<"+tag+">")
|
||||||
|
closeToken := "</" + tag + ">"
|
||||||
|
closeStart := strings.Index(lower[openEnd:], closeToken)
|
||||||
|
if closeStart < 0 {
|
||||||
|
reasoning.WriteString(text[openEnd:])
|
||||||
|
offset = len(text)
|
||||||
|
changed = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
closeStart += openEnd
|
||||||
|
reasoning.WriteString(text[openEnd:closeStart])
|
||||||
|
offset = closeStart + len(closeToken)
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
return clean.String(), reasoning.String(), changed
|
||||||
|
}
|
||||||
|
|
||||||
|
func nextReasoningOpenTag(lower string, offset int) (int, string) {
|
||||||
|
bestStart := -1
|
||||||
|
bestTag := ""
|
||||||
|
for _, tag := range []string{"think", "reasoning", "analysis"} {
|
||||||
|
needle := "<" + tag + ">"
|
||||||
|
idx := strings.Index(lower[offset:], needle)
|
||||||
|
if idx < 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
absolute := offset + idx
|
||||||
|
if bestStart < 0 || absolute < bestStart {
|
||||||
|
bestStart = absolute
|
||||||
|
bestTag = tag
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return bestStart, bestTag
|
||||||
|
}
|
||||||
|
|
||||||
func streamEventText(event map[string]any) string {
|
func streamEventText(event map[string]any) string {
|
||||||
if choices, ok := event["choices"].([]any); ok {
|
if choices, ok := event["choices"].([]any); ok {
|
||||||
for _, rawChoice := range choices {
|
for _, rawChoice := range choices {
|
||||||
@ -195,6 +784,91 @@ func streamEventText(event map[string]any) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func streamEventReasoningContent(event map[string]any) string {
|
||||||
|
if choices, ok := event["choices"].([]any); ok {
|
||||||
|
for _, rawChoice := range choices {
|
||||||
|
choice, _ := rawChoice.(map[string]any)
|
||||||
|
if delta, ok := choice["delta"].(map[string]any); ok {
|
||||||
|
if content, ok := delta["reasoning_content"].(string); ok {
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
if content, ok := delta["reasoning"].(string); ok {
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if message, ok := choice["message"].(map[string]any); ok {
|
||||||
|
if content, ok := message["reasoning_content"].(string); ok {
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func streamEventFinishReason(event map[string]any) string {
|
||||||
|
if choices, ok := event["choices"].([]any); ok {
|
||||||
|
for _, rawChoice := range choices {
|
||||||
|
choice, _ := rawChoice.(map[string]any)
|
||||||
|
if reason, ok := choice["finish_reason"].(string); ok && reason != "" {
|
||||||
|
return reason
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func aggregateStreamToolCalls(event map[string]any, toolCalls map[int]map[string]any) {
|
||||||
|
choices, _ := event["choices"].([]any)
|
||||||
|
for _, rawChoice := range choices {
|
||||||
|
choice, _ := rawChoice.(map[string]any)
|
||||||
|
delta, _ := choice["delta"].(map[string]any)
|
||||||
|
rawToolCalls, _ := delta["tool_calls"].([]any)
|
||||||
|
for _, rawToolCall := range rawToolCalls {
|
||||||
|
incoming, _ := rawToolCall.(map[string]any)
|
||||||
|
index := intFromAny(incoming["index"])
|
||||||
|
current := toolCalls[index]
|
||||||
|
if current == nil {
|
||||||
|
current = map[string]any{}
|
||||||
|
toolCalls[index] = current
|
||||||
|
}
|
||||||
|
for _, key := range []string{"id", "type"} {
|
||||||
|
if value, ok := incoming[key].(string); ok && value != "" {
|
||||||
|
current[key] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
incomingFn, _ := incoming["function"].(map[string]any)
|
||||||
|
if len(incomingFn) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
currentFn, _ := current["function"].(map[string]any)
|
||||||
|
if currentFn == nil {
|
||||||
|
currentFn = map[string]any{}
|
||||||
|
current["function"] = currentFn
|
||||||
|
}
|
||||||
|
if name, ok := incomingFn["name"].(string); ok && name != "" {
|
||||||
|
currentFn["name"] = stringFromAny(currentFn["name"]) + name
|
||||||
|
}
|
||||||
|
if arguments, ok := incomingFn["arguments"].(string); ok && arguments != "" {
|
||||||
|
currentFn["arguments"] = stringFromAny(currentFn["arguments"]) + arguments
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func sortedStreamToolCalls(toolCalls map[int]map[string]any) []any {
|
||||||
|
indices := make([]int, 0, len(toolCalls))
|
||||||
|
for index := range toolCalls {
|
||||||
|
indices = append(indices, index)
|
||||||
|
}
|
||||||
|
sort.Ints(indices)
|
||||||
|
out := make([]any, 0, len(indices))
|
||||||
|
for _, index := range indices {
|
||||||
|
out = append(out, toolCalls[index])
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
func usageFromOpenAI(result map[string]any) Usage {
|
func usageFromOpenAI(result map[string]any) Usage {
|
||||||
usage, _ := result["usage"].(map[string]any)
|
usage, _ := result["usage"].(map[string]any)
|
||||||
input := intFromAny(firstPresent(usage["prompt_tokens"], usage["input_tokens"]))
|
input := intFromAny(firstPresent(usage["prompt_tokens"], usage["input_tokens"]))
|
||||||
@ -254,6 +928,15 @@ func stringFromAny(value any) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func firstNonEmptyString(values ...any) string {
|
||||||
|
for _, value := range values {
|
||||||
|
if text := strings.TrimSpace(stringFromAny(value)); text != "" {
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
func firstPresent(values ...any) any {
|
func firstPresent(values ...any) any {
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
if value != nil {
|
if value != nil {
|
||||||
|
|||||||
960
apps/api/internal/clients/keling.go
Normal file
960
apps/api/internal/clients/keling.go
Normal file
@ -0,0 +1,960 @@
|
|||||||
|
package clients
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
"net/http"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
type KelingClient struct {
|
||||||
|
HTTPClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
type kelingPreparedTask struct {
|
||||||
|
Endpoint string
|
||||||
|
Payload map[string]any
|
||||||
|
RemoteTaskPayload map[string]any
|
||||||
|
CleanupElementIDs []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c KelingClient) Run(ctx context.Context, request Request) (Response, error) {
|
||||||
|
if request.Kind != "videos.generations" {
|
||||||
|
return Response{}, &ClientError{Code: "unsupported_kind", Message: "unsupported keling request kind", Retryable: false}
|
||||||
|
}
|
||||||
|
token, err := kelingAuthToken(request.Candidate)
|
||||||
|
if err != nil {
|
||||||
|
return Response{}, err
|
||||||
|
}
|
||||||
|
return c.runVideo(ctx, request, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c KelingClient) runVideo(ctx context.Context, request Request, token string) (Response, error) {
|
||||||
|
submitStartedAt := time.Now()
|
||||||
|
submitRequestID := strings.TrimSpace(request.RemoteTaskID)
|
||||||
|
upstreamTaskID := strings.TrimSpace(request.RemoteTaskID)
|
||||||
|
prepared := kelingResumePreparedTask(request)
|
||||||
|
if upstreamTaskID == "" {
|
||||||
|
var err error
|
||||||
|
prepared, err = c.prepareVideoTask(ctx, request, token)
|
||||||
|
if err != nil {
|
||||||
|
return Response{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if upstreamTaskID == "" {
|
||||||
|
_ = c.cleanupKelingElements(context.WithoutCancel(ctx), request, token, prepared.CleanupElementIDs)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if upstreamTaskID == "" {
|
||||||
|
submitResult, requestID, err := c.postJSON(ctx, request, prepared.Endpoint, token, prepared.Payload)
|
||||||
|
submitRequestID = requestID
|
||||||
|
if err != nil {
|
||||||
|
return Response{}, annotateResponseError(err, submitRequestID, submitStartedAt, time.Now())
|
||||||
|
}
|
||||||
|
upstreamTaskID = strings.TrimSpace(stringFromAny(kelingData(submitResult)["task_id"]))
|
||||||
|
if upstreamTaskID == "" {
|
||||||
|
_ = c.cleanupKelingElements(context.WithoutCancel(ctx), request, token, prepared.CleanupElementIDs)
|
||||||
|
return Response{}, &ClientError{Code: "invalid_response", Message: "keling video task id is missing", RequestID: submitRequestID, Retryable: false}
|
||||||
|
}
|
||||||
|
prepared.RemoteTaskPayload["submit"] = submitResult
|
||||||
|
if request.OnRemoteTaskSubmitted != nil {
|
||||||
|
if err := request.OnRemoteTaskSubmitted(upstreamTaskID, prepared.RemoteTaskPayload); err != nil {
|
||||||
|
return Response{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pollEndpoint := kelingPollEndpoint(request, prepared.Endpoint)
|
||||||
|
interval := kelingPollInterval(request)
|
||||||
|
timeout := kelingPollTimeout(request)
|
||||||
|
deadline := time.NewTimer(timeout)
|
||||||
|
defer deadline.Stop()
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
var lastResult map[string]any
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return Response{}, &ClientError{Code: "cancelled", Message: ctx.Err().Error(), RequestID: submitRequestID, Retryable: true}
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
pollStartedAt := time.Now()
|
||||||
|
pollResult, pollRequestID, err := c.getJSON(ctx, request, pollEndpoint+"/"+upstreamTaskID, token)
|
||||||
|
pollFinishedAt := time.Now()
|
||||||
|
requestID := firstNonEmpty(pollRequestID, submitRequestID, upstreamTaskID)
|
||||||
|
if err != nil {
|
||||||
|
return Response{}, annotateResponseError(err, requestID, pollStartedAt, pollFinishedAt)
|
||||||
|
}
|
||||||
|
lastResult = pollResult
|
||||||
|
|
||||||
|
switch kelingTaskStatus(pollResult) {
|
||||||
|
case "succeed":
|
||||||
|
_ = c.cleanupKelingElements(context.WithoutCancel(ctx), request, token, prepared.CleanupElementIDs)
|
||||||
|
prepared.CleanupElementIDs = nil
|
||||||
|
result := kelingVideoSuccessResult(request, upstreamTaskID, pollResult)
|
||||||
|
return Response{
|
||||||
|
Result: result,
|
||||||
|
RequestID: requestID,
|
||||||
|
Progress: kelingVideoProgress(request, upstreamTaskID),
|
||||||
|
ResponseStartedAt: submitStartedAt,
|
||||||
|
ResponseFinishedAt: pollFinishedAt,
|
||||||
|
ResponseDurationMS: responseDurationMS(submitStartedAt, pollFinishedAt),
|
||||||
|
}, nil
|
||||||
|
case "failed":
|
||||||
|
_ = c.cleanupKelingElements(context.WithoutCancel(ctx), request, token, prepared.CleanupElementIDs)
|
||||||
|
prepared.CleanupElementIDs = nil
|
||||||
|
return Response{}, &ClientError{
|
||||||
|
Code: kelingTaskErrorCode(pollResult),
|
||||||
|
Message: kelingTaskErrorMessage(request.Candidate, pollResult),
|
||||||
|
RequestID: requestID,
|
||||||
|
ResponseStartedAt: submitStartedAt,
|
||||||
|
ResponseFinishedAt: pollFinishedAt,
|
||||||
|
ResponseDurationMS: responseDurationMS(submitStartedAt, pollFinishedAt),
|
||||||
|
Retryable: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return Response{}, &ClientError{Code: "cancelled", Message: ctx.Err().Error(), RequestID: requestID, Retryable: true}
|
||||||
|
case <-deadline.C:
|
||||||
|
return Response{}, &ClientError{
|
||||||
|
Code: "timeout",
|
||||||
|
Message: fmt.Sprintf("keling video task %s did not finish before timeout; last status: %s", upstreamTaskID, kelingTaskStatus(lastResult)),
|
||||||
|
RequestID: requestID,
|
||||||
|
Retryable: true,
|
||||||
|
}
|
||||||
|
case <-ticker.C:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c KelingClient) prepareVideoTask(ctx context.Context, request Request, token string) (kelingPreparedTask, error) {
|
||||||
|
if kelingIsOmniRequest(request) {
|
||||||
|
payload, cleanupIDs, err := c.kelingOmniPayload(ctx, request, token)
|
||||||
|
if err != nil {
|
||||||
|
return kelingPreparedTask{}, err
|
||||||
|
}
|
||||||
|
return kelingPreparedTask{
|
||||||
|
Endpoint: "/videos/omni-video",
|
||||||
|
Payload: payload,
|
||||||
|
RemoteTaskPayload: map[string]any{"endpoint": "/videos/omni-video", "mode": "omni_video", "cleanupElementIds": cleanupIDs},
|
||||||
|
CleanupElementIDs: cleanupIDs,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
payload, taskType, err := kelingVideoPayload(ctx, request)
|
||||||
|
if err != nil {
|
||||||
|
return kelingPreparedTask{}, err
|
||||||
|
}
|
||||||
|
endpoint := "/videos/" + taskType
|
||||||
|
return kelingPreparedTask{
|
||||||
|
Endpoint: endpoint,
|
||||||
|
Payload: payload,
|
||||||
|
RemoteTaskPayload: map[string]any{"endpoint": endpoint, "taskType": taskType},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func kelingResumePreparedTask(request Request) kelingPreparedTask {
|
||||||
|
endpoint := ""
|
||||||
|
for _, key := range []string{"endpoint", "pollEndpoint"} {
|
||||||
|
if value := strings.TrimSpace(stringFromAny(request.RemoteTaskPayload[key])); value != "" {
|
||||||
|
endpoint = value
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if endpoint == "" {
|
||||||
|
if kelingIsOmniRequest(request) {
|
||||||
|
endpoint = "/videos/omni-video"
|
||||||
|
} else {
|
||||||
|
endpoint = "/videos/" + kelingTaskTypeFromRequest(request)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return kelingPreparedTask{Endpoint: endpoint, RemoteTaskPayload: map[string]any{"endpoint": endpoint}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func kelingVideoPayload(ctx context.Context, request Request) (map[string]any, string, error) {
|
||||||
|
body := cleanProviderBody(request.Body)
|
||||||
|
content := contentItems(body["content"])
|
||||||
|
if len(content) == 0 {
|
||||||
|
content = buildVolcesContentFromBody(body)
|
||||||
|
}
|
||||||
|
prompt := firstKelingPrompt(content)
|
||||||
|
if prompt == "" {
|
||||||
|
return nil, "", &ClientError{Code: "invalid_parameter", Message: "keling video prompt is required", StatusCode: 400, Retryable: false}
|
||||||
|
}
|
||||||
|
firstFrame, lastFrame, referenceImages := kelingImageInputs(content)
|
||||||
|
isImage2Video := firstFrame != "" || lastFrame != "" || len(referenceImages) > 0
|
||||||
|
primaryImage := firstFrame
|
||||||
|
if primaryImage == "" && len(referenceImages) <= 1 && len(referenceImages) > 0 {
|
||||||
|
primaryImage = referenceImages[0]
|
||||||
|
}
|
||||||
|
if primaryImage == "" {
|
||||||
|
primaryImage = lastFrame
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := map[string]any{
|
||||||
|
"prompt": prompt,
|
||||||
|
"model_name": upstreamModelName(request.Candidate),
|
||||||
|
"duration": fmtDuration(body["duration"], 5),
|
||||||
|
}
|
||||||
|
if value := strings.TrimSpace(stringFromAny(body["negative_prompt"])); value != "" {
|
||||||
|
payload["negative_prompt"] = value
|
||||||
|
}
|
||||||
|
if value, ok := body["cfg_scale"]; ok && numericValue(value, 0) > 0 {
|
||||||
|
payload["cfg_scale"] = value
|
||||||
|
}
|
||||||
|
if boolValue(body, "audio") || boolValue(body, "output_audio") {
|
||||||
|
payload["sound"] = "on"
|
||||||
|
}
|
||||||
|
if mode := kelingModeByResolution(firstNonEmptyStringValue(body, "resolution", "size")); mode != "" {
|
||||||
|
payload["mode"] = mode
|
||||||
|
}
|
||||||
|
if ratio := strings.TrimSpace(firstNonEmptyStringValue(body, "aspect_ratio", "aspectRatio", "ratio")); strings.Contains(ratio, ":") {
|
||||||
|
payload["aspect_ratio"] = ratio
|
||||||
|
}
|
||||||
|
if camera := kelingCameraControl(body); camera != nil {
|
||||||
|
payload["camera_control"] = camera
|
||||||
|
}
|
||||||
|
if primaryImage != "" {
|
||||||
|
encoded, err := kelingImageToBase64(ctx, request, primaryImage)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
payload["image"] = encoded
|
||||||
|
}
|
||||||
|
if lastFrame != "" {
|
||||||
|
encoded, err := kelingImageToBase64(ctx, request, lastFrame)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
payload["image_tail"] = encoded
|
||||||
|
}
|
||||||
|
if len(referenceImages) > 0 {
|
||||||
|
imageList := make([]any, 0, len(referenceImages))
|
||||||
|
for _, url := range referenceImages {
|
||||||
|
encoded, err := kelingImageToBase64(ctx, request, url)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
imageList = append(imageList, map[string]any{"image": encoded})
|
||||||
|
}
|
||||||
|
payload["image_list"] = imageList
|
||||||
|
}
|
||||||
|
if !strings.Contains(stringFromAny(payload["aspect_ratio"]), ":") || isImage2Video {
|
||||||
|
delete(payload, "aspect_ratio")
|
||||||
|
}
|
||||||
|
|
||||||
|
taskType := "text2video"
|
||||||
|
if primaryImage != "" {
|
||||||
|
taskType = "image2video"
|
||||||
|
} else if len(referenceImages) > 1 {
|
||||||
|
taskType = "multi-image2video"
|
||||||
|
}
|
||||||
|
return payload, taskType, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func kelingTaskTypeFromRequest(request Request) string {
|
||||||
|
body := cleanProviderBody(request.Body)
|
||||||
|
content := contentItems(body["content"])
|
||||||
|
if len(content) == 0 {
|
||||||
|
content = buildVolcesContentFromBody(body)
|
||||||
|
}
|
||||||
|
firstFrame, lastFrame, referenceImages := kelingImageInputs(content)
|
||||||
|
if firstFrame != "" || lastFrame != "" || len(referenceImages) == 1 {
|
||||||
|
return "image2video"
|
||||||
|
}
|
||||||
|
if len(referenceImages) > 1 {
|
||||||
|
return "multi-image2video"
|
||||||
|
}
|
||||||
|
return "text2video"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c KelingClient) kelingOmniPayload(ctx context.Context, request Request, token string) (map[string]any, []string, error) {
|
||||||
|
body := cleanProviderBody(request.Body)
|
||||||
|
content := contentItems(body["content"])
|
||||||
|
if len(content) == 0 {
|
||||||
|
content = buildVolcesContentFromBody(body)
|
||||||
|
}
|
||||||
|
prompt := firstKelingPrompt(content)
|
||||||
|
images := kelingOmniImageList(content)
|
||||||
|
videos := kelingOmniVideoList(content)
|
||||||
|
uploadedElementIDs := make([]string, 0)
|
||||||
|
elements, createdIDs, err := c.kelingOmniElementList(ctx, request, token, content)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
uploadedElementIDs = append(uploadedElementIDs, createdIDs...)
|
||||||
|
shots := kelingShotPrompts(content)
|
||||||
|
hasMultiPrompt := len(shots) > 0
|
||||||
|
hasVideo := len(videos) > 0
|
||||||
|
hasVideoEdit := kelingHasBaseVideo(videos)
|
||||||
|
hasFirstFrame := kelingHasFirstFrame(images)
|
||||||
|
|
||||||
|
payload := map[string]any{
|
||||||
|
"model_name": upstreamModelName(request.Candidate),
|
||||||
|
"mode": kelingModeByResolution(firstNonEmptyStringValue(body, "resolution", "size")),
|
||||||
|
"watermark_info": map[string]any{"enabled": false},
|
||||||
|
"negative_prompt": strings.TrimSpace(stringFromAny(body["negative_prompt"])),
|
||||||
|
}
|
||||||
|
if !hasMultiPrompt {
|
||||||
|
payload["prompt"] = prompt
|
||||||
|
if body["duration"] != nil {
|
||||||
|
payload["duration"] = fmtDuration(body["duration"], 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ratio := strings.TrimSpace(firstNonEmptyStringValue(body, "aspect_ratio", "aspectRatio", "ratio")); strings.Contains(ratio, ":") {
|
||||||
|
payload["aspect_ratio"] = ratio
|
||||||
|
}
|
||||||
|
if len(images) > 0 {
|
||||||
|
payload["image_list"] = images
|
||||||
|
}
|
||||||
|
if len(videos) > 0 {
|
||||||
|
payload["video_list"] = videos
|
||||||
|
}
|
||||||
|
if len(elements) > 0 {
|
||||||
|
payload["element_list"] = elements
|
||||||
|
}
|
||||||
|
if (boolValue(body, "audio") || boolValue(body, "output_audio")) && !hasVideo {
|
||||||
|
payload["sound"] = "on"
|
||||||
|
}
|
||||||
|
if hasMultiPrompt {
|
||||||
|
payload["multi_shot"] = true
|
||||||
|
payload["shot_type"] = "customize"
|
||||||
|
total := 0.0
|
||||||
|
multiPrompt := make([]any, 0, len(shots))
|
||||||
|
for index, shot := range shots {
|
||||||
|
duration := shot.duration
|
||||||
|
if duration <= 0 {
|
||||||
|
duration = 5
|
||||||
|
}
|
||||||
|
total += duration
|
||||||
|
multiPrompt = append(multiPrompt, map[string]any{
|
||||||
|
"index": index + 1,
|
||||||
|
"prompt": shot.text,
|
||||||
|
"duration": fmtDuration(duration, 5),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
delete(payload, "prompt")
|
||||||
|
payload["multi_prompt"] = multiPrompt
|
||||||
|
payload["duration"] = fmtDuration(total, 0)
|
||||||
|
}
|
||||||
|
deleteEmptyStringFields(payload)
|
||||||
|
if hasVideoEdit {
|
||||||
|
delete(payload, "duration")
|
||||||
|
delete(payload, "aspect_ratio")
|
||||||
|
}
|
||||||
|
if hasVideo && !hasVideoEdit && !strings.Contains(stringFromAny(payload["aspect_ratio"]), ":") {
|
||||||
|
payload["aspect_ratio"] = "16:9"
|
||||||
|
}
|
||||||
|
if !hasVideoEdit && !hasFirstFrame && !strings.Contains(stringFromAny(payload["aspect_ratio"]), ":") {
|
||||||
|
payload["aspect_ratio"] = "16:9"
|
||||||
|
}
|
||||||
|
return payload, uploadedElementIDs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c KelingClient) kelingOmniElementList(ctx context.Context, request Request, token string, content []map[string]any) ([]any, []string, error) {
|
||||||
|
elements := make([]any, 0)
|
||||||
|
createdIDs := make([]string, 0)
|
||||||
|
for _, item := range content {
|
||||||
|
if stringFromAny(item["type"]) != "element" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
element := mapFromAny(item["element"])
|
||||||
|
if element == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if id := kelingStringFromAny(firstPresent(element["element_id"], element["id"])); id != "" {
|
||||||
|
elements = append(elements, map[string]any{"element_id": id})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
inline := mapFromAny(element["inline_element"])
|
||||||
|
if inline == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
payload := kelingCreateElementPayload(inline)
|
||||||
|
if payload == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
id, err := c.createKelingElement(ctx, request, token, payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, createdIDs, err
|
||||||
|
}
|
||||||
|
elements = append(elements, map[string]any{"element_id": id})
|
||||||
|
createdIDs = append(createdIDs, id)
|
||||||
|
}
|
||||||
|
return elements, createdIDs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c KelingClient) postJSON(ctx context.Context, request Request, path string, token string, body map[string]any) (map[string]any, string, error) {
|
||||||
|
raw, _ := json.Marshal(body)
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, joinURL(request.Candidate.BaseURL, path), bytes.NewReader(raw))
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
resp, err := httpClient(request.HTTPClient, c.HTTPClient).Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", &ClientError{Code: "network", Message: err.Error(), Retryable: true}
|
||||||
|
}
|
||||||
|
requestID := requestIDFromHTTPResponse(resp)
|
||||||
|
result, err := decodeHTTPResponse(resp)
|
||||||
|
if err != nil {
|
||||||
|
return result, requestID, err
|
||||||
|
}
|
||||||
|
if code := intFromAny(result["code"]); code != 0 {
|
||||||
|
return result, requestID, &ClientError{Code: kelingEnvelopeErrorCode(result), Message: kelingEnvelopeErrorMessage(result), RequestID: firstNonEmpty(requestID, stringFromAny(result["request_id"])), Retryable: false}
|
||||||
|
}
|
||||||
|
return result, firstNonEmpty(requestID, stringFromAny(result["request_id"])), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c KelingClient) getJSON(ctx context.Context, request Request, path string, token string) (map[string]any, string, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, joinURL(request.Candidate.BaseURL, path), nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
resp, err := httpClient(request.HTTPClient, c.HTTPClient).Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", &ClientError{Code: "network", Message: err.Error(), Retryable: true}
|
||||||
|
}
|
||||||
|
requestID := requestIDFromHTTPResponse(resp)
|
||||||
|
result, err := decodeHTTPResponse(resp)
|
||||||
|
if err != nil {
|
||||||
|
return result, requestID, err
|
||||||
|
}
|
||||||
|
if code := intFromAny(result["code"]); code != 0 {
|
||||||
|
return result, requestID, &ClientError{Code: kelingEnvelopeErrorCode(result), Message: kelingEnvelopeErrorMessage(result), RequestID: firstNonEmpty(requestID, stringFromAny(result["request_id"])), Retryable: false}
|
||||||
|
}
|
||||||
|
return result, firstNonEmpty(requestID, stringFromAny(result["request_id"])), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c KelingClient) createKelingElement(ctx context.Context, request Request, token string, payload map[string]any) (string, error) {
|
||||||
|
raw, _ := json.Marshal(payload)
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, joinURL(request.Candidate.BaseURL, "/general/custom-elements"), bytes.NewReader(raw))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
resp, err := httpClient(request.HTTPClient, c.HTTPClient).Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", &ClientError{Code: "network", Message: err.Error(), Retryable: true}
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 16*1024*1024))
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return "", &ClientError{Code: statusCodeName(resp.StatusCode), Message: errorMessage(body, resp.Status), StatusCode: resp.StatusCode, RequestID: requestIDFromHTTPResponse(resp), Retryable: HTTPRetryable(resp.StatusCode)}
|
||||||
|
}
|
||||||
|
var parsed struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
RequestID string `json:"request_id"`
|
||||||
|
Data map[string]any `json:"data"`
|
||||||
|
}
|
||||||
|
decoder := json.NewDecoder(bytes.NewReader(body))
|
||||||
|
decoder.UseNumber()
|
||||||
|
if err := decoder.Decode(&parsed); err != nil {
|
||||||
|
return "", &ClientError{Code: "invalid_response", Message: err.Error(), Retryable: false}
|
||||||
|
}
|
||||||
|
if parsed.Code != 0 {
|
||||||
|
return "", &ClientError{Code: "keling_element_create_failed", Message: parsed.Message, RequestID: parsed.RequestID, Retryable: false}
|
||||||
|
}
|
||||||
|
id := kelingStringFromAny(parsed.Data["element_id"])
|
||||||
|
if id == "" {
|
||||||
|
return "", &ClientError{Code: "invalid_response", Message: "keling element id is missing", RequestID: parsed.RequestID, Retryable: false}
|
||||||
|
}
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c KelingClient) cleanupKelingElements(ctx context.Context, request Request, token string, elementIDs []string) error {
|
||||||
|
for _, id := range elementIDs {
|
||||||
|
id = strings.TrimSpace(id)
|
||||||
|
if id == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
_, _, _ = c.postJSON(ctx, request, "/general/delete-elements", token, map[string]any{"element_id": id})
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func kelingAuthToken(candidate store.RuntimeModelCandidate) (string, error) {
|
||||||
|
apiKey := credential(candidate.Credentials, "apiKey", "api_key", "key", "token")
|
||||||
|
accessKey := credential(candidate.Credentials, "accessKey", "access_key", "ak")
|
||||||
|
secretKey := credential(candidate.Credentials, "secretKey", "secret_key", "sk")
|
||||||
|
if accessKey != "" || secretKey != "" || strings.EqualFold(strings.TrimSpace(candidate.AuthType), "AccessKey-SecretKey") {
|
||||||
|
if accessKey == "" || secretKey == "" {
|
||||||
|
return "", &ClientError{Code: "missing_credentials", Message: "keling accessKey and secretKey are required", Retryable: false}
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
claims := jwt.MapClaims{
|
||||||
|
"iss": accessKey,
|
||||||
|
"exp": now.Add(30 * time.Minute).Unix(),
|
||||||
|
"nbf": now.Add(-5 * time.Second).Unix(),
|
||||||
|
}
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
|
signed, err := token.SignedString([]byte(secretKey))
|
||||||
|
if err != nil {
|
||||||
|
return "", &ClientError{Code: "auth_failed", Message: err.Error(), Retryable: false}
|
||||||
|
}
|
||||||
|
return signed, nil
|
||||||
|
}
|
||||||
|
if apiKey == "" {
|
||||||
|
return "", &ClientError{Code: "missing_credentials", Message: "keling api key is required", Retryable: false}
|
||||||
|
}
|
||||||
|
return apiKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func kelingImageToBase64(ctx context.Context, request Request, value string) (string, error) {
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
if value == "" {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(value, "data:") {
|
||||||
|
parts := strings.SplitN(value, ",", 2)
|
||||||
|
if len(parts) == 2 {
|
||||||
|
return strings.TrimSpace(parts[1]), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(value, "http://") || strings.HasPrefix(value, "https://") {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, value, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
resp, err := httpClient(request.HTTPClient).Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", &ClientError{Code: "network", Message: err.Error(), Retryable: true}
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
raw, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024))
|
||||||
|
return "", &ClientError{Code: statusCodeName(resp.StatusCode), Message: errorMessage(raw, resp.Status), StatusCode: resp.StatusCode, RequestID: requestIDFromHTTPResponse(resp), Retryable: HTTPRetryable(resp.StatusCode)}
|
||||||
|
}
|
||||||
|
raw, err := io.ReadAll(io.LimitReader(resp.Body, 16*1024*1024))
|
||||||
|
if err != nil {
|
||||||
|
return "", &ClientError{Code: "network", Message: err.Error(), Retryable: true}
|
||||||
|
}
|
||||||
|
return base64.StdEncoding.EncodeToString(raw), nil
|
||||||
|
}
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func kelingIsOmniRequest(request Request) bool {
|
||||||
|
modelType := strings.TrimSpace(request.ModelType)
|
||||||
|
return modelType == "omni_video" || modelType == "omni" ||
|
||||||
|
request.Candidate.Capabilities["omni_video"] != nil ||
|
||||||
|
request.Candidate.Capabilities["omni"] != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func firstKelingPrompt(content []map[string]any) string {
|
||||||
|
for _, item := range content {
|
||||||
|
if stringFromAny(item["type"]) == "text" && stringFromAny(item["role"]) != "shot_prompt" && item["shot_index"] == nil {
|
||||||
|
if text := strings.TrimSpace(stringFromAny(item["text"])); text != "" {
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func kelingImageInputs(content []map[string]any) (string, string, []string) {
|
||||||
|
firstFrame := ""
|
||||||
|
lastFrame := ""
|
||||||
|
references := make([]string, 0)
|
||||||
|
for _, item := range content {
|
||||||
|
if !isKelingImageContent(item) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
url := kelingNestedURL(item, "image_url")
|
||||||
|
if url == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch stringFromAny(item["role"]) {
|
||||||
|
case "first_frame":
|
||||||
|
if firstFrame == "" {
|
||||||
|
firstFrame = url
|
||||||
|
}
|
||||||
|
case "last_frame":
|
||||||
|
if lastFrame == "" {
|
||||||
|
lastFrame = url
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
references = append(references, url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return firstFrame, lastFrame, references
|
||||||
|
}
|
||||||
|
|
||||||
|
func kelingOmniImageList(content []map[string]any) []any {
|
||||||
|
out := make([]any, 0)
|
||||||
|
for _, item := range content {
|
||||||
|
if !isKelingImageContent(item) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
url := kelingNestedURL(item, "image_url")
|
||||||
|
if url == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
image := map[string]any{"image_url": url}
|
||||||
|
switch stringFromAny(item["role"]) {
|
||||||
|
case "first_frame":
|
||||||
|
image["type"] = "first_frame"
|
||||||
|
case "last_frame":
|
||||||
|
image["type"] = "end_frame"
|
||||||
|
}
|
||||||
|
out = append(out, image)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func kelingOmniVideoList(content []map[string]any) []map[string]any {
|
||||||
|
out := make([]map[string]any, 0)
|
||||||
|
for _, item := range content {
|
||||||
|
if !isKelingVideoContent(item) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
nested := mapFromAny(item["video_url"])
|
||||||
|
url := strings.TrimSpace(stringFromAny(nested["url"]))
|
||||||
|
if url == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
video := map[string]any{"video_url": url}
|
||||||
|
referType := strings.TrimSpace(stringFromAny(nested["refer_type"]))
|
||||||
|
if referType == "" {
|
||||||
|
switch stringFromAny(item["role"]) {
|
||||||
|
case "video_base":
|
||||||
|
referType = "base"
|
||||||
|
case "video_feature", "reference_video":
|
||||||
|
referType = "feature"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if referType == "base" || referType == "feature" {
|
||||||
|
video["refer_type"] = referType
|
||||||
|
}
|
||||||
|
if keep := strings.TrimSpace(stringFromAny(nested["keep_original_sound"])); keep != "" {
|
||||||
|
video["keep_original_sound"] = keep
|
||||||
|
}
|
||||||
|
out = append(out, video)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
type kelingShotPrompt struct {
|
||||||
|
index int
|
||||||
|
text string
|
||||||
|
duration float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func kelingShotPrompts(content []map[string]any) []kelingShotPrompt {
|
||||||
|
shots := make([]kelingShotPrompt, 0)
|
||||||
|
for index, item := range content {
|
||||||
|
if stringFromAny(item["type"]) != "text" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if stringFromAny(item["role"]) != "shot_prompt" && item["shot_index"] == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
text := strings.TrimSpace(stringFromAny(item["text"]))
|
||||||
|
if text == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
shotIndex := int(math.Floor(numericValue(item["shot_index"], float64(index))))
|
||||||
|
shots = append(shots, kelingShotPrompt{index: shotIndex, text: text, duration: numericValue(item["duration"], 5)})
|
||||||
|
}
|
||||||
|
sort.SliceStable(shots, func(i, j int) bool { return shots[i].index < shots[j].index })
|
||||||
|
return shots
|
||||||
|
}
|
||||||
|
|
||||||
|
func kelingHasBaseVideo(videos []map[string]any) bool {
|
||||||
|
for _, video := range videos {
|
||||||
|
if stringFromAny(video["refer_type"]) == "base" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func kelingHasFirstFrame(images []any) bool {
|
||||||
|
for _, item := range images {
|
||||||
|
image := mapFromAny(item)
|
||||||
|
if stringFromAny(image["type"]) == "first_frame" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func kelingCreateElementPayload(inline map[string]any) map[string]any {
|
||||||
|
frontURL := strings.TrimSpace(firstNonEmptyStringValue(inline, "frontal_image_url", "frontalImageUrl", "element_frontal_image", "image_url", "imageUrl", "url"))
|
||||||
|
if frontURL == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
name := firstNonEmptyStringValue(inline, "name", "element_name", "elementName")
|
||||||
|
if name == "" {
|
||||||
|
name = "temporary element"
|
||||||
|
}
|
||||||
|
payload := map[string]any{
|
||||||
|
"element_name": name,
|
||||||
|
"element_description": firstNonEmpty(firstNonEmptyStringValue(inline, "description"), name),
|
||||||
|
"element_frontal_image": frontURL,
|
||||||
|
}
|
||||||
|
referImages := make([]any, 0)
|
||||||
|
for _, ref := range mapListFromAny(firstPresent(inline["refer_images"], inline["referImages"], inline["element_refer_list"])) {
|
||||||
|
url := strings.TrimSpace(firstNonEmptyStringValue(ref, "url", "image_url", "imageUrl"))
|
||||||
|
if url != "" {
|
||||||
|
referImages = append(referImages, map[string]any{"image_url": url})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(referImages) > 0 {
|
||||||
|
payload["element_refer_list"] = referImages
|
||||||
|
}
|
||||||
|
if tags := kelingElementTagList(inline["tags"]); len(tags) > 0 {
|
||||||
|
payload["tag_list"] = tags
|
||||||
|
}
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
func kelingElementTagList(value any) []any {
|
||||||
|
mapping := map[string]string{
|
||||||
|
"hot_meme": "o_101",
|
||||||
|
"character": "o_102",
|
||||||
|
"animal": "o_103",
|
||||||
|
"prop": "o_104",
|
||||||
|
"costume": "o_105",
|
||||||
|
"scene": "o_106",
|
||||||
|
"effect": "o_107",
|
||||||
|
"other": "o_108",
|
||||||
|
}
|
||||||
|
out := make([]any, 0)
|
||||||
|
for _, tag := range stringListFromAny(value) {
|
||||||
|
id := mapping[strings.TrimSpace(tag)]
|
||||||
|
if id == "" {
|
||||||
|
id = mapping["other"]
|
||||||
|
}
|
||||||
|
out = append(out, map[string]any{"tag_id": id})
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func kelingNestedURL(item map[string]any, key string) string {
|
||||||
|
nested := mapFromAny(item[key])
|
||||||
|
if nested != nil {
|
||||||
|
if value := strings.TrimSpace(stringFromAny(nested["url"])); value != "" {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(stringFromAny(item[key]))
|
||||||
|
}
|
||||||
|
|
||||||
|
func isKelingImageContent(item map[string]any) bool {
|
||||||
|
return stringFromAny(item["type"]) == "image_url" || mapFromAny(item["image_url"]) != nil || strings.TrimSpace(stringFromAny(item["image_url"])) != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func isKelingVideoContent(item map[string]any) bool {
|
||||||
|
return stringFromAny(item["type"]) == "video_url" || mapFromAny(item["video_url"]) != nil || strings.TrimSpace(stringFromAny(item["video_url"])) != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func kelingModeByResolution(resolution string) string {
|
||||||
|
switch strings.TrimSpace(resolution) {
|
||||||
|
case "2160p":
|
||||||
|
return "4k"
|
||||||
|
case "1080p":
|
||||||
|
return "pro"
|
||||||
|
case "480p", "720p", "":
|
||||||
|
return "std"
|
||||||
|
default:
|
||||||
|
if strings.HasSuffix(strings.TrimSpace(resolution), "p") {
|
||||||
|
return "std"
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func kelingCameraControl(body map[string]any) map[string]any {
|
||||||
|
cameraControl := strings.TrimSpace(stringFromAny(body["camera_control"]))
|
||||||
|
if cameraControl == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(cameraControl, "simple") {
|
||||||
|
directions := []string{"horizontal", "vertical", "pan", "tilt", "roll", "zoom"}
|
||||||
|
current := ""
|
||||||
|
parts := strings.SplitN(cameraControl, ":", 2)
|
||||||
|
if len(parts) == 2 {
|
||||||
|
current = parts[1]
|
||||||
|
}
|
||||||
|
strength := firstPresent(body["camera_control_strength"], body["cameraControlStrength"])
|
||||||
|
config := map[string]any{}
|
||||||
|
for _, direction := range directions {
|
||||||
|
if direction == current {
|
||||||
|
config[direction] = strength
|
||||||
|
} else {
|
||||||
|
config[direction] = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return map[string]any{"type": "simple", "config": config}
|
||||||
|
}
|
||||||
|
return map[string]any{"type": cameraControl}
|
||||||
|
}
|
||||||
|
|
||||||
|
func kelingData(result map[string]any) map[string]any {
|
||||||
|
data, _ := result["data"].(map[string]any)
|
||||||
|
if data == nil {
|
||||||
|
return map[string]any{}
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
func kelingTaskStatus(result map[string]any) string {
|
||||||
|
return strings.ToLower(strings.TrimSpace(stringFromAny(kelingData(result)["task_status"])))
|
||||||
|
}
|
||||||
|
|
||||||
|
func kelingTaskErrorCode(result map[string]any) string {
|
||||||
|
if code := intFromAny(result["code"]); code != 0 {
|
||||||
|
return fmt.Sprintf("keling_%d", code)
|
||||||
|
}
|
||||||
|
return "keling_task_failed"
|
||||||
|
}
|
||||||
|
|
||||||
|
func kelingTaskErrorMessage(candidate store.RuntimeModelCandidate, result map[string]any) string {
|
||||||
|
message := strings.TrimSpace(stringFromAny(kelingData(result)["task_status_msg"]))
|
||||||
|
if message == "" {
|
||||||
|
message = strings.TrimSpace(stringFromAny(result["message"]))
|
||||||
|
}
|
||||||
|
if message == "" {
|
||||||
|
message = "keling video task failed"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("Platform:%s,Code:%v,requestId:%s,message:%s", candidate.Provider, result["code"], stringFromAny(result["request_id"]), message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func kelingEnvelopeErrorCode(result map[string]any) string {
|
||||||
|
if code := intFromAny(result["code"]); code != 0 {
|
||||||
|
return fmt.Sprintf("keling_%d", code)
|
||||||
|
}
|
||||||
|
return "keling_error"
|
||||||
|
}
|
||||||
|
|
||||||
|
func kelingEnvelopeErrorMessage(result map[string]any) string {
|
||||||
|
if message := strings.TrimSpace(stringFromAny(result["message"])); message != "" {
|
||||||
|
return message
|
||||||
|
}
|
||||||
|
return "keling request failed"
|
||||||
|
}
|
||||||
|
|
||||||
|
func kelingVideoSuccessResult(request Request, upstreamTaskID string, raw map[string]any) map[string]any {
|
||||||
|
data := kelingData(raw)
|
||||||
|
taskResult, _ := data["task_result"].(map[string]any)
|
||||||
|
videos, _ := taskResult["videos"].([]any)
|
||||||
|
items := make([]any, 0, len(videos))
|
||||||
|
for _, rawVideo := range videos {
|
||||||
|
video := mapFromAny(rawVideo)
|
||||||
|
url := strings.TrimSpace(stringFromAny(video["url"]))
|
||||||
|
if url == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
item := map[string]any{"url": url, "video_url": url, "type": "video"}
|
||||||
|
if duration := intFromAny(video["duration"]); duration > 0 {
|
||||||
|
item["duration"] = duration
|
||||||
|
}
|
||||||
|
items = append(items, item)
|
||||||
|
}
|
||||||
|
created := intFromAny(data["created_at"])
|
||||||
|
if created == 0 {
|
||||||
|
created = int(nowUnix())
|
||||||
|
}
|
||||||
|
return map[string]any{
|
||||||
|
"id": upstreamTaskID,
|
||||||
|
"object": "video.generation",
|
||||||
|
"created": created,
|
||||||
|
"model": upstreamModelName(request.Candidate),
|
||||||
|
"status": "succeeded",
|
||||||
|
"upstream_task_id": upstreamTaskID,
|
||||||
|
"data": items,
|
||||||
|
"raw": raw,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func kelingVideoProgress(request Request, upstreamTaskID string) []Progress {
|
||||||
|
progress := providerProgress(request)
|
||||||
|
progress = append(progress, Progress{
|
||||||
|
Phase: "polling_result",
|
||||||
|
Progress: 0.9,
|
||||||
|
Message: "keling video task completed",
|
||||||
|
Payload: map[string]any{"upstreamTaskId": upstreamTaskID},
|
||||||
|
})
|
||||||
|
return progress
|
||||||
|
}
|
||||||
|
|
||||||
|
func kelingPollEndpoint(request Request, fallback string) string {
|
||||||
|
for _, key := range []string{"endpoint", "pollEndpoint"} {
|
||||||
|
if value := strings.TrimSpace(stringFromAny(request.RemoteTaskPayload[key])); value != "" {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
|
||||||
|
func kelingPollInterval(request Request) time.Duration {
|
||||||
|
ms := numericValue(firstPresent(request.Candidate.PlatformConfig["kelingPollIntervalMs"], request.Candidate.PlatformConfig["klingPollIntervalMs"], request.Body["pollIntervalMs"], request.Body["poll_interval_ms"]), 15000)
|
||||||
|
if ms < 100 {
|
||||||
|
ms = 100
|
||||||
|
}
|
||||||
|
return time.Duration(ms) * time.Millisecond
|
||||||
|
}
|
||||||
|
|
||||||
|
func kelingPollTimeout(request Request) time.Duration {
|
||||||
|
seconds := numericValue(firstPresent(request.Candidate.PlatformConfig["kelingPollTimeoutSeconds"], request.Candidate.PlatformConfig["klingPollTimeoutSeconds"], request.Body["pollTimeoutSeconds"], request.Body["poll_timeout_seconds"]), 600)
|
||||||
|
if seconds < 1 {
|
||||||
|
seconds = 600
|
||||||
|
}
|
||||||
|
return time.Duration(seconds) * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
func fmtDuration(value any, fallback float64) string {
|
||||||
|
duration := numericValue(value, fallback)
|
||||||
|
if math.Abs(duration-math.Round(duration)) < 1e-9 {
|
||||||
|
return fmt.Sprintf("%d", int(math.Round(duration)))
|
||||||
|
}
|
||||||
|
return strings.TrimRight(strings.TrimRight(fmt.Sprintf("%.6f", duration), "0"), ".")
|
||||||
|
}
|
||||||
|
|
||||||
|
func deleteEmptyStringFields(payload map[string]any) {
|
||||||
|
for key, value := range payload {
|
||||||
|
if text, ok := value.(string); ok && strings.TrimSpace(text) == "" {
|
||||||
|
delete(payload, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func kelingStringFromAny(value any) string {
|
||||||
|
switch typed := value.(type) {
|
||||||
|
case json.Number:
|
||||||
|
return typed.String()
|
||||||
|
case float64:
|
||||||
|
if math.Abs(typed-math.Round(typed)) < 1e-9 {
|
||||||
|
return fmt.Sprintf("%.0f", typed)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%v", typed)
|
||||||
|
case int:
|
||||||
|
return fmt.Sprintf("%d", typed)
|
||||||
|
case int64:
|
||||||
|
return fmt.Sprintf("%d", typed)
|
||||||
|
case string:
|
||||||
|
return strings.TrimSpace(typed)
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
232
apps/api/internal/clients/media_clients.go
Normal file
232
apps/api/internal/clients/media_clients.go
Normal file
@ -0,0 +1,232 @@
|
|||||||
|
package clients
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type JimengClient struct{ HTTPClient *http.Client }
|
||||||
|
type BlackforestClient struct{ HTTPClient *http.Client }
|
||||||
|
type HunyuanImageClient struct{ HTTPClient *http.Client }
|
||||||
|
type HunyuanVideoClient struct{ HTTPClient *http.Client }
|
||||||
|
type MinimaxClient struct{ HTTPClient *http.Client }
|
||||||
|
type MidjourneyClient struct{ HTTPClient *http.Client }
|
||||||
|
type ViduClient struct{ HTTPClient *http.Client }
|
||||||
|
type AliyunBailianClient struct{ HTTPClient *http.Client }
|
||||||
|
type NewAPIClient struct{ HTTPClient *http.Client }
|
||||||
|
|
||||||
|
func (c JimengClient) Run(ctx context.Context, request Request) (Response, error) {
|
||||||
|
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: jimengSpec()}.Run(ctx, request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c BlackforestClient) Run(ctx context.Context, request Request) (Response, error) {
|
||||||
|
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: blackforestSpec()}.Run(ctx, request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c HunyuanImageClient) Run(ctx context.Context, request Request) (Response, error) {
|
||||||
|
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: hunyuanImageSpec()}.Run(ctx, request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c HunyuanVideoClient) Run(ctx context.Context, request Request) (Response, error) {
|
||||||
|
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: hunyuanVideoSpec()}.Run(ctx, request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c MinimaxClient) Run(ctx context.Context, request Request) (Response, error) {
|
||||||
|
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: minimaxSpec()}.Run(ctx, request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c MidjourneyClient) Run(ctx context.Context, request Request) (Response, error) {
|
||||||
|
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: midjourneySpec()}.Run(ctx, request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c ViduClient) Run(ctx context.Context, request Request) (Response, error) {
|
||||||
|
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: viduSpec()}.Run(ctx, request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c AliyunBailianClient) Run(ctx context.Context, request Request) (Response, error) {
|
||||||
|
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: aliyunBailianSpec()}.Run(ctx, request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c NewAPIClient) Run(ctx context.Context, request Request) (Response, error) {
|
||||||
|
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: newAPISpec()}.Run(ctx, request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func jimengSpec() providerTaskSpec {
|
||||||
|
return providerTaskSpec{
|
||||||
|
Name: "jimeng",
|
||||||
|
SubmitPath: func(request Request, _ map[string]any) string {
|
||||||
|
return configuredPath(request, "?Action=CVSubmitTask&Version=2022-08-31", "submitPath", "submit_path")
|
||||||
|
},
|
||||||
|
PollPath: func(request Request, _ string, _ map[string]any) string {
|
||||||
|
return configuredPath(request, "?Action=CVSync2AsyncGetResult&Version=2022-08-31", "pollPath", "poll_path")
|
||||||
|
},
|
||||||
|
Auth: "bearer",
|
||||||
|
TaskIDPaths: []string{"data.task_id"},
|
||||||
|
StatusPaths: []string{"data.status"},
|
||||||
|
SuccessStatuses: []string{"done"},
|
||||||
|
DefaultSubmitBody: func(request Request, body map[string]any) map[string]any {
|
||||||
|
body["req_key"] = upstreamModelName(request.Candidate)
|
||||||
|
if body["prompt"] == nil {
|
||||||
|
body["prompt"] = mediaPromptText(body)
|
||||||
|
}
|
||||||
|
return body
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func blackforestSpec() providerTaskSpec {
|
||||||
|
return providerTaskSpec{
|
||||||
|
Name: "blackforest",
|
||||||
|
SubmitPath: func(request Request, body map[string]any) string {
|
||||||
|
return configuredPath(request, "/"+upstreamModelName(request.Candidate), "submitPath", "submit_path")
|
||||||
|
},
|
||||||
|
PollPath: func(_ Request, upstreamTaskID string, _ map[string]any) string { return upstreamTaskID },
|
||||||
|
Auth: "x-key",
|
||||||
|
TaskIDPaths: []string{"polling_url"},
|
||||||
|
StatusPaths: []string{"status"},
|
||||||
|
SuccessStatuses: []string{"ready"},
|
||||||
|
FailureStatuses: []string{"error", "task not found"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func hunyuanImageSpec() providerTaskSpec {
|
||||||
|
return providerTaskSpec{
|
||||||
|
Name: "tencent-hunyuan-image",
|
||||||
|
SubmitPath: func(request Request, _ map[string]any) string {
|
||||||
|
return configuredPath(request, "?Action=SubmitHunyuanImageJob&Version=2023-09-01", "submitPath", "submit_path")
|
||||||
|
},
|
||||||
|
PollPath: func(request Request, _ string, _ map[string]any) string {
|
||||||
|
return configuredPath(request, "?Action=QueryHunyuanImageJob&Version=2023-09-01&JobId=${taskId}", "pollPath", "poll_path")
|
||||||
|
},
|
||||||
|
Auth: "bearer",
|
||||||
|
TaskIDPaths: []string{"Response.JobId"},
|
||||||
|
StatusPaths: []string{"Response.Status"},
|
||||||
|
SuccessStatuses: []string{"done"},
|
||||||
|
FailureStatuses: []string{"fail"},
|
||||||
|
DefaultSubmitBody: func(request Request, body map[string]any) map[string]any {
|
||||||
|
body["Prompt"] = mediaPromptText(body)
|
||||||
|
body["Model"] = upstreamModelName(request.Candidate)
|
||||||
|
return body
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func hunyuanVideoSpec() providerTaskSpec {
|
||||||
|
return providerTaskSpec{
|
||||||
|
Name: "tencent-hunyuan-video",
|
||||||
|
SubmitPath: func(request Request, _ map[string]any) string {
|
||||||
|
return configuredPath(request, "?Action=SubmitTextToVideoJob&Version=2024-01-01", "submitPath", "submit_path")
|
||||||
|
},
|
||||||
|
PollPath: func(request Request, _ string, _ map[string]any) string {
|
||||||
|
return configuredPath(request, "?Action=QueryVideoJob&Version=2024-01-01&JobId=${taskId}", "pollPath", "poll_path")
|
||||||
|
},
|
||||||
|
Auth: "bearer",
|
||||||
|
TaskIDPaths: []string{"Response.JobId"},
|
||||||
|
StatusPaths: []string{"Response.Status"},
|
||||||
|
SuccessStatuses: []string{"done"},
|
||||||
|
FailureStatuses: []string{"fail"},
|
||||||
|
DefaultSubmitBody: func(request Request, body map[string]any) map[string]any {
|
||||||
|
body["Prompt"] = mediaPromptText(body)
|
||||||
|
body["Model"] = upstreamModelName(request.Candidate)
|
||||||
|
return body
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func minimaxSpec() providerTaskSpec {
|
||||||
|
return providerTaskSpec{
|
||||||
|
Name: "minimax",
|
||||||
|
SubmitPath: func(Request, map[string]any) string { return "/video_generation" },
|
||||||
|
PollPath: func(_ Request, upstreamTaskID string, _ map[string]any) string {
|
||||||
|
return "/query/video_generation?task_id=" + upstreamTaskID
|
||||||
|
},
|
||||||
|
Auth: "bearer",
|
||||||
|
TaskIDPaths: []string{"task_id"},
|
||||||
|
StatusPaths: []string{"status"},
|
||||||
|
SuccessStatuses: []string{"success"},
|
||||||
|
FailureStatuses: []string{"failed", "expired"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func midjourneySpec() providerTaskSpec {
|
||||||
|
return providerTaskSpec{
|
||||||
|
Name: "midjourney",
|
||||||
|
SubmitPath: func(request Request, body map[string]any) string {
|
||||||
|
return configuredPath(request, "/diffusion", "submitPath", "submit_path")
|
||||||
|
},
|
||||||
|
PollPath: func(_ Request, upstreamTaskID string, _ map[string]any) string { return "/job/" + upstreamTaskID },
|
||||||
|
Auth: "bearer",
|
||||||
|
TaskIDPaths: []string{"job_id", "id"},
|
||||||
|
StatusPaths: []string{"status"},
|
||||||
|
SuccessStatuses: []string{"success", "completed"},
|
||||||
|
FailureStatuses: []string{"failed"},
|
||||||
|
DefaultSubmitBody: func(request Request, body map[string]any) map[string]any {
|
||||||
|
if body["prompt"] == nil && body["text"] == nil {
|
||||||
|
body["prompt"] = mediaPromptText(body)
|
||||||
|
}
|
||||||
|
return body
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func viduSpec() providerTaskSpec {
|
||||||
|
return providerTaskSpec{
|
||||||
|
Name: "vidu",
|
||||||
|
SubmitPath: func(request Request, body map[string]any) string {
|
||||||
|
if path := configuredPath(request, "", "submitPath", "submit_path"); path != "" {
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
taskType := firstNonEmptyString(body["type"], body["task_type"], "text2video")
|
||||||
|
if taskType == "multiframe" {
|
||||||
|
return "/multiframe"
|
||||||
|
}
|
||||||
|
return "/" + taskType
|
||||||
|
},
|
||||||
|
PollPath: func(_ Request, upstreamTaskID string, _ map[string]any) string {
|
||||||
|
return "/tasks/" + upstreamTaskID + "/creations"
|
||||||
|
},
|
||||||
|
Auth: "token",
|
||||||
|
TaskIDPaths: []string{"task_id"},
|
||||||
|
StatusPaths: []string{"state", "status"},
|
||||||
|
SuccessStatuses: []string{"success", "succeeded"},
|
||||||
|
FailureStatuses: []string{"failed"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func aliyunBailianSpec() providerTaskSpec {
|
||||||
|
return providerTaskSpec{
|
||||||
|
Name: "aliyun-bailian",
|
||||||
|
SubmitPath: func(Request, map[string]any) string { return "/services/aigc/video-generation/video-synthesis" },
|
||||||
|
PollPath: func(_ Request, upstreamTaskID string, _ map[string]any) string { return "/tasks/" + upstreamTaskID },
|
||||||
|
Auth: "bearer",
|
||||||
|
TaskIDPaths: []string{"output.task_id"},
|
||||||
|
StatusPaths: []string{"output.task_status"},
|
||||||
|
SuccessStatuses: []string{"succeeded", "success"},
|
||||||
|
FailureStatuses: []string{"failed"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAPISpec() providerTaskSpec {
|
||||||
|
return providerTaskSpec{
|
||||||
|
Name: "newapi",
|
||||||
|
SubmitPath: func(Request, map[string]any) string { return "/videos/generations" },
|
||||||
|
PollPath: func(_ Request, upstreamTaskID string, _ map[string]any) string {
|
||||||
|
return "/videos/generations/" + upstreamTaskID
|
||||||
|
},
|
||||||
|
Auth: "bearer",
|
||||||
|
TaskIDPaths: []string{"task_id"},
|
||||||
|
StatusPaths: []string{"status"},
|
||||||
|
SuccessStatuses: []string{"success"},
|
||||||
|
FailureStatuses: []string{"failure", "failed"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func configuredPath(request Request, fallback string, keys ...string) string {
|
||||||
|
for _, key := range keys {
|
||||||
|
if value := strings.TrimSpace(stringFromAny(request.Candidate.PlatformConfig[key])); value != "" {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
@ -23,6 +23,9 @@ func (c OpenAIClient) Run(ctx context.Context, request Request) (Response, error
|
|||||||
return Response{}, &ClientError{Code: "unsupported_kind", Message: "unsupported openai request kind", Retryable: false}
|
return Response{}, &ClientError{Code: "unsupported_kind", Message: "unsupported openai request kind", Retryable: false}
|
||||||
}
|
}
|
||||||
body := cloneBody(request.Body)
|
body := cloneBody(request.Body)
|
||||||
|
if request.Kind == "chat.completions" {
|
||||||
|
body = NormalizeChatCompletionRequestBody(body)
|
||||||
|
}
|
||||||
body["model"] = upstreamModelName(request.Candidate)
|
body["model"] = upstreamModelName(request.Candidate)
|
||||||
stream := request.Stream || boolValue(body, "stream")
|
stream := request.Stream || boolValue(body, "stream")
|
||||||
ensureOpenAIStreamUsage(body, request.Kind, stream)
|
ensureOpenAIStreamUsage(body, request.Kind, stream)
|
||||||
@ -33,13 +36,16 @@ func (c OpenAIClient) Run(ctx context.Context, request Request) (Response, error
|
|||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
|
responseStartedAt := time.Now()
|
||||||
resp, err := httpClient(request.HTTPClient, c.HTTPClient).Do(req)
|
resp, err := httpClient(request.HTTPClient, c.HTTPClient).Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Response{}, &ClientError{Code: "network", Message: err.Error(), Retryable: true}
|
return Response{}, &ClientError{Code: "network", Message: err.Error(), Retryable: true}
|
||||||
}
|
}
|
||||||
responseStartedAt := time.Now()
|
|
||||||
requestID := requestIDFromHTTPResponse(resp)
|
requestID := requestIDFromHTTPResponse(resp)
|
||||||
result, err := decodeOpenAIResponse(resp, stream, request.StreamDelta)
|
result, err := decodeOpenAIResponse(resp, stream, request.StreamDelta)
|
||||||
|
if err == nil && request.Kind == "chat.completions" {
|
||||||
|
result = NormalizeChatCompletionResult(result)
|
||||||
|
}
|
||||||
responseFinishedAt := time.Now()
|
responseFinishedAt := time.Now()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Response{}, annotateResponseError(err, requestID, responseStartedAt, responseFinishedAt)
|
return Response{}, annotateResponseError(err, requestID, responseStartedAt, responseFinishedAt)
|
||||||
|
|||||||
453
apps/api/internal/clients/provider_task.go
Normal file
453
apps/api/internal/clients/provider_task.go
Normal file
@ -0,0 +1,453 @@
|
|||||||
|
package clients
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type providerTaskSpec struct {
|
||||||
|
Name string
|
||||||
|
SubmitPath func(Request, map[string]any) string
|
||||||
|
PollPath func(Request, string, map[string]any) string
|
||||||
|
Auth string
|
||||||
|
TaskIDPaths []string
|
||||||
|
StatusPaths []string
|
||||||
|
SuccessStatuses []string
|
||||||
|
FailureStatuses []string
|
||||||
|
ProcessStatuses []string
|
||||||
|
DefaultSubmitBody func(Request, map[string]any) map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
type providerTaskClient struct {
|
||||||
|
HTTPClient *http.Client
|
||||||
|
Spec providerTaskSpec
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c providerTaskClient) Run(ctx context.Context, request Request) (Response, error) {
|
||||||
|
if request.Kind != "images.generations" && request.Kind != "images.edits" && request.Kind != "videos.generations" {
|
||||||
|
return Response{}, &ClientError{Code: "unsupported_kind", Message: "unsupported " + c.Spec.Name + " request kind", Retryable: false}
|
||||||
|
}
|
||||||
|
startedAt := time.Now()
|
||||||
|
payload := cloneBody(request.Body)
|
||||||
|
if c.Spec.DefaultSubmitBody != nil {
|
||||||
|
payload = c.Spec.DefaultSubmitBody(request, payload)
|
||||||
|
} else {
|
||||||
|
payload["model"] = upstreamModelName(request.Candidate)
|
||||||
|
}
|
||||||
|
|
||||||
|
upstreamTaskID := strings.TrimSpace(request.RemoteTaskID)
|
||||||
|
requestID := upstreamTaskID
|
||||||
|
var submitResult map[string]any
|
||||||
|
if upstreamTaskID == "" {
|
||||||
|
result, id, err := c.submit(ctx, request, payload)
|
||||||
|
if err != nil {
|
||||||
|
return Response{}, annotateResponseError(err, id, startedAt, time.Now())
|
||||||
|
}
|
||||||
|
submitResult = result
|
||||||
|
requestID = firstNonEmptyString(id, requestIDFromResult(result))
|
||||||
|
if isProviderTaskFailure(c.Spec, result) {
|
||||||
|
return Response{}, providerTaskFailure(c.Spec, result, requestID, startedAt)
|
||||||
|
}
|
||||||
|
if isProviderTaskSuccess(c.Spec, result) && hasProviderTaskResult(result) {
|
||||||
|
return Response{
|
||||||
|
Result: normalizeProviderTaskResult(request, c.Spec, result, ""),
|
||||||
|
RequestID: requestID,
|
||||||
|
Progress: providerProgress(request),
|
||||||
|
ResponseStartedAt: startedAt,
|
||||||
|
ResponseFinishedAt: time.Now(),
|
||||||
|
ResponseDurationMS: responseDurationMS(startedAt, time.Now()),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
upstreamTaskID = providerTaskID(c.Spec, result)
|
||||||
|
if upstreamTaskID == "" {
|
||||||
|
return Response{}, &ClientError{Code: "invalid_response", Message: c.Spec.Name + " task id is missing", RequestID: requestID, Retryable: false}
|
||||||
|
}
|
||||||
|
if request.OnRemoteTaskSubmitted != nil {
|
||||||
|
if err := request.OnRemoteTaskSubmitted(upstreamTaskID, map[string]any{"payload": payload, "submit": submitResult}); err != nil {
|
||||||
|
return Response{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if request.RemoteTaskPayload != nil {
|
||||||
|
if existingPayload, ok := request.RemoteTaskPayload["payload"].(map[string]any); ok {
|
||||||
|
payload = existingPayload
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
interval := providerPollInterval(request)
|
||||||
|
timeout := providerPollTimeout(request)
|
||||||
|
deadline := time.NewTimer(timeout)
|
||||||
|
defer deadline.Stop()
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
var lastResult map[string]any
|
||||||
|
for {
|
||||||
|
pollStarted := time.Now()
|
||||||
|
result, pollRequestID, err := c.poll(ctx, request, upstreamTaskID, payload)
|
||||||
|
pollFinished := time.Now()
|
||||||
|
if err != nil {
|
||||||
|
return Response{}, annotateResponseError(err, firstNonEmptyString(pollRequestID, requestID, upstreamTaskID), pollStarted, pollFinished)
|
||||||
|
}
|
||||||
|
lastResult = result
|
||||||
|
requestID = firstNonEmptyString(pollRequestID, requestID, requestIDFromResult(result), upstreamTaskID)
|
||||||
|
if isProviderTaskSuccess(c.Spec, result) {
|
||||||
|
finishedAt := time.Now()
|
||||||
|
return Response{
|
||||||
|
Result: normalizeProviderTaskResult(request, c.Spec, result, upstreamTaskID),
|
||||||
|
RequestID: requestID,
|
||||||
|
Progress: append(providerProgress(request), Progress{Phase: "polling", Progress: 0.65, Message: "provider task polled", Payload: map[string]any{"upstreamTaskId": upstreamTaskID}}),
|
||||||
|
ResponseStartedAt: startedAt,
|
||||||
|
ResponseFinishedAt: finishedAt,
|
||||||
|
ResponseDurationMS: responseDurationMS(startedAt, finishedAt),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
if isProviderTaskFailure(c.Spec, result) {
|
||||||
|
return Response{}, providerTaskFailure(c.Spec, result, requestID, startedAt)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return Response{}, &ClientError{Code: "cancelled", Message: ctx.Err().Error(), RequestID: requestID, Retryable: true}
|
||||||
|
case <-deadline.C:
|
||||||
|
return Response{}, &ClientError{Code: "timeout", Message: fmt.Sprintf("%s task %s did not finish before timeout; last status: %s", c.Spec.Name, upstreamTaskID, providerTaskStatus(c.Spec, lastResult)), RequestID: requestID, Retryable: true}
|
||||||
|
case <-ticker.C:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c providerTaskClient) submit(ctx context.Context, request Request, payload map[string]any) (map[string]any, string, error) {
|
||||||
|
path := c.Spec.SubmitPath(request, payload)
|
||||||
|
return providerPostJSON(ctx, httpClient(request.HTTPClient, c.HTTPClient), providerURL(request.Candidate.BaseURL, path), payload, request.Candidate.Credentials, c.Spec.Auth)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c providerTaskClient) poll(ctx context.Context, request Request, upstreamTaskID string, payload map[string]any) (map[string]any, string, error) {
|
||||||
|
path := resolveProviderPathTemplate(c.Spec.PollPath(request, upstreamTaskID, payload), upstreamTaskID)
|
||||||
|
url := path
|
||||||
|
if !strings.HasPrefix(path, "http://") && !strings.HasPrefix(path, "https://") {
|
||||||
|
url = providerURL(request.Candidate.BaseURL, path)
|
||||||
|
}
|
||||||
|
if c.Spec.Name == "jimeng" {
|
||||||
|
body := map[string]any{"task_id": upstreamTaskID, "req_key": upstreamModelName(request.Candidate)}
|
||||||
|
return providerPostJSON(ctx, httpClient(request.HTTPClient, c.HTTPClient), url, body, request.Candidate.Credentials, c.Spec.Auth)
|
||||||
|
}
|
||||||
|
return providerGetJSON(ctx, httpClient(request.HTTPClient, c.HTTPClient), url, request.Candidate.Credentials, c.Spec.Auth)
|
||||||
|
}
|
||||||
|
|
||||||
|
func providerPostJSON(ctx context.Context, client *http.Client, url string, body map[string]any, credentials map[string]any, auth string) (map[string]any, string, error) {
|
||||||
|
raw, _ := json.Marshal(body)
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(raw))
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
applyProviderAuth(req, credentials, auth)
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", &ClientError{Code: "network", Message: err.Error(), Retryable: true}
|
||||||
|
}
|
||||||
|
requestID := requestIDFromHTTPResponse(resp)
|
||||||
|
result, err := decodeHTTPResponse(resp)
|
||||||
|
return result, requestID, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func providerGetJSON(ctx context.Context, client *http.Client, url string, credentials map[string]any, auth string) (map[string]any, string, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
applyProviderAuth(req, credentials, auth)
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", &ClientError{Code: "network", Message: err.Error(), Retryable: true}
|
||||||
|
}
|
||||||
|
requestID := requestIDFromHTTPResponse(resp)
|
||||||
|
result, err := decodeHTTPResponse(resp)
|
||||||
|
return result, requestID, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyProviderAuth(req *http.Request, credentials map[string]any, auth string) {
|
||||||
|
apiKey := credential(credentials, "apiKey", "api_key", "key", "token")
|
||||||
|
switch auth {
|
||||||
|
case "token":
|
||||||
|
if apiKey != "" {
|
||||||
|
req.Header.Set("Authorization", "Token "+apiKey)
|
||||||
|
}
|
||||||
|
case "x-key":
|
||||||
|
if apiKey != "" {
|
||||||
|
req.Header.Set("x-key", apiKey)
|
||||||
|
}
|
||||||
|
case "none":
|
||||||
|
default:
|
||||||
|
if apiKey != "" {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func providerURL(base string, path string) string {
|
||||||
|
path = strings.TrimSpace(path)
|
||||||
|
if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") {
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
if path == "" {
|
||||||
|
path = "/"
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(path, "/") && !strings.HasPrefix(path, "?") {
|
||||||
|
path = "/" + path
|
||||||
|
}
|
||||||
|
return joinURL(base, path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveProviderPathTemplate(path string, upstreamTaskID string) string {
|
||||||
|
replacements := [][2]string{
|
||||||
|
{"${upstream_task_id}", upstreamTaskID},
|
||||||
|
{"{{upstream_task_id}}", upstreamTaskID},
|
||||||
|
{"{upstream_task_id}", upstreamTaskID},
|
||||||
|
{"${task_id}", upstreamTaskID},
|
||||||
|
{"{{task_id}}", upstreamTaskID},
|
||||||
|
{"{task_id}", upstreamTaskID},
|
||||||
|
{"${taskId}", upstreamTaskID},
|
||||||
|
{"${taskID}", upstreamTaskID},
|
||||||
|
{"{{taskId}}", upstreamTaskID},
|
||||||
|
{"{{taskID}}", upstreamTaskID},
|
||||||
|
{"{taskId}", upstreamTaskID},
|
||||||
|
{"{taskID}", upstreamTaskID},
|
||||||
|
}
|
||||||
|
for _, replacement := range replacements {
|
||||||
|
path = strings.ReplaceAll(path, replacement[0], replacement[1])
|
||||||
|
}
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
|
||||||
|
func providerTaskID(spec providerTaskSpec, result map[string]any) string {
|
||||||
|
paths := append([]string{}, spec.TaskIDPaths...)
|
||||||
|
paths = append(paths, "task_id", "taskId", "id", "job_id", "Response.JobId", "output.task_id", "data.task_id", "polling_url")
|
||||||
|
for _, path := range paths {
|
||||||
|
if value := stringFromPathValue(valueAtPath(result, path)); value != "" {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func providerTaskStatus(spec providerTaskSpec, result map[string]any) string {
|
||||||
|
if result == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if value, ok := valueAtPath(result, "status").(float64); ok {
|
||||||
|
if value == 2 {
|
||||||
|
return "success"
|
||||||
|
}
|
||||||
|
if value == 3 {
|
||||||
|
return "failed"
|
||||||
|
}
|
||||||
|
return "process"
|
||||||
|
}
|
||||||
|
paths := append([]string{}, spec.StatusPaths...)
|
||||||
|
paths = append(paths, "status", "state", "task_status", "output.task_status", "Response.Status", "data.status")
|
||||||
|
for _, path := range paths {
|
||||||
|
if value := stringFromPathValue(valueAtPath(result, path)); value != "" {
|
||||||
|
return strings.ToLower(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func stringFromPathValue(value any) string {
|
||||||
|
if value == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
text := strings.TrimSpace(fmt.Sprint(value))
|
||||||
|
if text == "" || text == "<nil>" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
|
||||||
|
func isProviderTaskSuccess(spec providerTaskSpec, result map[string]any) bool {
|
||||||
|
return containsStatus(append([]string{"success", "succeeded", "completed", "complete", "done", "ready", "succeed", "succeeded", "suceeded", "done", "done"}, spec.SuccessStatuses...), providerTaskStatus(spec, result))
|
||||||
|
}
|
||||||
|
|
||||||
|
func isProviderTaskFailure(spec providerTaskSpec, result map[string]any) bool {
|
||||||
|
return containsStatus(append([]string{"failed", "failure", "error", "cancelled", "canceled", "fail", "expired", "task not found"}, spec.FailureStatuses...), providerTaskStatus(spec, result))
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsStatus(values []string, status string) bool {
|
||||||
|
status = strings.ToLower(strings.TrimSpace(status))
|
||||||
|
for _, value := range values {
|
||||||
|
if strings.ToLower(strings.TrimSpace(value)) == status {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasProviderTaskResult(result map[string]any) bool {
|
||||||
|
return result["data"] != nil || valueAtPath(result, "output.image_urls") != nil || valueAtPath(result, "output.video_url") != nil || valueAtPath(result, "Response.ResultVideoUrl") != nil || valueAtPath(result, "Response.ResultImages") != nil || result["urls"] != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeProviderTaskResult(request Request, spec providerTaskSpec, result map[string]any, upstreamTaskID string) map[string]any {
|
||||||
|
out := cloneMapAny(result)
|
||||||
|
out["status"] = "success"
|
||||||
|
if upstreamTaskID != "" {
|
||||||
|
out["upstream_task_id"] = upstreamTaskID
|
||||||
|
}
|
||||||
|
if out["created"] == nil {
|
||||||
|
out["created"] = time.Now().UnixMilli()
|
||||||
|
}
|
||||||
|
if out["model"] == nil {
|
||||||
|
out["model"] = request.Model
|
||||||
|
}
|
||||||
|
if _, ok := out["data"].([]any); !ok {
|
||||||
|
if out["data"] != nil {
|
||||||
|
out["raw_data"] = out["data"]
|
||||||
|
}
|
||||||
|
out["data"] = providerTaskData(request, result)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func providerTaskData(request Request, result map[string]any) []any {
|
||||||
|
fileType := "image"
|
||||||
|
if request.Kind == "videos.generations" || strings.Contains(request.ModelType, "video") {
|
||||||
|
fileType = "video"
|
||||||
|
}
|
||||||
|
urlValues := []any{}
|
||||||
|
for _, path := range []string{
|
||||||
|
"urls",
|
||||||
|
"image_urls",
|
||||||
|
"data.image_urls",
|
||||||
|
"data.images",
|
||||||
|
"output.image_urls",
|
||||||
|
"output.video_url",
|
||||||
|
"output.output",
|
||||||
|
"data.output",
|
||||||
|
"data.video_url",
|
||||||
|
"video_url",
|
||||||
|
"preview_url",
|
||||||
|
"Response.ResultImages",
|
||||||
|
"Response.ResultVideoUrl",
|
||||||
|
} {
|
||||||
|
appendURLValues(&urlValues, valueAtPath(result, path))
|
||||||
|
}
|
||||||
|
data := make([]any, 0, len(urlValues))
|
||||||
|
for _, raw := range urlValues {
|
||||||
|
if url := strings.TrimSpace(fmt.Sprint(raw)); url != "" {
|
||||||
|
data = append(data, map[string]any{"type": fileType, "url": url})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(data) == 0 {
|
||||||
|
if base64Values := valueAtPath(result, "data.binary_data_base64"); base64Values != nil {
|
||||||
|
values := []any{}
|
||||||
|
appendURLValues(&values, base64Values)
|
||||||
|
for _, raw := range values {
|
||||||
|
if content := strings.TrimSpace(fmt.Sprint(raw)); content != "" {
|
||||||
|
data = append(data, map[string]any{"type": fileType, "content": content, "uploaded": false})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
func appendURLValues(out *[]any, value any) {
|
||||||
|
switch typed := value.(type) {
|
||||||
|
case nil:
|
||||||
|
case string:
|
||||||
|
*out = append(*out, typed)
|
||||||
|
case []any:
|
||||||
|
for _, item := range typed {
|
||||||
|
appendURLValues(out, item)
|
||||||
|
}
|
||||||
|
case []string:
|
||||||
|
for _, item := range typed {
|
||||||
|
*out = append(*out, item)
|
||||||
|
}
|
||||||
|
case map[string]any:
|
||||||
|
for _, key := range []string{"url", "image_url", "imageUrl", "video_url", "videoUrl", "content", "output"} {
|
||||||
|
if item := strings.TrimSpace(fmt.Sprint(typed[key])); item != "" && item != "<nil>" {
|
||||||
|
*out = append(*out, item)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func providerTaskFailure(spec providerTaskSpec, result map[string]any, requestID string, startedAt time.Time) error {
|
||||||
|
message := firstNonEmptyString(valueAtPath(result, "message"), valueAtPath(result, "error.message"), valueAtPath(result, "error"), valueAtPath(result, "Response.ErrorMessage"), valueAtPath(result, "comment"), spec.Name+" task failed")
|
||||||
|
return &ClientError{
|
||||||
|
Code: firstNonEmptyString(valueAtPath(result, "code"), valueAtPath(result, "error_code"), valueAtPath(result, "Response.ErrorCode"), "provider_failed"),
|
||||||
|
Message: message,
|
||||||
|
RequestID: requestID,
|
||||||
|
ResponseStartedAt: startedAt,
|
||||||
|
ResponseFinishedAt: time.Now(),
|
||||||
|
ResponseDurationMS: responseDurationMS(startedAt, time.Now()),
|
||||||
|
Retryable: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func providerPollInterval(request Request) time.Duration {
|
||||||
|
return durationFromConfig(request.Candidate.PlatformConfig, 2*time.Second, "pollIntervalMs", "poll_interval_ms")
|
||||||
|
}
|
||||||
|
|
||||||
|
func providerPollTimeout(request Request) time.Duration {
|
||||||
|
return durationFromConfig(request.Candidate.PlatformConfig, 10*time.Minute, "pollTimeoutMs", "poll_timeout_ms", "timeoutMs")
|
||||||
|
}
|
||||||
|
|
||||||
|
func durationFromConfig(config map[string]any, fallback time.Duration, keys ...string) time.Duration {
|
||||||
|
for _, key := range keys {
|
||||||
|
switch value := config[key].(type) {
|
||||||
|
case int:
|
||||||
|
if value > 0 {
|
||||||
|
return time.Duration(value) * time.Millisecond
|
||||||
|
}
|
||||||
|
case int64:
|
||||||
|
if value > 0 {
|
||||||
|
return time.Duration(value) * time.Millisecond
|
||||||
|
}
|
||||||
|
case float64:
|
||||||
|
if value > 0 {
|
||||||
|
return time.Duration(value) * time.Millisecond
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
if parsed, err := time.ParseDuration(value); err == nil && parsed > 0 {
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
|
||||||
|
func valueAtPath(values map[string]any, path string) any {
|
||||||
|
if values == nil || strings.TrimSpace(path) == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var current any = values
|
||||||
|
for _, part := range strings.Split(path, ".") {
|
||||||
|
object, ok := current.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
current = object[part]
|
||||||
|
}
|
||||||
|
return current
|
||||||
|
}
|
||||||
|
|
||||||
|
func mediaPromptText(body map[string]any) string {
|
||||||
|
if prompt := strings.TrimSpace(stringFromAny(body["prompt"])); prompt != "" {
|
||||||
|
return prompt
|
||||||
|
}
|
||||||
|
content, _ := body["content"].([]any)
|
||||||
|
for _, item := range content {
|
||||||
|
if part, ok := item.(map[string]any); ok && strings.TrimSpace(stringFromAny(part["type"])) == "text" {
|
||||||
|
if text := strings.TrimSpace(stringFromAny(part["text"])); text != "" {
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
339
apps/api/internal/clients/provider_task_test.go
Normal file
339
apps/api/internal/clients/provider_task_test.go
Normal file
@ -0,0 +1,339 @@
|
|||||||
|
package clients
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestProviderTaskClientsSubmitAndPoll(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
client Client
|
||||||
|
provider string
|
||||||
|
specType string
|
||||||
|
submitMatch func(*http.Request) bool
|
||||||
|
submitResponse string
|
||||||
|
pollMatch func(*http.Request) bool
|
||||||
|
pollResponse string
|
||||||
|
authHeader string
|
||||||
|
resultURL string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "jimeng",
|
||||||
|
client: JimengClient{},
|
||||||
|
provider: "jimeng",
|
||||||
|
specType: "jimeng",
|
||||||
|
submitMatch: func(r *http.Request) bool {
|
||||||
|
return r.Method == http.MethodPost && r.URL.Query().Get("Action") == "CVSubmitTask"
|
||||||
|
},
|
||||||
|
submitResponse: `{"code":10000,"data":{"task_id":"remote-1"}}`,
|
||||||
|
pollMatch: func(r *http.Request) bool {
|
||||||
|
return r.Method == http.MethodPost && r.URL.Query().Get("Action") == "CVSync2AsyncGetResult"
|
||||||
|
},
|
||||||
|
pollResponse: `{"code":10000,"data":{"status":"done","video_url":"https://cdn.example/jimeng.mp4"}}`,
|
||||||
|
authHeader: "Bearer test-key",
|
||||||
|
resultURL: "https://cdn.example/jimeng.mp4",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "blackforest",
|
||||||
|
client: BlackforestClient{},
|
||||||
|
provider: "blackforest",
|
||||||
|
specType: "blackforest",
|
||||||
|
submitMatch: func(r *http.Request) bool {
|
||||||
|
return r.Method == http.MethodPost && strings.HasPrefix(r.URL.Path, "/provider-model")
|
||||||
|
},
|
||||||
|
submitResponse: `{"polling_url":"__SERVER__/poll/remote-1"}`,
|
||||||
|
pollMatch: func(r *http.Request) bool { return r.Method == http.MethodGet && r.URL.Path == "/poll/remote-1" },
|
||||||
|
pollResponse: `{"status":"Ready","urls":["https://cdn.example/flux.png"]}`,
|
||||||
|
authHeader: "test-key",
|
||||||
|
resultURL: "https://cdn.example/flux.png",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hunyuan-image",
|
||||||
|
client: HunyuanImageClient{},
|
||||||
|
provider: "tencent-hunyuan-image",
|
||||||
|
specType: "tencent-hunyuan-image",
|
||||||
|
submitMatch: func(r *http.Request) bool {
|
||||||
|
return r.Method == http.MethodPost && r.URL.Query().Get("Action") == "SubmitHunyuanImageJob"
|
||||||
|
},
|
||||||
|
submitResponse: `{"Response":{"JobId":"remote-1"}}`,
|
||||||
|
pollMatch: func(r *http.Request) bool {
|
||||||
|
return r.Method == http.MethodGet && r.URL.Query().Get("Action") == "QueryHunyuanImageJob"
|
||||||
|
},
|
||||||
|
pollResponse: `{"Response":{"Status":"DONE","ResultImages":["https://cdn.example/hunyuan.png"]}}`,
|
||||||
|
authHeader: "Bearer test-key",
|
||||||
|
resultURL: "https://cdn.example/hunyuan.png",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hunyuan-video",
|
||||||
|
client: HunyuanVideoClient{},
|
||||||
|
provider: "tencent-hunyuan-video",
|
||||||
|
specType: "tencent-hunyuan-video",
|
||||||
|
submitMatch: func(r *http.Request) bool {
|
||||||
|
return r.Method == http.MethodPost && r.URL.Query().Get("Action") == "SubmitTextToVideoJob"
|
||||||
|
},
|
||||||
|
submitResponse: `{"Response":{"JobId":"remote-1"}}`,
|
||||||
|
pollMatch: func(r *http.Request) bool {
|
||||||
|
return r.Method == http.MethodGet && r.URL.Query().Get("Action") == "QueryVideoJob"
|
||||||
|
},
|
||||||
|
pollResponse: `{"Response":{"Status":"DONE","ResultVideoUrl":"https://cdn.example/hunyuan.mp4"}}`,
|
||||||
|
authHeader: "Bearer test-key",
|
||||||
|
resultURL: "https://cdn.example/hunyuan.mp4",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "minimax",
|
||||||
|
client: MinimaxClient{},
|
||||||
|
provider: "minimax",
|
||||||
|
specType: "minimax",
|
||||||
|
submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Path == "/video_generation" },
|
||||||
|
submitResponse: `{"task_id":123}`,
|
||||||
|
pollMatch: func(r *http.Request) bool {
|
||||||
|
return r.Method == http.MethodGet && r.URL.Path == "/query/video_generation" && r.URL.Query().Get("task_id") == "123"
|
||||||
|
},
|
||||||
|
pollResponse: `{"status":"Success","file_id":"file-1","video_url":"https://cdn.example/minimax.mp4"}`,
|
||||||
|
authHeader: "Bearer test-key",
|
||||||
|
resultURL: "https://cdn.example/minimax.mp4",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "midjourney",
|
||||||
|
client: MidjourneyClient{},
|
||||||
|
provider: "midjourney",
|
||||||
|
specType: "midjourney",
|
||||||
|
submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Path == "/diffusion" },
|
||||||
|
submitResponse: `{"job_id":"remote-1"}`,
|
||||||
|
pollMatch: func(r *http.Request) bool { return r.Method == http.MethodGet && r.URL.Path == "/job/remote-1" },
|
||||||
|
pollResponse: `{"status":"completed","output":{"image_urls":["https://cdn.example/mj.png"]}}`,
|
||||||
|
authHeader: "Bearer test-key",
|
||||||
|
resultURL: "https://cdn.example/mj.png",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "vidu",
|
||||||
|
client: ViduClient{},
|
||||||
|
provider: "vidu",
|
||||||
|
specType: "vidu",
|
||||||
|
submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Path == "/text2video" },
|
||||||
|
submitResponse: `{"task_id":"remote-1"}`,
|
||||||
|
pollMatch: func(r *http.Request) bool {
|
||||||
|
return r.Method == http.MethodGet && r.URL.Path == "/tasks/remote-1/creations"
|
||||||
|
},
|
||||||
|
pollResponse: `{"state":"success","video_url":"https://cdn.example/vidu.mp4"}`,
|
||||||
|
authHeader: "Token test-key",
|
||||||
|
resultURL: "https://cdn.example/vidu.mp4",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "aliyun-bailian",
|
||||||
|
client: AliyunBailianClient{},
|
||||||
|
provider: "aliyun-bailian",
|
||||||
|
specType: "aliyun-bailian",
|
||||||
|
submitMatch: func(r *http.Request) bool {
|
||||||
|
return r.Method == http.MethodPost && r.URL.Path == "/services/aigc/video-generation/video-synthesis"
|
||||||
|
},
|
||||||
|
submitResponse: `{"output":{"task_id":"remote-1"}}`,
|
||||||
|
pollMatch: func(r *http.Request) bool { return r.Method == http.MethodGet && r.URL.Path == "/tasks/remote-1" },
|
||||||
|
pollResponse: `{"output":{"task_status":"SUCCEEDED","video_url":"https://cdn.example/aliyun.mp4"}}`,
|
||||||
|
authHeader: "Bearer test-key",
|
||||||
|
resultURL: "https://cdn.example/aliyun.mp4",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "newapi",
|
||||||
|
client: NewAPIClient{},
|
||||||
|
provider: "newapi",
|
||||||
|
specType: "newapi",
|
||||||
|
submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Path == "/videos/generations" },
|
||||||
|
submitResponse: `{"task_id":"remote-1"}`,
|
||||||
|
pollMatch: func(r *http.Request) bool {
|
||||||
|
return r.Method == http.MethodGet && r.URL.Path == "/videos/generations/remote-1"
|
||||||
|
},
|
||||||
|
pollResponse: `{"status":"SUCCESS","data":{"output":"https://cdn.example/newapi.mp4"}}`,
|
||||||
|
authHeader: "Bearer test-key",
|
||||||
|
resultURL: "https://cdn.example/newapi.mp4",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if tc.authHeader == "test-key" {
|
||||||
|
if r.Header.Get("x-key") != tc.authHeader {
|
||||||
|
t.Fatalf("unexpected x-key header: %q", r.Header.Get("x-key"))
|
||||||
|
}
|
||||||
|
} else if r.Header.Get("Authorization") != tc.authHeader {
|
||||||
|
t.Fatalf("unexpected auth header: %q", r.Header.Get("Authorization"))
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("x-request-id", "req-"+tc.name)
|
||||||
|
switch {
|
||||||
|
case tc.submitMatch(r):
|
||||||
|
_, _ = w.Write([]byte(strings.ReplaceAll(tc.submitResponse, "__SERVER__", "http://"+r.Host)))
|
||||||
|
case tc.pollMatch(r):
|
||||||
|
_, _ = w.Write([]byte(tc.pollResponse))
|
||||||
|
default:
|
||||||
|
t.Fatalf("unexpected request: %s %s", r.Method, r.URL.String())
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
var submittedRemoteTaskID string
|
||||||
|
request := Request{
|
||||||
|
Kind: "videos.generations",
|
||||||
|
ModelType: "video_generate",
|
||||||
|
Model: "alias-model",
|
||||||
|
Body: map[string]any{"model": "alias-model", "prompt": "hello"},
|
||||||
|
Candidate: store.RuntimeModelCandidate{
|
||||||
|
Provider: tc.provider,
|
||||||
|
SpecType: tc.specType,
|
||||||
|
BaseURL: server.URL,
|
||||||
|
Credentials: map[string]any{"apiKey": "test-key"},
|
||||||
|
PlatformConfig: map[string]any{"pollIntervalMs": 1, "pollTimeoutMs": 1000},
|
||||||
|
ModelName: "alias-model",
|
||||||
|
ProviderModelName: "provider-model",
|
||||||
|
ModelType: "video_generate",
|
||||||
|
ClientID: tc.name + ":test",
|
||||||
|
},
|
||||||
|
OnRemoteTaskSubmitted: func(remoteTaskID string, payload map[string]any) error {
|
||||||
|
submittedRemoteTaskID = remoteTaskID
|
||||||
|
if payload["payload"] == nil || payload["submit"] == nil {
|
||||||
|
t.Fatalf("missing remote payload: %#v", payload)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err := tc.client.Run(context.Background(), request)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("run failed: %v", err)
|
||||||
|
}
|
||||||
|
data, ok := response.Result["data"].([]any)
|
||||||
|
if !ok || len(data) == 0 {
|
||||||
|
t.Fatalf("missing data: %#v", response.Result)
|
||||||
|
}
|
||||||
|
first, _ := data[0].(map[string]any)
|
||||||
|
if first["url"] != tc.resultURL {
|
||||||
|
t.Fatalf("unexpected result url: %#v", response.Result)
|
||||||
|
}
|
||||||
|
if response.RequestID != "req-"+tc.name {
|
||||||
|
t.Fatalf("unexpected request id: %q", response.RequestID)
|
||||||
|
}
|
||||||
|
if submittedRemoteTaskID == "" {
|
||||||
|
t.Fatalf("expected remote task submission")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProviderTaskClientFailureAndRetryableErrors(t *testing.T) {
|
||||||
|
t.Run("poll failure", func(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("x-request-id", "req-failed")
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/videos/generations":
|
||||||
|
_, _ = w.Write([]byte(`{"task_id":"remote-1"}`))
|
||||||
|
case "/videos/generations/remote-1":
|
||||||
|
_, _ = w.Write([]byte(`{"status":"failed","code":"UPSTREAM_FAILED","message":"provider rejected"}`))
|
||||||
|
default:
|
||||||
|
t.Fatalf("unexpected request: %s", r.URL.String())
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
_, err := (NewAPIClient{}).Run(context.Background(), Request{
|
||||||
|
Kind: "videos.generations",
|
||||||
|
ModelType: "video_generate",
|
||||||
|
Model: "alias-model",
|
||||||
|
Body: map[string]any{"model": "alias-model", "prompt": "hello"},
|
||||||
|
Candidate: store.RuntimeModelCandidate{
|
||||||
|
Provider: "newapi",
|
||||||
|
SpecType: "newapi",
|
||||||
|
BaseURL: server.URL,
|
||||||
|
Credentials: map[string]any{"apiKey": "test-key"},
|
||||||
|
PlatformConfig: map[string]any{"pollIntervalMs": 1, "pollTimeoutMs": 1000},
|
||||||
|
ModelName: "alias-model",
|
||||||
|
ProviderModelName: "provider-model",
|
||||||
|
ModelType: "video_generate",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
var clientErr *ClientError
|
||||||
|
if !errors.As(err, &clientErr) || clientErr.Code != "UPSTREAM_FAILED" || clientErr.Retryable {
|
||||||
|
t.Fatalf("expected non-retryable upstream failure, got %#v", err)
|
||||||
|
}
|
||||||
|
if clientErr.RequestID != "req-failed" {
|
||||||
|
t.Fatalf("unexpected request id: %q", clientErr.RequestID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("submit rate limit", func(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("x-request-id", "req-rate-limit")
|
||||||
|
w.WriteHeader(http.StatusTooManyRequests)
|
||||||
|
_, _ = w.Write([]byte(`{"error":{"message":"slow down"}}`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
_, err := (NewAPIClient{}).Run(context.Background(), Request{
|
||||||
|
Kind: "videos.generations",
|
||||||
|
ModelType: "video_generate",
|
||||||
|
Model: "alias-model",
|
||||||
|
Body: map[string]any{"model": "alias-model", "prompt": "hello"},
|
||||||
|
Candidate: store.RuntimeModelCandidate{
|
||||||
|
Provider: "newapi",
|
||||||
|
SpecType: "newapi",
|
||||||
|
BaseURL: server.URL,
|
||||||
|
Credentials: map[string]any{"apiKey": "test-key"},
|
||||||
|
PlatformConfig: map[string]any{"pollIntervalMs": 1, "pollTimeoutMs": 1000},
|
||||||
|
ModelName: "alias-model",
|
||||||
|
ProviderModelName: "provider-model",
|
||||||
|
ModelType: "video_generate",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
var clientErr *ClientError
|
||||||
|
if !errors.As(err, &clientErr) || !clientErr.Retryable || clientErr.RequestID != "req-rate-limit" {
|
||||||
|
t.Fatalf("expected retryable rate limit with request id, got %#v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProviderTaskClientResumeSkipsSubmit(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method == http.MethodPost {
|
||||||
|
t.Fatal("submit should not run for resumed remote task")
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"status":"success","data":{"output":"https://cdn.example/resume.mp4"}}`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
response, err := (NewAPIClient{}).Run(context.Background(), Request{
|
||||||
|
Kind: "videos.generations",
|
||||||
|
ModelType: "video_generate",
|
||||||
|
Model: "alias-model",
|
||||||
|
Body: map[string]any{"model": "alias-model", "prompt": "hello"},
|
||||||
|
RemoteTaskID: "remote-1",
|
||||||
|
RemoteTaskPayload: map[string]any{"payload": map[string]any{"prompt": "old"}},
|
||||||
|
Candidate: store.RuntimeModelCandidate{
|
||||||
|
Provider: "newapi",
|
||||||
|
SpecType: "newapi",
|
||||||
|
BaseURL: server.URL,
|
||||||
|
Credentials: map[string]any{"apiKey": "test-key"},
|
||||||
|
PlatformConfig: map[string]any{"pollIntervalMs": 1, "pollTimeoutMs": 1000},
|
||||||
|
ModelName: "alias-model",
|
||||||
|
ProviderModelName: "provider-model",
|
||||||
|
ModelType: "video_generate",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("run failed: %v", err)
|
||||||
|
}
|
||||||
|
data := response.Result["data"].([]any)
|
||||||
|
first := data[0].(map[string]any)
|
||||||
|
if first["url"] != "https://cdn.example/resume.mp4" || response.Result["upstream_task_id"] != "remote-1" {
|
||||||
|
t.Fatalf("unexpected response: %#v", response.Result)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -48,7 +48,13 @@ type Progress struct {
|
|||||||
Payload map[string]any
|
Payload map[string]any
|
||||||
}
|
}
|
||||||
|
|
||||||
type StreamDelta func(text string) error
|
type StreamDeltaEvent struct {
|
||||||
|
Text string
|
||||||
|
ReasoningContent string
|
||||||
|
Event map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
type StreamDelta func(event StreamDeltaEvent) error
|
||||||
|
|
||||||
type Client interface {
|
type Client interface {
|
||||||
Run(ctx context.Context, request Request) (Response, error)
|
Run(ctx context.Context, request Request) (Response, error)
|
||||||
@ -146,5 +152,8 @@ func responseDurationMS(startedAt time.Time, finishedAt time.Time) int64 {
|
|||||||
if duration < 0 {
|
if duration < 0 {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
if duration == 0 && finishedAt.After(startedAt) {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
return duration
|
return duration
|
||||||
}
|
}
|
||||||
|
|||||||
481
apps/api/internal/clients/universal.go
Normal file
481
apps/api/internal/clients/universal.go
Normal file
@ -0,0 +1,481 @@
|
|||||||
|
package clients
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
scriptengine "github.com/easyai/easyai-ai-gateway/apps/api/internal/script"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UniversalClient struct {
|
||||||
|
HTTPClient *http.Client
|
||||||
|
ScriptExecutor *scriptengine.Executor
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c UniversalClient) Run(ctx context.Context, request Request) (Response, error) {
|
||||||
|
executor := c.ScriptExecutor
|
||||||
|
if executor == nil {
|
||||||
|
executor = &scriptengine.Executor{}
|
||||||
|
}
|
||||||
|
startedAt := time.Now()
|
||||||
|
modelType := strings.TrimSpace(request.ModelType)
|
||||||
|
if modelType == "" {
|
||||||
|
modelType = strings.TrimSpace(request.Candidate.ModelType)
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := cloneBody(request.Body)
|
||||||
|
upstreamTaskID := strings.TrimSpace(request.RemoteTaskID)
|
||||||
|
submitRequestID := upstreamTaskID
|
||||||
|
var submitResult map[string]any
|
||||||
|
|
||||||
|
if upstreamTaskID == "" {
|
||||||
|
var err error
|
||||||
|
payload, err = c.universalGetParams(ctx, executor, request, modelType)
|
||||||
|
if err != nil {
|
||||||
|
return Response{}, err
|
||||||
|
}
|
||||||
|
submitResult, submitRequestID, err = c.universalSubmit(ctx, executor, request, modelType, payload)
|
||||||
|
if err != nil {
|
||||||
|
return Response{}, annotateResponseError(err, submitRequestID, startedAt, time.Now())
|
||||||
|
}
|
||||||
|
if isUniversalSuccess(submitResult) && submitResult["data"] != nil {
|
||||||
|
return Response{
|
||||||
|
Result: normalizeUniversalResult(request, submitResult, ""),
|
||||||
|
RequestID: firstNonEmptyString(submitRequestID, requestIDFromResult(submitResult)),
|
||||||
|
Progress: providerProgress(request),
|
||||||
|
ResponseStartedAt: startedAt,
|
||||||
|
ResponseFinishedAt: time.Now(),
|
||||||
|
ResponseDurationMS: responseDurationMS(startedAt, time.Now()),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
if isUniversalFailure(submitResult) {
|
||||||
|
return Response{}, universalFailureError(submitResult, firstNonEmptyString(submitRequestID, requestIDFromResult(submitResult)), startedAt)
|
||||||
|
}
|
||||||
|
upstreamTaskID = universalTaskID(submitResult)
|
||||||
|
if upstreamTaskID == "" {
|
||||||
|
return Response{}, &ClientError{Code: "invalid_response", Message: "universal task id is missing", RequestID: submitRequestID, Retryable: false}
|
||||||
|
}
|
||||||
|
if request.OnRemoteTaskSubmitted != nil {
|
||||||
|
if err := request.OnRemoteTaskSubmitted(upstreamTaskID, map[string]any{"payload": payload, "submit": submitResult}); err != nil {
|
||||||
|
return Response{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if request.RemoteTaskPayload != nil {
|
||||||
|
if existingPayload, ok := request.RemoteTaskPayload["payload"].(map[string]any); ok {
|
||||||
|
payload = existingPayload
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result, requestID, err := c.universalPollUntilDone(ctx, executor, request, modelType, upstreamTaskID, payload, firstNonEmptyString(submitRequestID, upstreamTaskID), startedAt)
|
||||||
|
if err != nil {
|
||||||
|
return Response{}, err
|
||||||
|
}
|
||||||
|
finishedAt := time.Now()
|
||||||
|
return Response{
|
||||||
|
Result: normalizeUniversalResult(request, result, upstreamTaskID),
|
||||||
|
RequestID: firstNonEmptyString(requestID, submitRequestID, requestIDFromResult(result), upstreamTaskID),
|
||||||
|
Progress: universalProgress(request, upstreamTaskID),
|
||||||
|
ResponseStartedAt: startedAt,
|
||||||
|
ResponseFinishedAt: finishedAt,
|
||||||
|
ResponseDurationMS: responseDurationMS(startedAt, finishedAt),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c UniversalClient) universalGetParams(ctx context.Context, executor *scriptengine.Executor, request Request, modelType string) (map[string]any, error) {
|
||||||
|
if scriptText := universalSceneScript(request.Candidate.PlatformConfig, modelType, "customGetParamsScript", "custom_get_params_script"); scriptText != "" {
|
||||||
|
scriptContext := universalScriptContext(request, modelType, nil)
|
||||||
|
out, err := executor.Execute(ctx, scriptengine.Options{
|
||||||
|
Script: scriptText,
|
||||||
|
Args: []any{cloneBody(request.Body), scriptContext},
|
||||||
|
ContextData: scriptContext,
|
||||||
|
ScriptName: "custom_get_params_script:" + modelType,
|
||||||
|
PreferredEntryNames: []string{"getGenerateParams", "getParams", "main", "handler"},
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
HTTPClient: httpClient(request.HTTPClient, c.HTTPClient),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, universalScriptError(err)
|
||||||
|
}
|
||||||
|
if params, ok := out.(map[string]any); ok && params != nil {
|
||||||
|
if params["_originalParams"] == nil {
|
||||||
|
params["_originalParams"] = cloneBody(request.Body)
|
||||||
|
}
|
||||||
|
return params, nil
|
||||||
|
}
|
||||||
|
return nil, &ClientError{Code: "invalid_response", Message: "custom get params script must return an object", Retryable: false}
|
||||||
|
}
|
||||||
|
body := universalDefaultPayload(request)
|
||||||
|
body["_originalParams"] = cloneBody(request.Body)
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c UniversalClient) universalSubmit(ctx context.Context, executor *scriptengine.Executor, request Request, modelType string, payload map[string]any) (map[string]any, string, error) {
|
||||||
|
if scriptText := universalSceneScript(request.Candidate.PlatformConfig, modelType, "customSubmitScript", "custom_submit_script"); scriptText != "" {
|
||||||
|
scriptContext := universalScriptContext(request, modelType, payload)
|
||||||
|
out, err := executor.Execute(ctx, scriptengine.Options{
|
||||||
|
Script: scriptText,
|
||||||
|
Args: []any{cloneBody(payload), scriptContext},
|
||||||
|
ContextData: scriptContext,
|
||||||
|
ScriptName: "custom_submit_script:" + modelType,
|
||||||
|
PreferredEntryNames: []string{"submitTask", "submitParams", "submit", "main", "handler"},
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
HTTPClient: httpClient(request.HTTPClient, c.HTTPClient),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", universalScriptError(err)
|
||||||
|
}
|
||||||
|
result, ok := out.(map[string]any)
|
||||||
|
if !ok || result == nil {
|
||||||
|
return nil, "", &ClientError{Code: "invalid_response", Message: "custom submit script must return an object", Retryable: false}
|
||||||
|
}
|
||||||
|
return result, requestIDFromResult(result), nil
|
||||||
|
}
|
||||||
|
endpoint := universalSubmitEndpoint(request)
|
||||||
|
result, requestID, err := universalPostJSON(ctx, httpClient(request.HTTPClient, c.HTTPClient), request.Candidate.BaseURL, endpoint, universalStripPrivatePayload(payload), request.Candidate.Credentials)
|
||||||
|
return result, requestID, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c UniversalClient) universalPollUntilDone(ctx context.Context, executor *scriptengine.Executor, request Request, modelType string, upstreamTaskID string, payload map[string]any, requestID string, startedAt time.Time) (map[string]any, string, error) {
|
||||||
|
interval := universalDurationConfig(request.Candidate.PlatformConfig, 2*time.Second, "pollIntervalMs", "poll_interval_ms")
|
||||||
|
timeout := universalDurationConfig(request.Candidate.PlatformConfig, 10*time.Minute, "pollTimeoutMs", "poll_timeout_ms", "timeoutMs")
|
||||||
|
deadline := time.NewTimer(timeout)
|
||||||
|
defer deadline.Stop()
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
var lastResult map[string]any
|
||||||
|
for {
|
||||||
|
pollStarted := time.Now()
|
||||||
|
result, pollRequestID, err := c.universalPoll(ctx, executor, request, modelType, upstreamTaskID, payload)
|
||||||
|
pollFinished := time.Now()
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", annotateResponseError(err, firstNonEmptyString(pollRequestID, requestID, upstreamTaskID), pollStarted, pollFinished)
|
||||||
|
}
|
||||||
|
lastResult = result
|
||||||
|
requestID = firstNonEmptyString(pollRequestID, requestID, requestIDFromResult(result), upstreamTaskID)
|
||||||
|
if isUniversalSuccess(result) {
|
||||||
|
return result, requestID, nil
|
||||||
|
}
|
||||||
|
if isUniversalFailure(result) {
|
||||||
|
return nil, "", universalFailureError(result, requestID, startedAt)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, "", &ClientError{Code: "cancelled", Message: ctx.Err().Error(), RequestID: requestID, Retryable: true}
|
||||||
|
case <-deadline.C:
|
||||||
|
return nil, "", &ClientError{Code: "timeout", Message: fmt.Sprintf("universal task %s did not finish before timeout; last status: %s", upstreamTaskID, universalStatus(lastResult)), RequestID: requestID, Retryable: true}
|
||||||
|
case <-ticker.C:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c UniversalClient) universalPoll(ctx context.Context, executor *scriptengine.Executor, request Request, modelType string, upstreamTaskID string, payload map[string]any) (map[string]any, string, error) {
|
||||||
|
if scriptText := universalSceneScript(request.Candidate.PlatformConfig, modelType, "customPollScript", "custom_poll_script"); scriptText != "" {
|
||||||
|
scriptContext := universalScriptContext(request, modelType, payload)
|
||||||
|
out, err := executor.Execute(ctx, scriptengine.Options{
|
||||||
|
Script: scriptText,
|
||||||
|
Args: []any{upstreamTaskID, scriptContext},
|
||||||
|
ContextData: scriptContext,
|
||||||
|
ScriptName: "custom_poll_script:" + modelType,
|
||||||
|
PreferredEntryNames: []string{"pollTask", "poll", "main", "handler"},
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
HTTPClient: httpClient(request.HTTPClient, c.HTTPClient),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", universalScriptError(err)
|
||||||
|
}
|
||||||
|
result, ok := out.(map[string]any)
|
||||||
|
if !ok || result == nil {
|
||||||
|
return nil, "", &ClientError{Code: "invalid_response", Message: "custom poll script must return an object", Retryable: false}
|
||||||
|
}
|
||||||
|
return result, requestIDFromResult(result), nil
|
||||||
|
}
|
||||||
|
pollURL := resolveUniversalTaskURL(request.Candidate.PlatformConfig, upstreamTaskID)
|
||||||
|
if pollURL == "" {
|
||||||
|
return nil, "", &ClientError{Code: "missing_configuration", Message: "universal getTaskURL is required", Retryable: false}
|
||||||
|
}
|
||||||
|
return universalGetJSON(ctx, httpClient(request.HTTPClient, c.HTTPClient), pollURL, request.Candidate.Credentials)
|
||||||
|
}
|
||||||
|
|
||||||
|
func universalScriptContext(request Request, modelType string, payload map[string]any) map[string]any {
|
||||||
|
baseURL := strings.TrimRight(strings.TrimSpace(request.Candidate.BaseURL), "/")
|
||||||
|
getTaskURL := universalConfigString(request.Candidate.PlatformConfig, "getTaskURL", "get_task_url")
|
||||||
|
context := map[string]any{
|
||||||
|
"__easyaiScriptContext": true,
|
||||||
|
"baseURL": baseURL,
|
||||||
|
"getTaskURL": getTaskURL,
|
||||||
|
"authValues": cloneMapAny(request.Candidate.Credentials),
|
||||||
|
"headers": map[string]any{},
|
||||||
|
"payload": cloneMapAny(payload),
|
||||||
|
"type": modelType,
|
||||||
|
"options": map[string]any{
|
||||||
|
"task_id": request.RemoteTaskID,
|
||||||
|
"upstream_task_id": request.RemoteTaskID,
|
||||||
|
"model": request.Model,
|
||||||
|
"providerModelName": request.Candidate.ProviderModelName,
|
||||||
|
"platformId": request.Candidate.PlatformID,
|
||||||
|
"platformModelId": request.Candidate.PlatformModelID,
|
||||||
|
"canonicalModelKey": request.Candidate.CanonicalModelKey,
|
||||||
|
"modelType": modelType,
|
||||||
|
"timeout": universalDurationConfig(request.Candidate.PlatformConfig, 10*time.Minute, "pollTimeoutMs", "poll_timeout_ms").Milliseconds(),
|
||||||
|
},
|
||||||
|
"env": cloneMapAny(request.Candidate.PlatformConfig),
|
||||||
|
"candidate": universalCandidateSnapshot(request),
|
||||||
|
}
|
||||||
|
context["createRequestURL"] = func(path string, base ...string) string {
|
||||||
|
selectedBase := baseURL
|
||||||
|
if len(base) > 0 && strings.TrimSpace(base[0]) != "" {
|
||||||
|
selectedBase = strings.TrimRight(strings.TrimSpace(base[0]), "/")
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") {
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
return selectedBase + "/" + strings.TrimLeft(path, "/")
|
||||||
|
}
|
||||||
|
context["creatRequestURL"] = context["createRequestURL"]
|
||||||
|
context["resolveGetTaskURL"] = func(taskID string) string {
|
||||||
|
return resolveUniversalTaskURL(request.Candidate.PlatformConfig, taskID)
|
||||||
|
}
|
||||||
|
return context
|
||||||
|
}
|
||||||
|
|
||||||
|
func universalCandidateSnapshot(request Request) map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"modelName": request.Candidate.ModelName,
|
||||||
|
"modelAlias": request.Candidate.ModelAlias,
|
||||||
|
"providerModelName": request.Candidate.ProviderModelName,
|
||||||
|
"provider": request.Candidate.Provider,
|
||||||
|
"platformId": request.Candidate.PlatformID,
|
||||||
|
"platformModelId": request.Candidate.PlatformModelID,
|
||||||
|
"capabilities": cloneMapAny(request.Candidate.Capabilities),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func universalDefaultPayload(request Request) map[string]any {
|
||||||
|
body := cloneBody(request.Body)
|
||||||
|
body["model"] = upstreamModelName(request.Candidate)
|
||||||
|
if request.Kind == "images.generations" {
|
||||||
|
if n := firstPresent(body["n"], body["numImages"]); n != nil {
|
||||||
|
body["numImages"] = n
|
||||||
|
}
|
||||||
|
if aspectRatio := strings.TrimSpace(stringFromAny(body["aspect_ratio"])); aspectRatio != "" {
|
||||||
|
body["aspectRatio"] = aspectRatio
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
func universalSubmitEndpoint(request Request) string {
|
||||||
|
if endpoint := universalConfigString(request.Candidate.PlatformConfig, "submitPath", "submit_path"); endpoint != "" {
|
||||||
|
return endpoint
|
||||||
|
}
|
||||||
|
switch request.Kind {
|
||||||
|
case "images.generations":
|
||||||
|
return "/images/generations"
|
||||||
|
case "images.edits":
|
||||||
|
return "/images/edits"
|
||||||
|
case "videos.generations":
|
||||||
|
return "/video/generations"
|
||||||
|
default:
|
||||||
|
return "/" + strings.ReplaceAll(request.Kind, ".", "/")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func universalPostJSON(ctx context.Context, client *http.Client, baseURL string, endpoint string, body map[string]any, credentials map[string]any) (map[string]any, string, error) {
|
||||||
|
raw, _ := json.Marshal(body)
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, providerURL(baseURL, endpoint), bytes.NewReader(raw))
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
if apiKey := credential(credentials, "apiKey", "api_key", "key", "token"); apiKey != "" {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
|
}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", &ClientError{Code: "network", Message: err.Error(), Retryable: true}
|
||||||
|
}
|
||||||
|
requestID := requestIDFromHTTPResponse(resp)
|
||||||
|
result, err := decodeHTTPResponse(resp)
|
||||||
|
return result, requestID, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func universalGetJSON(ctx context.Context, client *http.Client, url string, credentials map[string]any) (map[string]any, string, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
if apiKey := credential(credentials, "apiKey", "api_key", "key", "token"); apiKey != "" {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
|
}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", &ClientError{Code: "network", Message: err.Error(), Retryable: true}
|
||||||
|
}
|
||||||
|
requestID := requestIDFromHTTPResponse(resp)
|
||||||
|
result, err := decodeHTTPResponse(resp)
|
||||||
|
return result, requestID, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeUniversalResult(request Request, result map[string]any, upstreamTaskID string) map[string]any {
|
||||||
|
out := cloneMapAny(result)
|
||||||
|
if out["created"] == nil {
|
||||||
|
out["created"] = time.Now().UnixMilli()
|
||||||
|
}
|
||||||
|
if out["task_id"] == nil {
|
||||||
|
out["task_id"] = upstreamTaskID
|
||||||
|
}
|
||||||
|
if out["upstream_task_id"] == nil {
|
||||||
|
out["upstream_task_id"] = upstreamTaskID
|
||||||
|
}
|
||||||
|
if out["model"] == nil {
|
||||||
|
out["model"] = request.Model
|
||||||
|
}
|
||||||
|
if out["status"] == nil {
|
||||||
|
out["status"] = "success"
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func universalScriptError(err error) error {
|
||||||
|
var scriptErr *scriptengine.Error
|
||||||
|
if strings.TrimSpace(err.Error()) == "" {
|
||||||
|
return &ClientError{Code: "script_error", Message: "script execution failed", Retryable: false}
|
||||||
|
}
|
||||||
|
if errors.As(err, &scriptErr) {
|
||||||
|
return &ClientError{Code: scriptErr.ErrorCode(), Message: scriptErr.Error(), Retryable: scriptErr.ErrorCode() == "script_timeout"}
|
||||||
|
}
|
||||||
|
return &ClientError{Code: "script_error", Message: err.Error(), Retryable: false}
|
||||||
|
}
|
||||||
|
|
||||||
|
func universalFailureError(result map[string]any, requestID string, startedAt time.Time) error {
|
||||||
|
message := firstNonEmptyString(result["message"], result["error"], result["error_message"], "universal task failed")
|
||||||
|
return &ClientError{
|
||||||
|
Code: firstNonEmptyString(result["code"], result["error_code"], "provider_failed"),
|
||||||
|
Message: message,
|
||||||
|
RequestID: requestID,
|
||||||
|
ResponseStartedAt: startedAt,
|
||||||
|
ResponseFinishedAt: time.Now(),
|
||||||
|
ResponseDurationMS: responseDurationMS(startedAt, time.Now()),
|
||||||
|
Retryable: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isUniversalSuccess(result map[string]any) bool {
|
||||||
|
switch universalStatus(result) {
|
||||||
|
case "success", "succeeded", "completed", "complete", "done":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isUniversalFailure(result map[string]any) bool {
|
||||||
|
switch universalStatus(result) {
|
||||||
|
case "failed", "failure", "error", "cancelled", "canceled":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func universalStatus(result map[string]any) string {
|
||||||
|
return strings.ToLower(strings.TrimSpace(firstNonEmptyString(result["status"], result["state"], result["task_status"])))
|
||||||
|
}
|
||||||
|
|
||||||
|
func universalTaskID(result map[string]any) string {
|
||||||
|
return firstNonEmptyString(result["upstream_task_id"], result["task_id"], result["taskId"], result["id"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func universalProgress(request Request, upstreamTaskID string) []Progress {
|
||||||
|
progress := providerProgress(request)
|
||||||
|
progress = append(progress, Progress{Phase: "polling", Progress: 0.65, Message: "provider task polled", Payload: map[string]any{"upstreamTaskId": upstreamTaskID}})
|
||||||
|
return progress
|
||||||
|
}
|
||||||
|
|
||||||
|
func universalStripPrivatePayload(payload map[string]any) map[string]any {
|
||||||
|
out := cloneMapAny(payload)
|
||||||
|
for _, key := range []string{"_originalParams", "_resolution", "_duration"} {
|
||||||
|
delete(out, key)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func universalSceneScript(config map[string]any, modelType string, keys ...string) string {
|
||||||
|
for _, key := range keys {
|
||||||
|
value := config[key]
|
||||||
|
switch typed := value.(type) {
|
||||||
|
case string:
|
||||||
|
if strings.TrimSpace(typed) != "" {
|
||||||
|
return strings.TrimSpace(typed)
|
||||||
|
}
|
||||||
|
case map[string]any:
|
||||||
|
if script := firstNonEmptyString(typed[modelType], typed["common"]); script != "" {
|
||||||
|
return script
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func universalConfigString(config map[string]any, keys ...string) string {
|
||||||
|
for _, key := range keys {
|
||||||
|
if value := strings.TrimSpace(fmt.Sprint(config[key])); value != "" && value != "<nil>" {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func universalDurationConfig(config map[string]any, fallback time.Duration, keys ...string) time.Duration {
|
||||||
|
for _, key := range keys {
|
||||||
|
switch value := config[key].(type) {
|
||||||
|
case int:
|
||||||
|
if value > 0 {
|
||||||
|
return time.Duration(value) * time.Millisecond
|
||||||
|
}
|
||||||
|
case int64:
|
||||||
|
if value > 0 {
|
||||||
|
return time.Duration(value) * time.Millisecond
|
||||||
|
}
|
||||||
|
case float64:
|
||||||
|
if value > 0 {
|
||||||
|
return time.Duration(value) * time.Millisecond
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
if parsed, err := time.ParseDuration(value); err == nil && parsed > 0 {
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveUniversalTaskURL(config map[string]any, upstreamTaskID string) string {
|
||||||
|
template := universalConfigString(config, "getTaskURL", "get_task_url")
|
||||||
|
out := strings.TrimSpace(template)
|
||||||
|
replacements := [][2]string{
|
||||||
|
{"${upstream_task_id}", upstreamTaskID},
|
||||||
|
{"{{upstream_task_id}}", upstreamTaskID},
|
||||||
|
{"{upstream_task_id}", upstreamTaskID},
|
||||||
|
{"${task_id}", upstreamTaskID},
|
||||||
|
{"{{task_id}}", upstreamTaskID},
|
||||||
|
{"{task_id}", upstreamTaskID},
|
||||||
|
{"${taskId}", upstreamTaskID},
|
||||||
|
{"${taskID}", upstreamTaskID},
|
||||||
|
{"{{taskId}}", upstreamTaskID},
|
||||||
|
{"{{taskID}}", upstreamTaskID},
|
||||||
|
{"{taskId}", upstreamTaskID},
|
||||||
|
{"{taskID}", upstreamTaskID},
|
||||||
|
}
|
||||||
|
for _, replacement := range replacements {
|
||||||
|
out = strings.ReplaceAll(out, replacement[0], replacement[1])
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
132
apps/api/internal/clients/universal_test.go
Normal file
132
apps/api/internal/clients/universal_test.go
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
package clients
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUniversalClientRunsCustomScripts(t *testing.T) {
|
||||||
|
request := Request{
|
||||||
|
Kind: "videos.generations",
|
||||||
|
ModelType: "video_generate",
|
||||||
|
Model: "custom-video",
|
||||||
|
Body: map[string]any{"model": "custom-video", "prompt": "hello"},
|
||||||
|
Candidate: testUniversalCandidate(map[string]any{
|
||||||
|
"customGetParamsScript": map[string]any{
|
||||||
|
"video_generate": `async function getGenerateParams(params, context) {
|
||||||
|
return { prompt: params.prompt + "-payload", model: context.candidate.providerModelName };
|
||||||
|
}`,
|
||||||
|
},
|
||||||
|
"customSubmitScript": map[string]any{
|
||||||
|
"video_generate": `async function submitTask(payload) {
|
||||||
|
return { status: "submitted", task_id: "task-" + payload.prompt };
|
||||||
|
}`,
|
||||||
|
},
|
||||||
|
"customPollScript": map[string]any{
|
||||||
|
"video_generate": `async function pollTask(taskId) {
|
||||||
|
return { status: "success", upstream_task_id: taskId, data: [{ url: "https://cdn.example/video.mp4" }] };
|
||||||
|
}`,
|
||||||
|
},
|
||||||
|
"pollIntervalMs": 1,
|
||||||
|
"pollTimeoutMs": 1000,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
var submitted string
|
||||||
|
request.OnRemoteTaskSubmitted = func(remoteTaskID string, payload map[string]any) error {
|
||||||
|
submitted = remoteTaskID
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err := (UniversalClient{}).Run(context.Background(), request)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("run failed: %v", err)
|
||||||
|
}
|
||||||
|
if submitted != "task-hello-payload" {
|
||||||
|
t.Fatalf("unexpected remote task id: %q", submitted)
|
||||||
|
}
|
||||||
|
if response.Result["upstream_task_id"] != "task-hello-payload" {
|
||||||
|
t.Fatalf("unexpected result: %#v", response.Result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUniversalClientDefaultSubmitAndPoll(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Header.Get("Authorization") != "Bearer test-key" {
|
||||||
|
t.Fatalf("missing authorization header")
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/video/generations":
|
||||||
|
_, _ = w.Write([]byte(`{"status":"submitted","task_id":"remote-1"}`))
|
||||||
|
case "/tasks/remote-1":
|
||||||
|
_, _ = w.Write([]byte(`{"status":"success","data":[{"url":"https://cdn.example/default.mp4"}]}`))
|
||||||
|
default:
|
||||||
|
t.Fatalf("unexpected path: %s", r.URL.Path)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
request := Request{
|
||||||
|
Kind: "videos.generations",
|
||||||
|
ModelType: "video_generate",
|
||||||
|
Model: "default-video",
|
||||||
|
Body: map[string]any{"model": "default-video", "prompt": "hello"},
|
||||||
|
Candidate: testUniversalCandidate(map[string]any{
|
||||||
|
"getTaskURL": server.URL + "/tasks/{{task_id}}",
|
||||||
|
"pollIntervalMs": 1,
|
||||||
|
"pollTimeoutMs": 1000,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
request.Candidate.BaseURL = server.URL
|
||||||
|
|
||||||
|
response, err := (UniversalClient{}).Run(context.Background(), request)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("run failed: %v", err)
|
||||||
|
}
|
||||||
|
if response.Result["upstream_task_id"] != "remote-1" {
|
||||||
|
t.Fatalf("unexpected result: %#v", response.Result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUniversalClientResumeSkipsSubmit(t *testing.T) {
|
||||||
|
request := Request{
|
||||||
|
Kind: "videos.generations",
|
||||||
|
ModelType: "video_generate",
|
||||||
|
Model: "resume-video",
|
||||||
|
Body: map[string]any{"model": "resume-video", "prompt": "hello"},
|
||||||
|
RemoteTaskID: "existing-1",
|
||||||
|
RemoteTaskPayload: map[string]any{"payload": map[string]any{"prompt": "old"}},
|
||||||
|
Candidate: testUniversalCandidate(map[string]any{
|
||||||
|
"customSubmitScript": `async function submitTask() { throw new Error("submit should not run"); }`,
|
||||||
|
"customPollScript": `async function pollTask(taskId) { return { status: "success", upstream_task_id: taskId, data: [{ url: "https://cdn.example/resume.mp4" }] }; }`,
|
||||||
|
"pollIntervalMs": 1,
|
||||||
|
"pollTimeoutMs": 1000,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err := (UniversalClient{}).Run(context.Background(), request)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("run failed: %v", err)
|
||||||
|
}
|
||||||
|
if response.Result["upstream_task_id"] != "existing-1" {
|
||||||
|
t.Fatalf("unexpected result: %#v", response.Result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testUniversalCandidate(config map[string]any) store.RuntimeModelCandidate {
|
||||||
|
return store.RuntimeModelCandidate{
|
||||||
|
Provider: "universal",
|
||||||
|
SpecType: "universal",
|
||||||
|
BaseURL: "https://provider.example",
|
||||||
|
Credentials: map[string]any{"apiKey": "test-key"},
|
||||||
|
PlatformConfig: config,
|
||||||
|
ModelName: "alias-model",
|
||||||
|
ProviderModelName: "provider-model",
|
||||||
|
ModelType: "video_generate",
|
||||||
|
ClientID: "universal:test",
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -45,11 +45,11 @@ func (c VolcesClient) runImage(ctx context.Context, request Request, apiKey stri
|
|||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
|
|
||||||
|
responseStartedAt := time.Now()
|
||||||
resp, err := httpClient(request.HTTPClient, c.HTTPClient).Do(req)
|
resp, err := httpClient(request.HTTPClient, c.HTTPClient).Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Response{}, &ClientError{Code: "network", Message: err.Error(), Retryable: true}
|
return Response{}, &ClientError{Code: "network", Message: err.Error(), Retryable: true}
|
||||||
}
|
}
|
||||||
responseStartedAt := time.Now()
|
|
||||||
requestID := requestIDFromHTTPResponse(resp)
|
requestID := requestIDFromHTTPResponse(resp)
|
||||||
result, err := decodeHTTPResponse(resp)
|
result, err := decodeHTTPResponse(resp)
|
||||||
responseFinishedAt := time.Now()
|
responseFinishedAt := time.Now()
|
||||||
|
|||||||
150
apps/api/internal/httpapi/chat_completions_mode_test.go
Normal file
150
apps/api/internal/httpapi/chat_completions_mode_test.go
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
package httpapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/runner"
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPlanTaskResponseTreatsAPIV1ChatCompletionsAsSynchronousCompatibleResponse(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/chat/completions", nil)
|
||||||
|
req.Header.Set("X-Async", "true")
|
||||||
|
|
||||||
|
plan := planTaskResponse("chat.completions", false, map[string]any{"stream": true}, req)
|
||||||
|
|
||||||
|
if plan.asyncMode {
|
||||||
|
t.Fatal("/api/v1/chat/completions must not enter async task mode")
|
||||||
|
}
|
||||||
|
if !plan.compatibleMode {
|
||||||
|
t.Fatal("/api/v1/chat/completions should return OpenAI-compatible response payloads")
|
||||||
|
}
|
||||||
|
if !plan.streamMode {
|
||||||
|
t.Fatal("stream=true should select SSE streaming mode")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPlanTaskResponseKeepsAsyncTaskModeForOtherAPIV1Tasks(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/images/generations", nil)
|
||||||
|
req.Header.Set("X-Async", "true")
|
||||||
|
|
||||||
|
plan := planTaskResponse("images.generations", false, map[string]any{"stream": true}, req)
|
||||||
|
|
||||||
|
if !plan.asyncMode {
|
||||||
|
t.Fatal("non-chat /api/v1 task endpoints should keep X-Async task mode")
|
||||||
|
}
|
||||||
|
if plan.compatibleMode {
|
||||||
|
t.Fatal("non-compatible /api/v1 task endpoints should not return OpenAI-compatible payloads")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteCompatibleTaskResponseReturnsJSONWhenStreamIsFalse(t *testing.T) {
|
||||||
|
executor := &fakeTaskExecutor{output: map[string]any{"id": "chatcmpl-test", "object": "chat.completion"}}
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/chat/completions", nil)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
writeCompatibleTaskResponse(context.Background(), recorder, req, executor, "chat.completions", "gpt-test", store.GatewayTask{ID: "task-test"}, &auth.User{}, false, false)
|
||||||
|
|
||||||
|
if recorder.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status=%d want=%d body=%s", recorder.Code, http.StatusOK, recorder.Body.String())
|
||||||
|
}
|
||||||
|
if executor.executeCalls != 1 || executor.streamCalls != 0 {
|
||||||
|
t.Fatalf("expected non-stream execute only, got execute=%d stream=%d", executor.executeCalls, executor.streamCalls)
|
||||||
|
}
|
||||||
|
var body map[string]any
|
||||||
|
if err := json.Unmarshal(recorder.Body.Bytes(), &body); err != nil {
|
||||||
|
t.Fatalf("decode response body: %v body=%s", err, recorder.Body.String())
|
||||||
|
}
|
||||||
|
if body["object"] != "chat.completion" {
|
||||||
|
t.Fatalf("unexpected compatible JSON response: %+v", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteCompatibleTaskResponseReturnsSSEWhenStreamIsTrue(t *testing.T) {
|
||||||
|
executor := &fakeTaskExecutor{
|
||||||
|
deltas: []clients.StreamDeltaEvent{{Text: "hel"}, {Text: "lo"}},
|
||||||
|
output: map[string]any{"id": "chatcmpl-test", "object": "chat.completion", "usage": map[string]any{"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3}},
|
||||||
|
}
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/chat/completions", nil)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
writeCompatibleTaskResponse(context.Background(), recorder, req, executor, "chat.completions", "gpt-test", store.GatewayTask{ID: "task-test"}, &auth.User{}, true, true)
|
||||||
|
|
||||||
|
if executor.executeCalls != 0 || executor.streamCalls != 1 {
|
||||||
|
t.Fatalf("expected stream execute only, got execute=%d stream=%d", executor.executeCalls, executor.streamCalls)
|
||||||
|
}
|
||||||
|
if contentType := recorder.Header().Get("Content-Type"); contentType != "text/event-stream" {
|
||||||
|
t.Fatalf("Content-Type=%q want text/event-stream", contentType)
|
||||||
|
}
|
||||||
|
body := recorder.Body.String()
|
||||||
|
for _, want := range []string{`data: {`, `"role":"assistant"`, `"created":`, `"system_fingerprint":`, `"content":"hel"`, `"content":"lo"`, `"finish_reason":"stop"`, `"usage":{"completion_tokens":2,"prompt_tokens":1,"total_tokens":3}`, "data: [DONE]"} {
|
||||||
|
if !strings.Contains(body, want) {
|
||||||
|
t.Fatalf("SSE body missing %s: %s", want, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if strings.Contains(body, "event: message") {
|
||||||
|
t.Fatalf("chat completions stream should use OpenAI data-only SSE frames: %s", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteCompatibleTaskResponseStreamsStructuredToolAndReasoningDeltas(t *testing.T) {
|
||||||
|
executor := &fakeTaskExecutor{
|
||||||
|
deltas: []clients.StreamDeltaEvent{
|
||||||
|
{Event: map[string]any{"id": "chatcmpl-upstream", "object": "chat.completion.chunk", "created": float64(1710000000), "model": "deepseek-v4", "system_fingerprint": "fp-test", "choices": []any{map[string]any{"index": float64(0), "delta": map[string]any{"reasoning_details": []any{map[string]any{"type": "reasoning.text", "text": "detail-"}, map[string]any{"type": "reasoning.summary", "summary": "summary"}, map[string]any{"type": "reasoning.encrypted", "data": "secret"}}}, "finish_reason": nil}}}},
|
||||||
|
{Event: map[string]any{"id": "chatcmpl-upstream", "object": "chat.completion.chunk", "created": float64(1710000000), "model": "deepseek-v4", "system_fingerprint": "fp-test", "choices": []any{map[string]any{"index": float64(0), "delta": map[string]any{"content": "<think>tagged</think>answer"}, "finish_reason": nil}}}},
|
||||||
|
{Event: map[string]any{"id": "chatcmpl-upstream", "object": "chat.completion.chunk", "created": float64(1710000000), "model": "deepseek-v4", "system_fingerprint": "fp-test", "choices": []any{map[string]any{"index": float64(0), "delta": map[string]any{"functionCall": map[string]any{"name": "legacy_lookup", "arguments": "{\"city\":\"Boston\"}"}}, "finish_reason": nil}}}},
|
||||||
|
{Event: map[string]any{"id": "chatcmpl-upstream", "object": "chat.completion.chunk", "created": float64(1710000000), "model": "deepseek-v4", "system_fingerprint": "fp-test", "choices": []any{map[string]any{"index": float64(0), "delta": map[string]any{"tool_calls": []any{map[string]any{"index": float64(0), "id": "call_1", "type": "function", "function": map[string]any{"name": "lookup", "arguments": "{\"q\":"}}}}, "finish_reason": nil}}}},
|
||||||
|
{Event: map[string]any{"id": "chatcmpl-upstream", "object": "chat.completion.chunk", "created": float64(1710000000), "model": "deepseek-v4", "system_fingerprint": "fp-test", "choices": []any{map[string]any{"index": float64(0), "delta": map[string]any{"tool_calls": []any{map[string]any{"index": float64(0), "function": map[string]any{"arguments": "\"weather\"}"}}}}, "finish_reason": "tool_calls"}}}},
|
||||||
|
{Event: map[string]any{"id": "chatcmpl-upstream", "object": "chat.completion.chunk", "created": float64(1710000000), "model": "deepseek-v4", "choices": []any{}, "usage": map[string]any{"prompt_tokens": float64(4), "completion_tokens": float64(5), "total_tokens": float64(9)}}},
|
||||||
|
},
|
||||||
|
output: map[string]any{"id": "chatcmpl-upstream", "object": "chat.completion", "model": "deepseek-v4"},
|
||||||
|
}
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/chat/completions", nil)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
writeCompatibleTaskResponse(context.Background(), recorder, req, executor, "chat.completions", "gpt-test", store.GatewayTask{ID: "task-test"}, &auth.User{}, true, true)
|
||||||
|
|
||||||
|
body := recorder.Body.String()
|
||||||
|
roleIndex := strings.Index(body, `"role":"assistant"`)
|
||||||
|
reasoningIndex := strings.Index(body, `"reasoning_content":"detail-summary"`)
|
||||||
|
if roleIndex < 0 || reasoningIndex < 0 || roleIndex > reasoningIndex {
|
||||||
|
t.Fatalf("assistant role should be emitted before structured deltas: %s", body)
|
||||||
|
}
|
||||||
|
for _, want := range []string{`"system_fingerprint":"fp-test"`, `"created":1710000000`, `"reasoning_content":"tagged"`, `"content":"answer"`, `"tool_calls":[{"function":{"arguments":"{\"city\":\"Boston\"}","name":"legacy_lookup"}`, `"tool_calls":[{"function":{"arguments":"{\"q\":"`, `"finish_reason":"tool_calls"`, `"choices":[],"created":1710000000`, `"usage":{"completion_tokens":5,"prompt_tokens":4,"total_tokens":9}`, "data: [DONE]"} {
|
||||||
|
if !strings.Contains(body, want) {
|
||||||
|
t.Fatalf("SSE body missing %s: %s", want, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if strings.Contains(body, "reasoning_details") || strings.Contains(body, "<think>") || strings.Contains(body, "functionCall") {
|
||||||
|
t.Fatalf("provider-specific reasoning/tool fields should be converted away: %s", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeTaskExecutor struct {
|
||||||
|
executeCalls int
|
||||||
|
streamCalls int
|
||||||
|
deltas []clients.StreamDeltaEvent
|
||||||
|
output map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeTaskExecutor) Execute(context.Context, store.GatewayTask, *auth.User) (runner.Result, error) {
|
||||||
|
f.executeCalls++
|
||||||
|
return runner.Result{Output: f.output}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeTaskExecutor) ExecuteStream(_ context.Context, _ store.GatewayTask, _ *auth.User, onDelta clients.StreamDelta) (runner.Result, error) {
|
||||||
|
f.streamCalls++
|
||||||
|
for _, delta := range f.deltas {
|
||||||
|
if err := onDelta(delta); err != nil {
|
||||||
|
return runner.Result{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return runner.Result{Output: f.output}, nil
|
||||||
|
}
|
||||||
@ -129,6 +129,13 @@ func TestCoreLocalFlow(t *testing.T) {
|
|||||||
if _, err := testPool.Exec(ctx, `UPDATE gateway_users SET roles = '["admin"]'::jsonb WHERE username = $1`, username); err != nil {
|
if _, err := testPool.Exec(ctx, `UPDATE gateway_users SET roles = '["admin"]'::jsonb WHERE username = $1`, username); err != nil {
|
||||||
t.Fatalf("promote smoke user: %v", err)
|
t.Fatalf("promote smoke user: %v", err)
|
||||||
}
|
}
|
||||||
|
doJSON(t, server.URL, http.MethodPost, "/api/v1/auth/login", "", map[string]any{
|
||||||
|
"account": username,
|
||||||
|
"password": password,
|
||||||
|
}, http.StatusOK, &loginResponse)
|
||||||
|
if loginResponse.AccessToken == "" {
|
||||||
|
t.Fatal("admin login did not return access token")
|
||||||
|
}
|
||||||
var smokeGatewayUserID string
|
var smokeGatewayUserID string
|
||||||
if err := testPool.QueryRow(ctx, `SELECT id::text FROM gateway_users WHERE username = $1`, username).Scan(&smokeGatewayUserID); err != nil {
|
if err := testPool.QueryRow(ctx, `SELECT id::text FROM gateway_users WHERE username = $1`, username).Scan(&smokeGatewayUserID); err != nil {
|
||||||
t.Fatalf("read smoke gateway user id: %v", err)
|
t.Fatalf("read smoke gateway user id: %v", err)
|
||||||
@ -316,13 +323,17 @@ VALUES ($1, 5, '{"purpose":"core-flow"}'::jsonb)`, inviteCode); err != nil {
|
|||||||
} `json:"task"`
|
} `json:"task"`
|
||||||
}
|
}
|
||||||
defaultTextModel := "openai:gpt-4o-mini"
|
defaultTextModel := "openai:gpt-4o-mini"
|
||||||
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
|
var apiV1Chat map[string]any
|
||||||
|
doAPIV1ChatCompletionAndLoadTask(t, ctx, testPool, server.URL, apiKeyResponse.Secret, map[string]any{
|
||||||
"model": defaultTextModel,
|
"model": defaultTextModel,
|
||||||
"runMode": "simulation",
|
"runMode": "simulation",
|
||||||
"simulation": true,
|
"simulation": true,
|
||||||
"simulationDurationMs": 5,
|
"simulationDurationMs": 5,
|
||||||
"messages": []map[string]any{{"role": "user", "content": "ping"}},
|
"messages": []map[string]any{{"role": "user", "content": "ping"}},
|
||||||
}, http.StatusAccepted, &taskResponse)
|
}, "default-chat-"+suffixText, http.StatusOK, &apiV1Chat, &taskResponse.Task)
|
||||||
|
if apiV1Chat["object"] != "chat.completion" {
|
||||||
|
t.Fatalf("unexpected api v1 chat response: %+v", apiV1Chat)
|
||||||
|
}
|
||||||
if taskResponse.Task.ID == "" || taskResponse.Task.Status != "succeeded" || taskResponse.Task.RunMode != "simulation" {
|
if taskResponse.Task.ID == "" || taskResponse.Task.Status != "succeeded" || taskResponse.Task.RunMode != "simulation" {
|
||||||
t.Fatalf("unexpected task response: %+v", taskResponse.Task)
|
t.Fatalf("unexpected task response: %+v", taskResponse.Task)
|
||||||
}
|
}
|
||||||
@ -513,13 +524,13 @@ LIMIT 1`).Scan(&gptImageModelTypesRaw); err != nil {
|
|||||||
ErrorCode string `json:"errorCode"`
|
ErrorCode string `json:"errorCode"`
|
||||||
} `json:"task"`
|
} `json:"task"`
|
||||||
}
|
}
|
||||||
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", chatOnlyAPIKeyResponse.Secret, map[string]any{
|
doAPIV1ChatCompletionAndLoadTask(t, ctx, testPool, server.URL, chatOnlyAPIKeyResponse.Secret, map[string]any{
|
||||||
"model": deniedModel,
|
"model": deniedModel,
|
||||||
"runMode": "simulation",
|
"runMode": "simulation",
|
||||||
"simulation": true,
|
"simulation": true,
|
||||||
"simulationDurationMs": 5,
|
"simulationDurationMs": 5,
|
||||||
"messages": []map[string]any{{"role": "user", "content": "permission deny"}},
|
"messages": []map[string]any{{"role": "user", "content": "permission deny"}},
|
||||||
}, http.StatusAccepted, &deniedTask)
|
}, "permission-deny-"+suffixText, http.StatusNotFound, nil, &deniedTask.Task)
|
||||||
if deniedTask.Task.Status != "failed" || deniedTask.Task.ErrorCode != "no_model_candidate" {
|
if deniedTask.Task.Status != "failed" || deniedTask.Task.ErrorCode != "no_model_candidate" {
|
||||||
t.Fatalf("deny access rule should hide denied model from runtime candidates: %+v", deniedTask.Task)
|
t.Fatalf("deny access rule should hide denied model from runtime candidates: %+v", deniedTask.Task)
|
||||||
}
|
}
|
||||||
@ -561,13 +572,13 @@ LIMIT 1`).Scan(&gptImageModelTypesRaw); err != nil {
|
|||||||
ErrorCode string `json:"errorCode"`
|
ErrorCode string `json:"errorCode"`
|
||||||
} `json:"task"`
|
} `json:"task"`
|
||||||
}
|
}
|
||||||
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", chatOnlyAPIKeyResponse.Secret, map[string]any{
|
doAPIV1ChatCompletionAndLoadTask(t, ctx, testPool, server.URL, chatOnlyAPIKeyResponse.Secret, map[string]any{
|
||||||
"model": controlledModel,
|
"model": controlledModel,
|
||||||
"runMode": "simulation",
|
"runMode": "simulation",
|
||||||
"simulation": true,
|
"simulation": true,
|
||||||
"simulationDurationMs": 5,
|
"simulationDurationMs": 5,
|
||||||
"messages": []map[string]any{{"role": "user", "content": "allow should block other keys"}},
|
"messages": []map[string]any{{"role": "user", "content": "allow should block other keys"}},
|
||||||
}, http.StatusAccepted, &blockedControlledTask)
|
}, "permission-allow-block-"+suffixText, http.StatusNotFound, nil, &blockedControlledTask.Task)
|
||||||
if blockedControlledTask.Task.Status != "failed" || blockedControlledTask.Task.ErrorCode != "no_model_candidate" {
|
if blockedControlledTask.Task.Status != "failed" || blockedControlledTask.Task.ErrorCode != "no_model_candidate" {
|
||||||
t.Fatalf("allow access rule should make the resource unavailable to unmatched subjects: %+v", blockedControlledTask.Task)
|
t.Fatalf("allow access rule should make the resource unavailable to unmatched subjects: %+v", blockedControlledTask.Task)
|
||||||
}
|
}
|
||||||
@ -586,13 +597,13 @@ LIMIT 1`).Scan(&gptImageModelTypesRaw); err != nil {
|
|||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
} `json:"task"`
|
} `json:"task"`
|
||||||
}
|
}
|
||||||
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", chatOnlyAPIKeyResponse.Secret, map[string]any{
|
doAPIV1ChatCompletionAndLoadTask(t, ctx, testPool, server.URL, chatOnlyAPIKeyResponse.Secret, map[string]any{
|
||||||
"model": controlledModel,
|
"model": controlledModel,
|
||||||
"runMode": "simulation",
|
"runMode": "simulation",
|
||||||
"simulation": true,
|
"simulation": true,
|
||||||
"simulationDurationMs": 5,
|
"simulationDurationMs": 5,
|
||||||
"messages": []map[string]any{{"role": "user", "content": "allow should pass"}},
|
"messages": []map[string]any{{"role": "user", "content": "allow should pass"}},
|
||||||
}, http.StatusAccepted, &allowedControlledTask)
|
}, "permission-allow-pass-"+suffixText, http.StatusOK, nil, &allowedControlledTask.Task)
|
||||||
if allowedControlledTask.Task.Status != "succeeded" {
|
if allowedControlledTask.Task.Status != "succeeded" {
|
||||||
t.Fatalf("matching allow access rule should make the controlled model usable: %+v", allowedControlledTask.Task)
|
t.Fatalf("matching allow access rule should make the controlled model usable: %+v", allowedControlledTask.Task)
|
||||||
}
|
}
|
||||||
@ -645,13 +656,13 @@ WHERE gateway_user_id = $1::uuid
|
|||||||
FinalChargeAmount float64 `json:"finalChargeAmount"`
|
FinalChargeAmount float64 `json:"finalChargeAmount"`
|
||||||
} `json:"task"`
|
} `json:"task"`
|
||||||
}
|
}
|
||||||
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
|
doAPIV1ChatCompletionAndLoadTask(t, ctx, testPool, server.URL, apiKeyResponse.Secret, map[string]any{
|
||||||
"model": pricingModel,
|
"model": pricingModel,
|
||||||
"runMode": "simulation",
|
"runMode": "simulation",
|
||||||
"simulation": true,
|
"simulation": true,
|
||||||
"simulationDurationMs": 5,
|
"simulationDurationMs": 5,
|
||||||
"messages": []map[string]any{{"role": "user", "content": "priced ping"}},
|
"messages": []map[string]any{{"role": "user", "content": "priced ping"}},
|
||||||
}, http.StatusAccepted, &pricingTask)
|
}, "pricing-chat-"+suffixText, http.StatusOK, nil, &pricingTask.Task)
|
||||||
if pricingTask.Task.Status != "succeeded" || !floatNear(pricingTask.Task.FinalChargeAmount, 0.028) {
|
if pricingTask.Task.Status != "succeeded" || !floatNear(pricingTask.Task.FinalChargeAmount, 0.028) {
|
||||||
t.Fatalf("custom pricing rule set should drive text billing, got task=%+v", pricingTask.Task)
|
t.Fatalf("custom pricing rule set should drive text billing, got task=%+v", pricingTask.Task)
|
||||||
}
|
}
|
||||||
@ -757,14 +768,14 @@ WHERE reference_type = 'gateway_task'
|
|||||||
ErrorCode string `json:"errorCode"`
|
ErrorCode string `json:"errorCode"`
|
||||||
} `json:"task"`
|
} `json:"task"`
|
||||||
}
|
}
|
||||||
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
|
doAPIV1ChatCompletionAndLoadTask(t, ctx, testPool, server.URL, apiKeyResponse.Secret, map[string]any{
|
||||||
"model": rateLimitedModel,
|
"model": rateLimitedModel,
|
||||||
"runMode": "simulation",
|
"runMode": "simulation",
|
||||||
"simulation": true,
|
"simulation": true,
|
||||||
"simulationDurationMs": 5,
|
"simulationDurationMs": 5,
|
||||||
"simulationProfile": "non_retryable_failure",
|
"simulationProfile": "non_retryable_failure",
|
||||||
"messages": []map[string]any{{"role": "user", "content": "failed first"}},
|
"messages": []map[string]any{{"role": "user", "content": "failed first"}},
|
||||||
}, http.StatusAccepted, &rateLimitFailedTask)
|
}, "rate-limit-failed-first-"+suffixText, http.StatusBadGateway, nil, &rateLimitFailedTask.Task)
|
||||||
if rateLimitFailedTask.Task.Status != "failed" || rateLimitFailedTask.Task.ErrorCode != "bad_request" {
|
if rateLimitFailedTask.Task.Status != "failed" || rateLimitFailedTask.Task.ErrorCode != "bad_request" {
|
||||||
t.Fatalf("failed rate-limited task should fail before consuming rpm: %+v", rateLimitFailedTask.Task)
|
t.Fatalf("failed rate-limited task should fail before consuming rpm: %+v", rateLimitFailedTask.Task)
|
||||||
}
|
}
|
||||||
@ -774,13 +785,13 @@ WHERE reference_type = 'gateway_task'
|
|||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
} `json:"task"`
|
} `json:"task"`
|
||||||
}
|
}
|
||||||
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
|
doAPIV1ChatCompletionAndLoadTask(t, ctx, testPool, server.URL, apiKeyResponse.Secret, map[string]any{
|
||||||
"model": rateLimitedModel,
|
"model": rateLimitedModel,
|
||||||
"runMode": "simulation",
|
"runMode": "simulation",
|
||||||
"simulation": true,
|
"simulation": true,
|
||||||
"simulationDurationMs": 5,
|
"simulationDurationMs": 5,
|
||||||
"messages": []map[string]any{{"role": "user", "content": "first"}},
|
"messages": []map[string]any{{"role": "user", "content": "first"}},
|
||||||
}, http.StatusAccepted, &rateLimitTaskOne)
|
}, "rate-limit-first-"+suffixText, http.StatusOK, nil, &rateLimitTaskOne.Task)
|
||||||
if rateLimitTaskOne.Task.Status != "succeeded" {
|
if rateLimitTaskOne.Task.Status != "succeeded" {
|
||||||
t.Fatalf("first rate-limited task should succeed: %+v", rateLimitTaskOne.Task)
|
t.Fatalf("first rate-limited task should succeed: %+v", rateLimitTaskOne.Task)
|
||||||
}
|
}
|
||||||
@ -790,13 +801,13 @@ WHERE reference_type = 'gateway_task'
|
|||||||
ErrorCode string `json:"errorCode"`
|
ErrorCode string `json:"errorCode"`
|
||||||
} `json:"task"`
|
} `json:"task"`
|
||||||
}
|
}
|
||||||
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
|
doAPIV1ChatCompletionAndLoadTask(t, ctx, testPool, server.URL, apiKeyResponse.Secret, map[string]any{
|
||||||
"model": rateLimitedModel,
|
"model": rateLimitedModel,
|
||||||
"runMode": "simulation",
|
"runMode": "simulation",
|
||||||
"simulation": true,
|
"simulation": true,
|
||||||
"simulationDurationMs": 5,
|
"simulationDurationMs": 5,
|
||||||
"messages": []map[string]any{{"role": "user", "content": "second"}},
|
"messages": []map[string]any{{"role": "user", "content": "second"}},
|
||||||
}, http.StatusAccepted, &rateLimitTaskTwo)
|
}, "rate-limit-second-"+suffixText, http.StatusTooManyRequests, nil, &rateLimitTaskTwo.Task)
|
||||||
if rateLimitTaskTwo.Task.Status != "failed" || rateLimitTaskTwo.Task.ErrorCode != "rate_limit" {
|
if rateLimitTaskTwo.Task.Status != "failed" || rateLimitTaskTwo.Task.ErrorCode != "rate_limit" {
|
||||||
t.Fatalf("runtime policy rate limit should fail second task with rate_limit: %+v", rateLimitTaskTwo.Task)
|
t.Fatalf("runtime policy rate limit should fail second task with rate_limit: %+v", rateLimitTaskTwo.Task)
|
||||||
}
|
}
|
||||||
@ -808,12 +819,12 @@ WHERE reference_type = 'gateway_task'
|
|||||||
AsyncMode bool `json:"asyncMode"`
|
AsyncMode bool `json:"asyncMode"`
|
||||||
} `json:"task"`
|
} `json:"task"`
|
||||||
}
|
}
|
||||||
doJSONWithHeaders(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
|
doJSONWithHeaders(t, server.URL, http.MethodPost, "/api/v1/responses", apiKeyResponse.Secret, map[string]any{
|
||||||
"model": rateLimitedModel,
|
"model": rateLimitedModel,
|
||||||
"runMode": "simulation",
|
"runMode": "simulation",
|
||||||
"simulation": true,
|
"simulation": true,
|
||||||
"simulationDurationMs": 5,
|
"simulationDurationMs": 5,
|
||||||
"messages": []map[string]any{{"role": "user", "content": "async queued"}},
|
"input": "async queued",
|
||||||
}, map[string]string{"X-Async": "true"}, http.StatusAccepted, &asyncRateLimitTask)
|
}, map[string]string{"X-Async": "true"}, http.StatusAccepted, &asyncRateLimitTask)
|
||||||
if asyncRateLimitTask.TaskID == "" || asyncRateLimitTask.Task.ID != asyncRateLimitTask.TaskID || !asyncRateLimitTask.Task.AsyncMode {
|
if asyncRateLimitTask.TaskID == "" || asyncRateLimitTask.Task.ID != asyncRateLimitTask.TaskID || !asyncRateLimitTask.Task.AsyncMode {
|
||||||
t.Fatalf("async task response should expose task id and async mode: %+v", asyncRateLimitTask)
|
t.Fatalf("async task response should expose task id and async mode: %+v", asyncRateLimitTask)
|
||||||
@ -984,11 +995,11 @@ WHERE reference_type = 'gateway_task'
|
|||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
} `json:"task"`
|
} `json:"task"`
|
||||||
}
|
}
|
||||||
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
|
doAPIV1ChatCompletionAndLoadTask(t, ctx, testPool, server.URL, apiKeyResponse.Secret, map[string]any{
|
||||||
"model": failoverModel,
|
"model": failoverModel,
|
||||||
"runMode": "simulation",
|
"runMode": "simulation",
|
||||||
"messages": []map[string]any{{"role": "user", "content": "retry please"}},
|
"messages": []map[string]any{{"role": "user", "content": "retry please"}},
|
||||||
}, http.StatusAccepted, &failoverTask)
|
}, "failover-chat-"+suffixText, http.StatusOK, nil, &failoverTask.Task)
|
||||||
if failoverTask.Task.Status != "succeeded" {
|
if failoverTask.Task.Status != "succeeded" {
|
||||||
t.Fatalf("failover task should succeed through second client: %+v", failoverTask.Task)
|
t.Fatalf("failover task should succeed through second client: %+v", failoverTask.Task)
|
||||||
}
|
}
|
||||||
@ -1103,13 +1114,13 @@ WHERE failed.id = $1::uuid`, failedPlatform.ID, successPlatform.ID, unrelatedPri
|
|||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
} `json:"task"`
|
} `json:"task"`
|
||||||
}
|
}
|
||||||
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
|
doAPIV1ChatCompletionAndLoadTask(t, ctx, testPool, server.URL, apiKeyResponse.Secret, map[string]any{
|
||||||
"model": degradeModel,
|
"model": degradeModel,
|
||||||
"runMode": "simulation",
|
"runMode": "simulation",
|
||||||
"simulation": true,
|
"simulation": true,
|
||||||
"simulationDurationMs": 5,
|
"simulationDurationMs": 5,
|
||||||
"messages": []map[string]any{{"role": "user", "content": "degrade please"}},
|
"messages": []map[string]any{{"role": "user", "content": "degrade please"}},
|
||||||
}, http.StatusAccepted, °radeTask)
|
}, "degrade-chat-"+suffixText, http.StatusOK, nil, °radeTask.Task)
|
||||||
if degradeTask.Task.Status != "succeeded" {
|
if degradeTask.Task.Status != "succeeded" {
|
||||||
t.Fatalf("degrade task should fail over after cooling down failed model: %+v", degradeTask.Task)
|
t.Fatalf("degrade task should fail over after cooling down failed model: %+v", degradeTask.Task)
|
||||||
}
|
}
|
||||||
@ -1170,13 +1181,13 @@ WHERE m.platform_id = $1::uuid
|
|||||||
ErrorCode string `json:"errorCode"`
|
ErrorCode string `json:"errorCode"`
|
||||||
} `json:"task"`
|
} `json:"task"`
|
||||||
}
|
}
|
||||||
doJSON(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
|
doAPIV1ChatCompletionAndLoadTask(t, ctx, testPool, server.URL, apiKeyResponse.Secret, map[string]any{
|
||||||
"model": autoDisableModel,
|
"model": autoDisableModel,
|
||||||
"runMode": "simulation",
|
"runMode": "simulation",
|
||||||
"simulation": true,
|
"simulation": true,
|
||||||
"simulationDurationMs": 5,
|
"simulationDurationMs": 5,
|
||||||
"messages": []map[string]any{{"role": "user", "content": "disable please"}},
|
"messages": []map[string]any{{"role": "user", "content": "disable please"}},
|
||||||
}, http.StatusAccepted, &autoDisableTask)
|
}, "auto-disable-chat-"+suffixText, http.StatusBadGateway, nil, &autoDisableTask.Task)
|
||||||
if autoDisableTask.Task.Status != "failed" || autoDisableTask.Task.ErrorCode != "invalid_api_key" {
|
if autoDisableTask.Task.Status != "failed" || autoDisableTask.Task.ErrorCode != "invalid_api_key" {
|
||||||
t.Fatalf("auto disable task should fail with invalid_api_key: %+v", autoDisableTask.Task)
|
t.Fatalf("auto disable task should fail with invalid_api_key: %+v", autoDisableTask.Task)
|
||||||
}
|
}
|
||||||
@ -1293,12 +1304,12 @@ WHERE m.platform_id = $1::uuid
|
|||||||
AsyncMode bool `json:"asyncMode"`
|
AsyncMode bool `json:"asyncMode"`
|
||||||
} `json:"task"`
|
} `json:"task"`
|
||||||
}
|
}
|
||||||
doJSONWithHeaders(t, server.URL, http.MethodPost, "/api/v1/chat/completions", apiKeyResponse.Secret, map[string]any{
|
doJSONWithHeaders(t, server.URL, http.MethodPost, "/api/v1/responses", apiKeyResponse.Secret, map[string]any{
|
||||||
"model": defaultTextModel,
|
"model": defaultTextModel,
|
||||||
"runMode": "simulation",
|
"runMode": "simulation",
|
||||||
"simulation": true,
|
"simulation": true,
|
||||||
"simulationDurationMs": 2000,
|
"simulationDurationMs": 2000,
|
||||||
"messages": []map[string]any{{"role": "user", "content": "river worker restart"}},
|
"input": "river worker restart",
|
||||||
}, map[string]string{"X-Async": "true"}, http.StatusAccepted, &restartAsyncTask)
|
}, map[string]string{"X-Async": "true"}, http.StatusAccepted, &restartAsyncTask)
|
||||||
if restartAsyncTask.TaskID == "" || !restartAsyncTask.Task.AsyncMode {
|
if restartAsyncTask.TaskID == "" || !restartAsyncTask.Task.AsyncMode {
|
||||||
t.Fatalf("restart async task should be accepted as async: %+v", restartAsyncTask)
|
t.Fatalf("restart async task should be accepted as async: %+v", restartAsyncTask)
|
||||||
@ -1398,14 +1409,41 @@ func applyMigration(t *testing.T, ctx context.Context, databaseURL string) {
|
|||||||
t.Fatalf("connect migration db: %v", err)
|
t.Fatalf("connect migration db: %v", err)
|
||||||
}
|
}
|
||||||
defer pool.Close()
|
defer pool.Close()
|
||||||
|
if _, err := pool.Exec(ctx, `
|
||||||
|
CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||||
|
version text PRIMARY KEY,
|
||||||
|
applied_at timestamptz NOT NULL DEFAULT now()
|
||||||
|
);`); err != nil {
|
||||||
|
t.Fatalf("ensure schema migrations: %v", err)
|
||||||
|
}
|
||||||
for _, migrationPath := range migrationFiles {
|
for _, migrationPath := range migrationFiles {
|
||||||
|
version := strings.TrimSuffix(filepath.Base(migrationPath), filepath.Ext(migrationPath))
|
||||||
|
var exists bool
|
||||||
|
if err := pool.QueryRow(ctx, "SELECT EXISTS (SELECT 1 FROM schema_migrations WHERE version = $1)", version).Scan(&exists); err != nil {
|
||||||
|
t.Fatalf("check migration %s: %v", filepath.Base(migrationPath), err)
|
||||||
|
}
|
||||||
|
if exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
migration, err := os.ReadFile(migrationPath)
|
migration, err := os.ReadFile(migrationPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("read migration %s: %v", filepath.Base(migrationPath), err)
|
t.Fatalf("read migration %s: %v", filepath.Base(migrationPath), err)
|
||||||
}
|
}
|
||||||
if _, err := pool.Exec(ctx, string(migration)); err != nil {
|
tx, err := pool.Begin(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("begin migration %s: %v", filepath.Base(migrationPath), err)
|
||||||
|
}
|
||||||
|
if _, err := tx.Exec(ctx, string(migration)); err != nil {
|
||||||
|
_ = tx.Rollback(ctx)
|
||||||
t.Fatalf("apply migration %s: %v", filepath.Base(migrationPath), err)
|
t.Fatalf("apply migration %s: %v", filepath.Base(migrationPath), err)
|
||||||
}
|
}
|
||||||
|
if _, err := tx.Exec(ctx, "INSERT INTO schema_migrations(version) VALUES($1)", version); err != nil {
|
||||||
|
_ = tx.Rollback(ctx)
|
||||||
|
t.Fatalf("record migration %s: %v", filepath.Base(migrationPath), err)
|
||||||
|
}
|
||||||
|
if err := tx.Commit(ctx); err != nil {
|
||||||
|
t.Fatalf("commit migration %s: %v", filepath.Base(migrationPath), err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1453,6 +1491,20 @@ func doJSONWithHeaders(t *testing.T, baseURL string, method string, path string,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func doAPIV1ChatCompletionAndLoadTask(t *testing.T, ctx context.Context, pool *pgxpool.Pool, baseURL string, token string, payload map[string]any, marker string, expectedStatus int, responseOut any, taskDetailOut any) string {
|
||||||
|
t.Helper()
|
||||||
|
payload["integrationTestMarker"] = marker
|
||||||
|
if responseOut == nil {
|
||||||
|
responseOut = &map[string]any{}
|
||||||
|
}
|
||||||
|
doJSON(t, baseURL, http.MethodPost, "/api/v1/chat/completions", token, payload, expectedStatus, responseOut)
|
||||||
|
taskID := waitForTaskIDByRequestField(t, ctx, pool, "integrationTestMarker", marker, 2*time.Second)
|
||||||
|
if taskDetailOut != nil {
|
||||||
|
doJSON(t, baseURL, http.MethodGet, "/api/v1/tasks/"+taskID, token, nil, http.StatusOK, taskDetailOut)
|
||||||
|
}
|
||||||
|
return taskID
|
||||||
|
}
|
||||||
|
|
||||||
type taskWaitDetail struct {
|
type taskWaitDetail struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
@ -1481,6 +1533,11 @@ func waitForTaskStatus(t *testing.T, baseURL string, token string, taskID string
|
|||||||
}
|
}
|
||||||
|
|
||||||
func waitForTaskIDByRequestMarker(t *testing.T, ctx context.Context, pool *pgxpool.Pool, marker string, timeout time.Duration) string {
|
func waitForTaskIDByRequestMarker(t *testing.T, ctx context.Context, pool *pgxpool.Pool, marker string, timeout time.Duration) string {
|
||||||
|
t.Helper()
|
||||||
|
return waitForTaskIDByRequestField(t, ctx, pool, "cancelTestId", marker, timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitForTaskIDByRequestField(t *testing.T, ctx context.Context, pool *pgxpool.Pool, key string, value string, timeout time.Duration) string {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
deadline := time.Now().Add(timeout)
|
deadline := time.Now().Add(timeout)
|
||||||
for time.Now().Before(deadline) {
|
for time.Now().Before(deadline) {
|
||||||
@ -1488,15 +1545,15 @@ func waitForTaskIDByRequestMarker(t *testing.T, ctx context.Context, pool *pgxpo
|
|||||||
err := pool.QueryRow(ctx, `
|
err := pool.QueryRow(ctx, `
|
||||||
SELECT id::text
|
SELECT id::text
|
||||||
FROM gateway_tasks
|
FROM gateway_tasks
|
||||||
WHERE request->>'cancelTestId' = $1
|
WHERE request->>$1 = $2
|
||||||
ORDER BY created_at DESC
|
ORDER BY created_at DESC
|
||||||
LIMIT 1`, marker).Scan(&taskID)
|
LIMIT 1`, key, value).Scan(&taskID)
|
||||||
if err == nil && taskID != "" {
|
if err == nil && taskID != "" {
|
||||||
return taskID
|
return taskID
|
||||||
}
|
}
|
||||||
time.Sleep(50 * time.Millisecond)
|
time.Sleep(50 * time.Millisecond)
|
||||||
}
|
}
|
||||||
t.Fatalf("task with request marker %s was not created within %s", marker, timeout)
|
t.Fatalf("task with request %s=%s was not created within %s", key, value, timeout)
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1577,13 +1634,13 @@ func assertLoadAvoidanceSimulatedRetryChain(t *testing.T, ctx context.Context, t
|
|||||||
ErrorCode string `json:"errorCode"`
|
ErrorCode string `json:"errorCode"`
|
||||||
} `json:"task"`
|
} `json:"task"`
|
||||||
}
|
}
|
||||||
doJSON(t, baseURL, http.MethodPost, "/api/v1/chat/completions", runtimeToken, map[string]any{
|
doAPIV1ChatCompletionAndLoadTask(t, ctx, testPool, baseURL, runtimeToken, map[string]any{
|
||||||
"model": model,
|
"model": model,
|
||||||
"runMode": "simulation",
|
"runMode": "simulation",
|
||||||
"simulation": true,
|
"simulation": true,
|
||||||
"simulationDurationMs": 5,
|
"simulationDurationMs": 5,
|
||||||
"messages": []map[string]any{{"role": "user", "content": "load avoidance retry chain"}},
|
"messages": []map[string]any{{"role": "user", "content": "load avoidance retry chain"}},
|
||||||
}, http.StatusAccepted, &taskResponse)
|
}, "load-avoidance-"+suffixText, http.StatusBadGateway, nil, &taskResponse.Task)
|
||||||
if taskResponse.Task.ID == "" || taskResponse.Task.Status != "failed" || taskResponse.Task.ErrorCode != "bad_request" {
|
if taskResponse.Task.ID == "" || taskResponse.Task.Status != "failed" || taskResponse.Task.ErrorCode != "bad_request" {
|
||||||
t.Fatalf("load avoidance task should only fail after avoided clients are retried, got %+v", taskResponse.Task)
|
t.Fatalf("load avoidance task should only fail after avoided clients are retried, got %+v", taskResponse.Task)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -11,6 +11,22 @@ import (
|
|||||||
|
|
||||||
const maxGatewayUploadBytes = 256 << 20
|
const maxGatewayUploadBytes = 256 << 20
|
||||||
|
|
||||||
|
// uploadFile godoc
|
||||||
|
// @Summary 上传文件
|
||||||
|
// @Description 上传文件到配置的文件存储通道;没有启用通道时回退到本地静态上传目录。单文件最大 256MiB。
|
||||||
|
// @Tags files
|
||||||
|
// @Accept multipart/form-data
|
||||||
|
// @Produce json
|
||||||
|
// @Security BearerAuth
|
||||||
|
// @Param file formData file true "要上传的文件"
|
||||||
|
// @Param source formData string false "上传来源标识" default(ai-gateway-openapi)
|
||||||
|
// @Success 200 {object} FileUploadResponse
|
||||||
|
// @Failure 400 {object} ErrorEnvelope
|
||||||
|
// @Failure 401 {object} ErrorEnvelope
|
||||||
|
// @Failure 502 {object} ErrorEnvelope
|
||||||
|
// @Failure 503 {object} ErrorEnvelope
|
||||||
|
// @Router /api/v1/files/upload [post]
|
||||||
|
// @Router /v1/files/upload [post]
|
||||||
func (s *Server) uploadFile(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) uploadFile(w http.ResponseWriter, r *http.Request) {
|
||||||
r.Body = http.MaxBytesReader(w, r.Body, maxGatewayUploadBytes)
|
r.Body = http.MaxBytesReader(w, r.Body, maxGatewayUploadBytes)
|
||||||
if err := r.ParseMultipartForm(32 << 20); err != nil {
|
if err := r.ParseMultipartForm(32 << 20); err != nil {
|
||||||
|
|||||||
@ -13,6 +13,7 @@ import (
|
|||||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
||||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
|
||||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/netproxy"
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/netproxy"
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/runner"
|
||||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -804,7 +805,7 @@ func (s *Server) estimatePricing(w http.ResponseWriter, r *http.Request) {
|
|||||||
estimate, err := s.runner.Estimate(r.Context(), kind, model, body, user)
|
estimate, err := s.runner.Estimate(r.Context(), kind, model, body, user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, store.ErrNoModelCandidate) {
|
if errors.Is(err, store.ErrNoModelCandidate) {
|
||||||
writeError(w, statusFromRunError(err), err.Error(), store.ModelCandidateErrorCode(err))
|
writeErrorWithDetails(w, statusFromRunError(err), runErrorMessage(err), runErrorDetails(err), store.ModelCandidateErrorCode(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.logger.Error("estimate pricing failed", "error", err)
|
s.logger.Error("estimate pricing failed", "error", err)
|
||||||
@ -858,7 +859,7 @@ func (s *Server) listModelRateLimitStatuses(w http.ResponseWriter, r *http.Reque
|
|||||||
|
|
||||||
// createTask godoc
|
// createTask godoc
|
||||||
// @Summary 创建或执行 AI 任务
|
// @Summary 创建或执行 AI 任务
|
||||||
// @Description 网关任务接口按 model 选择平台模型;/api/v1 路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或 SSE 流。
|
// @Description 网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或 SSE 流。
|
||||||
// @Tags tasks
|
// @Tags tasks
|
||||||
// @Accept json
|
// @Accept json
|
||||||
// @Produce json
|
// @Produce json
|
||||||
@ -874,7 +875,6 @@ func (s *Server) listModelRateLimitStatuses(w http.ResponseWriter, r *http.Reque
|
|||||||
// @Failure 404 {object} ErrorEnvelope
|
// @Failure 404 {object} ErrorEnvelope
|
||||||
// @Failure 429 {object} ErrorEnvelope
|
// @Failure 429 {object} ErrorEnvelope
|
||||||
// @Failure 502 {object} ErrorEnvelope
|
// @Failure 502 {object} ErrorEnvelope
|
||||||
// @Router /api/v1/chat/completions [post]
|
|
||||||
// @Router /api/v1/responses [post]
|
// @Router /api/v1/responses [post]
|
||||||
// @Router /api/v1/images/generations [post]
|
// @Router /api/v1/images/generations [post]
|
||||||
// @Router /api/v1/images/edits [post]
|
// @Router /api/v1/images/edits [post]
|
||||||
@ -909,13 +909,13 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler {
|
|||||||
writeError(w, http.StatusForbidden, "api key scope does not allow this capability")
|
writeError(w, http.StatusForbidden, "api key scope does not allow this capability")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
asyncMode := asyncRequest(r)
|
responsePlan := planTaskResponse(kind, compatible, body, r)
|
||||||
|
|
||||||
task, err := s.store.CreateTask(r.Context(), store.CreateTaskInput{
|
task, err := s.store.CreateTask(r.Context(), store.CreateTaskInput{
|
||||||
Kind: kind,
|
Kind: kind,
|
||||||
Model: model,
|
Model: model,
|
||||||
RunMode: runModeFromRequest(body),
|
RunMode: runModeFromRequest(body),
|
||||||
Async: asyncMode,
|
Async: responsePlan.asyncMode,
|
||||||
Request: body,
|
Request: body,
|
||||||
}, user)
|
}, user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -923,7 +923,7 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler {
|
|||||||
writeError(w, http.StatusInternalServerError, "create task failed")
|
writeError(w, http.StatusInternalServerError, "create task failed")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if asyncMode {
|
if responsePlan.asyncMode {
|
||||||
if err := s.runner.EnqueueAsyncTask(r.Context(), task); err != nil {
|
if err := s.runner.EnqueueAsyncTask(r.Context(), task); err != nil {
|
||||||
writeError(w, http.StatusInternalServerError, err.Error(), "enqueue_failed")
|
writeError(w, http.StatusInternalServerError, err.Error(), "enqueue_failed")
|
||||||
return
|
return
|
||||||
@ -933,14 +933,84 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler {
|
|||||||
}
|
}
|
||||||
runCtx, cancelRun := s.requestExecutionContext(r)
|
runCtx, cancelRun := s.requestExecutionContext(r)
|
||||||
defer cancelRun()
|
defer cancelRun()
|
||||||
if compatible {
|
if responsePlan.compatibleMode {
|
||||||
if boolValue(body, "stream") {
|
writeCompatibleTaskResponse(runCtx, w, r, s.runner, kind, model, task, user, responsePlan.streamMode, streamIncludeUsage(body))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
result, runErr := s.runner.Execute(runCtx, task, user)
|
||||||
|
if runErr != nil {
|
||||||
|
s.logger.Warn("task completed with failure", "kind", kind, "taskId", task.ID, "error", runErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !requestStillConnected(r) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeTaskAccepted(w, result.Task)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// createAPIV1ChatCompletions godoc
|
||||||
|
// @Summary 创建 Chat Completions
|
||||||
|
// @Description /api/v1/chat/completions 同步执行:stream=true 返回 text/event-stream SSE;stream=false 或未传返回兼容 JSON;该接口忽略 X-Async。
|
||||||
|
// @Tags tasks
|
||||||
|
// @Accept json
|
||||||
|
// @Produce json
|
||||||
|
// @Produce text/event-stream
|
||||||
|
// @Security BearerAuth
|
||||||
|
// @Param X-Async header bool false "该接口忽略此参数"
|
||||||
|
// @Param input body TaskRequest true "Chat Completions 请求"
|
||||||
|
// @Success 200 {object} ChatCompletionCompatibleResponse
|
||||||
|
// @Failure 400 {object} ErrorEnvelope
|
||||||
|
// @Failure 401 {object} ErrorEnvelope
|
||||||
|
// @Failure 402 {object} ErrorEnvelope
|
||||||
|
// @Failure 403 {object} ErrorEnvelope
|
||||||
|
// @Failure 404 {object} ErrorEnvelope
|
||||||
|
// @Failure 429 {object} ErrorEnvelope
|
||||||
|
// @Failure 502 {object} ErrorEnvelope
|
||||||
|
// @Router /api/v1/chat/completions [post]
|
||||||
|
func (s *Server) createAPIV1ChatCompletions() http.Handler {
|
||||||
|
return s.createTask("chat.completions", false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) requestExecutionContext(r *http.Request) (context.Context, context.CancelFunc) {
|
||||||
|
base := context.WithoutCancel(r.Context())
|
||||||
|
if s.ctx == nil {
|
||||||
|
return base, func() {}
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithCancel(base)
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-s.ctx.Done():
|
||||||
|
cancel()
|
||||||
|
case <-ctx.Done():
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return ctx, cancel
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestStillConnected(r *http.Request) bool {
|
||||||
|
select {
|
||||||
|
case <-r.Context().Done():
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type taskExecutor interface {
|
||||||
|
Execute(context.Context, store.GatewayTask, *auth.User) (runner.Result, error)
|
||||||
|
ExecuteStream(context.Context, store.GatewayTask, *auth.User, clients.StreamDelta) (runner.Result, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeCompatibleTaskResponse(runCtx context.Context, w http.ResponseWriter, r *http.Request, executor taskExecutor, kind string, model string, task store.GatewayTask, user *auth.User, streamMode bool, includeUsage bool) {
|
||||||
|
if streamMode {
|
||||||
flusher := prepareCompatibleStream(w)
|
flusher := prepareCompatibleStream(w)
|
||||||
result, runErr := s.runner.ExecuteStream(runCtx, task, user, func(delta string) error {
|
streamWriter := newCompatibleStreamWriter(kind, model, includeUsage)
|
||||||
|
result, runErr := executor.ExecuteStream(runCtx, task, user, func(delta clients.StreamDeltaEvent) error {
|
||||||
if !requestStillConnected(r) {
|
if !requestStillConnected(r) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
writeCompatibleDelta(w, kind, model, delta)
|
streamWriter.writeDelta(w, delta)
|
||||||
if flusher != nil {
|
if flusher != nil {
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
@ -974,13 +1044,14 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler {
|
|||||||
if !requestStillConnected(r) {
|
if !requestStillConnected(r) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
writeCompatibleDone(w, kind, model, result.Output)
|
streamWriter.writeDone(w, result.Output)
|
||||||
if flusher != nil {
|
if flusher != nil {
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
result, runErr := s.runner.Execute(runCtx, task, user)
|
|
||||||
|
result, runErr := executor.Execute(runCtx, task, user)
|
||||||
if runErr != nil {
|
if runErr != nil {
|
||||||
if !requestStillConnected(r) {
|
if !requestStillConnected(r) {
|
||||||
return
|
return
|
||||||
@ -992,43 +1063,12 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
writeJSON(w, http.StatusOK, result.Output)
|
writeJSON(w, http.StatusOK, result.Output)
|
||||||
return
|
|
||||||
}
|
|
||||||
result, runErr := s.runner.Execute(runCtx, task, user)
|
|
||||||
if runErr != nil {
|
|
||||||
s.logger.Warn("task completed with failure", "kind", kind, "taskId", task.ID, "error", runErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !requestStillConnected(r) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
writeTaskAccepted(w, result.Task)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) requestExecutionContext(r *http.Request) (context.Context, context.CancelFunc) {
|
func streamIncludeUsage(body map[string]any) bool {
|
||||||
base := context.WithoutCancel(r.Context())
|
streamOptions, _ := body["stream_options"].(map[string]any)
|
||||||
if s.ctx == nil {
|
includeUsage, _ := streamOptions["include_usage"].(bool)
|
||||||
return base, func() {}
|
return includeUsage
|
||||||
}
|
|
||||||
ctx, cancel := context.WithCancel(base)
|
|
||||||
go func() {
|
|
||||||
select {
|
|
||||||
case <-s.ctx.Done():
|
|
||||||
cancel()
|
|
||||||
case <-ctx.Done():
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return ctx, cancel
|
|
||||||
}
|
|
||||||
|
|
||||||
func requestStillConnected(r *http.Request) bool {
|
|
||||||
select {
|
|
||||||
case <-r.Context().Done():
|
|
||||||
return false
|
|
||||||
default:
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func asyncRequest(r *http.Request) bool {
|
func asyncRequest(r *http.Request) bool {
|
||||||
@ -1036,6 +1076,26 @@ func asyncRequest(r *http.Request) bool {
|
|||||||
return value == "1" || value == "true" || value == "yes" || value == "on"
|
return value == "1" || value == "true" || value == "yes" || value == "on"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type taskResponsePlan struct {
|
||||||
|
asyncMode bool
|
||||||
|
compatibleMode bool
|
||||||
|
streamMode bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func planTaskResponse(kind string, compatible bool, body map[string]any, r *http.Request) taskResponsePlan {
|
||||||
|
asyncMode := asyncRequest(r)
|
||||||
|
compatibleMode := compatible
|
||||||
|
if kind == "chat.completions" && !compatible {
|
||||||
|
asyncMode = false
|
||||||
|
compatibleMode = true
|
||||||
|
}
|
||||||
|
return taskResponsePlan{
|
||||||
|
asyncMode: asyncMode,
|
||||||
|
compatibleMode: compatibleMode,
|
||||||
|
streamMode: boolValue(body, "stream"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func writeTaskAccepted(w http.ResponseWriter, task store.GatewayTask) {
|
func writeTaskAccepted(w http.ResponseWriter, task store.GatewayTask) {
|
||||||
writeJSON(w, http.StatusAccepted, map[string]any{
|
writeJSON(w, http.StatusAccepted, map[string]any{
|
||||||
"taskId": task.ID,
|
"taskId": task.ID,
|
||||||
@ -1115,6 +1175,9 @@ func runErrorDetails(err error) map[string]any {
|
|||||||
if detail := rateLimitErrorDetail(err); len(detail) > 0 {
|
if detail := rateLimitErrorDetail(err); len(detail) > 0 {
|
||||||
return map[string]any{"rateLimit": detail}
|
return map[string]any{"rateLimit": detail}
|
||||||
}
|
}
|
||||||
|
if detail := store.ModelCandidateErrorDetails(err); len(detail) > 0 {
|
||||||
|
return map[string]any{"modelCandidate": detail}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -123,6 +123,19 @@ type TaskEventListResponse struct {
|
|||||||
Items []store.TaskEvent `json:"items"`
|
Items []store.TaskEvent `json:"items"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type FileStorageChannelListResponse struct {
|
||||||
|
Items []store.FileStorageChannel `json:"items"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FileUploadResponse struct {
|
||||||
|
ID string `json:"id,omitempty" example:"file_abc123"`
|
||||||
|
URL string `json:"url,omitempty" example:"/static/uploaded/upload-abc123.png"`
|
||||||
|
Filename string `json:"filename,omitempty" example:"image.png"`
|
||||||
|
ContentType string `json:"contentType,omitempty" example:"image/png"`
|
||||||
|
Size int `json:"size,omitempty" example:"1024"`
|
||||||
|
AssetStorage map[string]interface{} `json:"assetStorage,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
type ReplacePlatformModelsRequest struct {
|
type ReplacePlatformModelsRequest struct {
|
||||||
Models []store.CreatePlatformModelInput `json:"models"`
|
Models []store.CreatePlatformModelInput `json:"models"`
|
||||||
}
|
}
|
||||||
@ -166,6 +179,8 @@ type TaskRequest struct {
|
|||||||
Stream bool `json:"stream,omitempty" example:"false"`
|
Stream bool `json:"stream,omitempty" example:"false"`
|
||||||
RunMode string `json:"runMode,omitempty" example:"simulation"`
|
RunMode string `json:"runMode,omitempty" example:"simulation"`
|
||||||
MaxTokens int `json:"max_tokens,omitempty" example:"512"`
|
MaxTokens int `json:"max_tokens,omitempty" example:"512"`
|
||||||
|
// ReasoningEffort 推理深度,OpenAI-compatible 请求字段;开放字符串,取值随 provider 和模型能力而定,常见值为 none、minimal、low、medium、high、xhigh,也可配置 max 等供应商自定义值。
|
||||||
|
ReasoningEffort string `json:"reasoning_effort,omitempty" example:"medium"`
|
||||||
Size string `json:"size,omitempty" example:"1024x1024"`
|
Size string `json:"size,omitempty" example:"1024x1024"`
|
||||||
Duration int `json:"duration,omitempty" example:"5"`
|
Duration int `json:"duration,omitempty" example:"5"`
|
||||||
Resolution string `json:"resolution,omitempty" example:"720p"`
|
Resolution string `json:"resolution,omitempty" example:"720p"`
|
||||||
@ -176,6 +191,8 @@ type ChatCompletionRequest struct {
|
|||||||
Messages []ChatMessage `json:"messages"`
|
Messages []ChatMessage `json:"messages"`
|
||||||
Temperature float64 `json:"temperature,omitempty" example:"0.7"`
|
Temperature float64 `json:"temperature,omitempty" example:"0.7"`
|
||||||
MaxTokens int `json:"max_tokens,omitempty" example:"512"`
|
MaxTokens int `json:"max_tokens,omitempty" example:"512"`
|
||||||
|
// ReasoningEffort 推理深度,OpenAI-compatible 请求字段;开放字符串,取值随 provider 和模型能力而定,常见值为 none、minimal、low、medium、high、xhigh,也可配置 max 等供应商自定义值。
|
||||||
|
ReasoningEffort string `json:"reasoning_effort,omitempty" example:"medium"`
|
||||||
Stream bool `json:"stream,omitempty" example:"false"`
|
Stream bool `json:"stream,omitempty" example:"false"`
|
||||||
RunMode string `json:"runMode,omitempty" example:"simulation"`
|
RunMode string `json:"runMode,omitempty" example:"simulation"`
|
||||||
}
|
}
|
||||||
@ -229,6 +246,32 @@ type CompatibleResponse struct {
|
|||||||
Usage map[string]interface{} `json:"usage,omitempty"`
|
Usage map[string]interface{} `json:"usage,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ChatCompletionCompatibleResponse struct {
|
||||||
|
ID string `json:"id" example:"chatcmpl-123"`
|
||||||
|
Object string `json:"object" example:"chat.completion"`
|
||||||
|
Created int64 `json:"created,omitempty" example:"1710000000"`
|
||||||
|
Model string `json:"model" example:"gpt-4o-mini"`
|
||||||
|
Choices []ChatCompletionChoice `json:"choices"`
|
||||||
|
Usage *ChatCompletionUsage `json:"usage,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionChoice struct {
|
||||||
|
Index int `json:"index" example:"0"`
|
||||||
|
Message ChatCompletionChoiceMessage `json:"message"`
|
||||||
|
FinishReason string `json:"finish_reason,omitempty" example:"stop"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionChoiceMessage struct {
|
||||||
|
Role string `json:"role" example:"assistant"`
|
||||||
|
Content string `json:"content" example:"Hello"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionUsage struct {
|
||||||
|
PromptTokens int `json:"prompt_tokens,omitempty" example:"12"`
|
||||||
|
CompletionTokens int `json:"completion_tokens,omitempty" example:"8"`
|
||||||
|
TotalTokens int `json:"total_tokens,omitempty" example:"20"`
|
||||||
|
}
|
||||||
|
|
||||||
type NetworkProxyConfigResponse struct {
|
type NetworkProxyConfigResponse struct {
|
||||||
GlobalHTTPProxy string `json:"globalHttpProxy" example:"http://127.0.0.1:7890"`
|
GlobalHTTPProxy string `json:"globalHttpProxy" example:"http://127.0.0.1:7890"`
|
||||||
GlobalHTTPProxySet bool `json:"globalHttpProxySet" example:"true"`
|
GlobalHTTPProxySet bool `json:"globalHttpProxySet" example:"true"`
|
||||||
|
|||||||
@ -126,7 +126,7 @@ func NewServerWithContext(ctx context.Context, cfg config.Config, db *store.Stor
|
|||||||
mux.Handle("GET /api/v1/playground/models", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listPlayableModels)))
|
mux.Handle("GET /api/v1/playground/models", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listPlayableModels)))
|
||||||
mux.Handle("GET /api/admin/runtime/rate-limit-windows", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.listRateLimitWindows)))
|
mux.Handle("GET /api/admin/runtime/rate-limit-windows", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.listRateLimitWindows)))
|
||||||
mux.Handle("GET /api/admin/runtime/model-rate-limits", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.listModelRateLimitStatuses)))
|
mux.Handle("GET /api/admin/runtime/model-rate-limits", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.listModelRateLimitStatuses)))
|
||||||
mux.Handle("POST /api/v1/chat/completions", server.auth.Require(auth.PermissionBasic, server.createTask("chat.completions", false)))
|
mux.Handle("POST /api/v1/chat/completions", server.auth.Require(auth.PermissionBasic, server.createAPIV1ChatCompletions()))
|
||||||
mux.Handle("POST /api/v1/responses", server.auth.Require(auth.PermissionBasic, server.createTask("responses", false)))
|
mux.Handle("POST /api/v1/responses", server.auth.Require(auth.PermissionBasic, server.createTask("responses", false)))
|
||||||
mux.Handle("POST /api/v1/images/generations", server.auth.Require(auth.PermissionBasic, server.createTask("images.generations", false)))
|
mux.Handle("POST /api/v1/images/generations", server.auth.Require(auth.PermissionBasic, server.createTask("images.generations", false)))
|
||||||
mux.Handle("POST /api/v1/images/edits", server.auth.Require(auth.PermissionBasic, server.createTask("images.edits", false)))
|
mux.Handle("POST /api/v1/images/edits", server.auth.Require(auth.PermissionBasic, server.createTask("images.edits", false)))
|
||||||
|
|||||||
@ -9,10 +9,28 @@ import (
|
|||||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// serveGeneratedStaticAsset godoc
|
||||||
|
// @Summary 获取本地生成资源
|
||||||
|
// @Description 从本地生成资源目录读取图片、视频等任务产物;不存在时返回 404。
|
||||||
|
// @Tags static
|
||||||
|
// @Produce octet-stream
|
||||||
|
// @Param asset path string true "资源文件名"
|
||||||
|
// @Success 200 {file} file
|
||||||
|
// @Failure 404 {string} string "Not Found"
|
||||||
|
// @Router /static/generated/{asset} [get]
|
||||||
func (s *Server) serveGeneratedStaticAsset(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) serveGeneratedStaticAsset(w http.ResponseWriter, r *http.Request) {
|
||||||
s.serveLocalStaticAsset(w, r, s.cfg.LocalGeneratedStorageDir, config.DefaultLocalGeneratedStorageDir)
|
s.serveLocalStaticAsset(w, r, s.cfg.LocalGeneratedStorageDir, config.DefaultLocalGeneratedStorageDir)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// serveUploadedStaticAsset godoc
|
||||||
|
// @Summary 获取本地上传资源
|
||||||
|
// @Description 从本地上传资源目录读取用户上传文件;不存在时返回 404。
|
||||||
|
// @Tags static
|
||||||
|
// @Produce octet-stream
|
||||||
|
// @Param asset path string true "资源文件名"
|
||||||
|
// @Success 200 {file} file
|
||||||
|
// @Failure 404 {string} string "Not Found"
|
||||||
|
// @Router /static/uploaded/{asset} [get]
|
||||||
func (s *Server) serveUploadedStaticAsset(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) serveUploadedStaticAsset(w http.ResponseWriter, r *http.Request) {
|
||||||
s.serveLocalStaticAsset(w, r, s.cfg.LocalUploadedStorageDir, config.DefaultLocalUploadedStorageDir)
|
s.serveLocalStaticAsset(w, r, s.cfg.LocalUploadedStorageDir, config.DefaultLocalUploadedStorageDir)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,6 +1,13 @@
|
|||||||
package httpapi
|
package httpapi
|
||||||
|
|
||||||
import "net/http"
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
|
||||||
|
)
|
||||||
|
|
||||||
func prepareCompatibleStream(w http.ResponseWriter) http.Flusher {
|
func prepareCompatibleStream(w http.ResponseWriter) http.Flusher {
|
||||||
w.Header().Set("Content-Type", "text/event-stream")
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
@ -10,55 +17,180 @@ func prepareCompatibleStream(w http.ResponseWriter) http.Flusher {
|
|||||||
return flusher
|
return flusher
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeCompatibleDelta(w http.ResponseWriter, kind string, model string, content string) {
|
type compatibleStreamWriter struct {
|
||||||
if kind == "responses" {
|
kind string
|
||||||
sendSSE(w, "response.output_text.delta", map[string]any{"type": "response.output_text.delta", "delta": content})
|
model string
|
||||||
return
|
includeUsage bool
|
||||||
}
|
|
||||||
sendSSE(w, "message", map[string]any{
|
id string
|
||||||
"id": "chatcmpl-stream",
|
created int64
|
||||||
"object": "chat.completion.chunk",
|
systemFingerprint any
|
||||||
"model": model,
|
sentRole bool
|
||||||
"choices": []any{map[string]any{"index": 0, "delta": map[string]any{"content": content}, "finish_reason": nil}},
|
sentFinish bool
|
||||||
})
|
sentUsage bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeCompatibleDone(w http.ResponseWriter, kind string, model string, output map[string]any) {
|
func newCompatibleStreamWriter(kind string, model string, includeUsage bool) *compatibleStreamWriter {
|
||||||
if kind == "responses" {
|
return &compatibleStreamWriter{
|
||||||
|
kind: kind,
|
||||||
|
model: model,
|
||||||
|
includeUsage: includeUsage,
|
||||||
|
id: "chatcmpl-stream",
|
||||||
|
created: time.Now().Unix(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *compatibleStreamWriter) writeDelta(w http.ResponseWriter, event clients.StreamDeltaEvent) {
|
||||||
|
if s.kind == "responses" {
|
||||||
|
if event.Text != "" {
|
||||||
|
sendSSE(w, "response.output_text.delta", map[string]any{"type": "response.output_text.delta", "delta": event.Text})
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if event.Event != nil && isChatCompletionChunk(event.Event) {
|
||||||
|
s.writeChatChunk(w, event.Event)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if event.Text == "" && event.ReasoningContent == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.ensureRoleChunk(w)
|
||||||
|
if event.ReasoningContent != "" {
|
||||||
|
s.writeChatData(w, s.chatChunk([]any{map[string]any{"index": 0, "delta": map[string]any{"reasoning_content": event.ReasoningContent}, "finish_reason": nil}}, nil))
|
||||||
|
}
|
||||||
|
if event.Text != "" {
|
||||||
|
s.writeChatData(w, s.chatChunk([]any{map[string]any{"index": 0, "delta": map[string]any{"content": event.Text}, "finish_reason": nil}}, nil))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *compatibleStreamWriter) writeDone(w http.ResponseWriter, output map[string]any) {
|
||||||
|
if s.kind == "responses" {
|
||||||
sendSSE(w, "response.completed", map[string]any{"type": "response.completed", "response": output})
|
sendSSE(w, "response.completed", map[string]any{"type": "response.completed", "response": output})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
sendSSE(w, "message", map[string]any{
|
s.captureOutputMetadata(output)
|
||||||
"id": firstString(output["id"], "chatcmpl-stream"),
|
if !s.sentRole {
|
||||||
|
s.ensureRoleChunk(w)
|
||||||
|
}
|
||||||
|
if !s.sentFinish {
|
||||||
|
s.writeChatData(w, s.chatChunk([]any{map[string]any{"index": 0, "delta": map[string]any{}, "finish_reason": finishReasonFromOutput(output)}}, nil))
|
||||||
|
s.sentFinish = true
|
||||||
|
}
|
||||||
|
if s.includeUsage && !s.sentUsage {
|
||||||
|
if usage, ok := output["usage"].(map[string]any); ok && len(usage) > 0 {
|
||||||
|
s.writeChatData(w, s.chatChunk([]any{}, usage))
|
||||||
|
s.sentUsage = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.writeDoneMarker(w)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *compatibleStreamWriter) writeChatChunk(w http.ResponseWriter, chunk map[string]any) {
|
||||||
|
chunk = clients.NormalizeChatCompletionStreamEvent(chunk)
|
||||||
|
s.captureChunkMetadata(chunk)
|
||||||
|
choices, _ := chunk["choices"].([]any)
|
||||||
|
usage, hasUsage := chunk["usage"].(map[string]any)
|
||||||
|
if len(choices) == 0 && hasUsage {
|
||||||
|
if !s.includeUsage {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.writeChatData(w, s.chatChunk([]any{}, usage))
|
||||||
|
s.sentUsage = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(choices) > 0 && !chunkHasRole(choices) && !s.sentRole {
|
||||||
|
s.ensureRoleChunk(w)
|
||||||
|
}
|
||||||
|
if chunkHasRole(choices) {
|
||||||
|
s.sentRole = true
|
||||||
|
}
|
||||||
|
if chunkHasFinishReason(choices) {
|
||||||
|
s.sentFinish = true
|
||||||
|
}
|
||||||
|
normalized := cloneMap(chunk)
|
||||||
|
normalized["id"] = s.id
|
||||||
|
normalized["object"] = "chat.completion.chunk"
|
||||||
|
normalized["created"] = s.created
|
||||||
|
normalized["model"] = firstString(normalized["model"], s.model)
|
||||||
|
normalized["system_fingerprint"] = s.systemFingerprint
|
||||||
|
s.writeChatData(w, normalized)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *compatibleStreamWriter) ensureRoleChunk(w http.ResponseWriter) {
|
||||||
|
if s.sentRole {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.writeChatData(w, s.chatChunk([]any{map[string]any{"index": 0, "delta": map[string]any{"role": "assistant"}, "finish_reason": nil}}, nil))
|
||||||
|
s.sentRole = true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *compatibleStreamWriter) chatChunk(choices []any, usage map[string]any) map[string]any {
|
||||||
|
chunk := map[string]any{
|
||||||
|
"id": s.id,
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
"model": model,
|
"created": s.created,
|
||||||
"choices": []any{map[string]any{"index": 0, "delta": map[string]any{}, "finish_reason": "stop"}},
|
"model": s.model,
|
||||||
})
|
"system_fingerprint": s.systemFingerprint,
|
||||||
|
"choices": choices,
|
||||||
|
}
|
||||||
|
if usage != nil {
|
||||||
|
chunk["usage"] = usage
|
||||||
|
} else {
|
||||||
|
chunk["usage"] = nil
|
||||||
|
}
|
||||||
|
return chunk
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *compatibleStreamWriter) writeChatData(w http.ResponseWriter, payload map[string]any) {
|
||||||
|
bytes, _ := json.Marshal(payload)
|
||||||
|
_, _ = fmt.Fprintf(w, "data: %s\n\n", bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *compatibleStreamWriter) writeDoneMarker(w http.ResponseWriter) {
|
||||||
|
_, _ = fmt.Fprint(w, "data: [DONE]\n\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *compatibleStreamWriter) captureChunkMetadata(chunk map[string]any) {
|
||||||
|
if id := firstString(chunk["id"], ""); id != "" {
|
||||||
|
s.id = id
|
||||||
|
}
|
||||||
|
if model := firstString(chunk["model"], ""); model != "" {
|
||||||
|
s.model = model
|
||||||
|
}
|
||||||
|
if created := int64FromAny(chunk["created"]); created > 0 {
|
||||||
|
s.created = created
|
||||||
|
}
|
||||||
|
if value, ok := chunk["system_fingerprint"]; ok {
|
||||||
|
s.systemFingerprint = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *compatibleStreamWriter) captureOutputMetadata(output map[string]any) {
|
||||||
|
if id := firstString(output["id"], ""); id != "" {
|
||||||
|
s.id = id
|
||||||
|
}
|
||||||
|
if model := firstString(output["model"], ""); model != "" {
|
||||||
|
s.model = model
|
||||||
|
}
|
||||||
|
if created := int64FromAny(output["created"]); created > 0 {
|
||||||
|
s.created = created
|
||||||
|
}
|
||||||
|
if value, ok := output["system_fingerprint"]; ok {
|
||||||
|
s.systemFingerprint = value
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeCompatibleStream(w http.ResponseWriter, kind string, model string, output map[string]any) {
|
func writeCompatibleStream(w http.ResponseWriter, kind string, model string, output map[string]any) {
|
||||||
prepareCompatibleStream(w)
|
prepareCompatibleStream(w)
|
||||||
|
writer := newCompatibleStreamWriter(kind, model, true)
|
||||||
content := extractOutputText(output)
|
content := extractOutputText(output)
|
||||||
if content == "" {
|
if content == "" {
|
||||||
content = "done"
|
content = "done"
|
||||||
}
|
}
|
||||||
if kind == "responses" {
|
writer.writeDelta(w, clients.StreamDeltaEvent{Text: content})
|
||||||
sendSSE(w, "response.output_text.delta", map[string]any{"type": "response.output_text.delta", "delta": content})
|
writer.writeDone(w, output)
|
||||||
sendSSE(w, "response.completed", map[string]any{"type": "response.completed", "response": output})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
sendSSE(w, "message", map[string]any{
|
|
||||||
"id": output["id"],
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"model": model,
|
|
||||||
"choices": []any{map[string]any{"index": 0, "delta": map[string]any{"content": content}, "finish_reason": nil}},
|
|
||||||
})
|
|
||||||
sendSSE(w, "message", map[string]any{
|
|
||||||
"id": output["id"],
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"model": model,
|
|
||||||
"choices": []any{map[string]any{"index": 0, "delta": map[string]any{}, "finish_reason": "stop"}},
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func firstString(value any, fallback string) string {
|
func firstString(value any, fallback string) string {
|
||||||
@ -68,6 +200,68 @@ func firstString(value any, fallback string) string {
|
|||||||
return fallback
|
return fallback
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func int64FromAny(value any) int64 {
|
||||||
|
switch typed := value.(type) {
|
||||||
|
case int64:
|
||||||
|
return typed
|
||||||
|
case int:
|
||||||
|
return int64(typed)
|
||||||
|
case float64:
|
||||||
|
return int64(typed)
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isChatCompletionChunk(event map[string]any) bool {
|
||||||
|
object, _ := event["object"].(string)
|
||||||
|
if object == "chat.completion.chunk" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
_, hasChoices := event["choices"].([]any)
|
||||||
|
return hasChoices
|
||||||
|
}
|
||||||
|
|
||||||
|
func chunkHasRole(choices []any) bool {
|
||||||
|
for _, rawChoice := range choices {
|
||||||
|
choice, _ := rawChoice.(map[string]any)
|
||||||
|
delta, _ := choice["delta"].(map[string]any)
|
||||||
|
if role, ok := delta["role"].(string); ok && role != "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func chunkHasFinishReason(choices []any) bool {
|
||||||
|
for _, rawChoice := range choices {
|
||||||
|
choice, _ := rawChoice.(map[string]any)
|
||||||
|
if reason, ok := choice["finish_reason"].(string); ok && reason != "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func finishReasonFromOutput(output map[string]any) string {
|
||||||
|
choices, _ := output["choices"].([]any)
|
||||||
|
for _, rawChoice := range choices {
|
||||||
|
choice, _ := rawChoice.(map[string]any)
|
||||||
|
if reason, ok := choice["finish_reason"].(string); ok && reason != "" {
|
||||||
|
return reason
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "stop"
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneMap(value map[string]any) map[string]any {
|
||||||
|
out := map[string]any{}
|
||||||
|
for key, item := range value {
|
||||||
|
out[key] = item
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
func extractOutputText(output map[string]any) string {
|
func extractOutputText(output map[string]any) string {
|
||||||
if text, ok := output["output_text"].(string); ok {
|
if text, ok := output["output_text"].(string); ok {
|
||||||
return text
|
return text
|
||||||
|
|||||||
@ -8,6 +8,17 @@ import (
|
|||||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// listFileStorageChannels godoc
|
||||||
|
// @Summary 列出文件存储通道
|
||||||
|
// @Description 返回所有未删除的文件存储通道,用于管理上传与生成资源回传策略。
|
||||||
|
// @Tags system
|
||||||
|
// @Produce json
|
||||||
|
// @Security BearerAuth
|
||||||
|
// @Success 200 {object} FileStorageChannelListResponse
|
||||||
|
// @Failure 401 {object} ErrorEnvelope
|
||||||
|
// @Failure 403 {object} ErrorEnvelope
|
||||||
|
// @Failure 500 {object} ErrorEnvelope
|
||||||
|
// @Router /api/admin/system/file-storage/channels [get]
|
||||||
func (s *Server) listFileStorageChannels(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) listFileStorageChannels(w http.ResponseWriter, r *http.Request) {
|
||||||
items, err := s.store.ListFileStorageChannels(r.Context())
|
items, err := s.store.ListFileStorageChannels(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -18,6 +29,17 @@ func (s *Server) listFileStorageChannels(w http.ResponseWriter, r *http.Request)
|
|||||||
writeJSON(w, http.StatusOK, map[string]any{"items": items})
|
writeJSON(w, http.StatusOK, map[string]any{"items": items})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getFileStorageSettings godoc
|
||||||
|
// @Summary 获取文件存储设置
|
||||||
|
// @Description 返回文件存储系统设置;数据库对象尚未创建时返回默认设置。
|
||||||
|
// @Tags system
|
||||||
|
// @Produce json
|
||||||
|
// @Security BearerAuth
|
||||||
|
// @Success 200 {object} store.FileStorageSettings
|
||||||
|
// @Failure 401 {object} ErrorEnvelope
|
||||||
|
// @Failure 403 {object} ErrorEnvelope
|
||||||
|
// @Failure 500 {object} ErrorEnvelope
|
||||||
|
// @Router /api/admin/system/file-storage/settings [get]
|
||||||
func (s *Server) getFileStorageSettings(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) getFileStorageSettings(w http.ResponseWriter, r *http.Request) {
|
||||||
settings, err := s.store.GetFileStorageSettings(r.Context())
|
settings, err := s.store.GetFileStorageSettings(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -32,6 +54,20 @@ func (s *Server) getFileStorageSettings(w http.ResponseWriter, r *http.Request)
|
|||||||
writeJSON(w, http.StatusOK, settings)
|
writeJSON(w, http.StatusOK, settings)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// updateFileStorageSettings godoc
|
||||||
|
// @Summary 更新文件存储设置
|
||||||
|
// @Description 更新生成资源上传策略等文件存储系统设置。
|
||||||
|
// @Tags system
|
||||||
|
// @Accept json
|
||||||
|
// @Produce json
|
||||||
|
// @Security BearerAuth
|
||||||
|
// @Param body body store.FileStorageSettingsInput true "文件存储设置"
|
||||||
|
// @Success 200 {object} store.FileStorageSettings
|
||||||
|
// @Failure 400 {object} ErrorEnvelope
|
||||||
|
// @Failure 401 {object} ErrorEnvelope
|
||||||
|
// @Failure 403 {object} ErrorEnvelope
|
||||||
|
// @Failure 500 {object} ErrorEnvelope
|
||||||
|
// @Router /api/admin/system/file-storage/settings [patch]
|
||||||
func (s *Server) updateFileStorageSettings(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) updateFileStorageSettings(w http.ResponseWriter, r *http.Request) {
|
||||||
var input store.FileStorageSettingsInput
|
var input store.FileStorageSettingsInput
|
||||||
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
|
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
|
||||||
@ -47,6 +83,21 @@ func (s *Server) updateFileStorageSettings(w http.ResponseWriter, r *http.Reques
|
|||||||
writeJSON(w, http.StatusOK, settings)
|
writeJSON(w, http.StatusOK, settings)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// createFileStorageChannel godoc
|
||||||
|
// @Summary 创建文件存储通道
|
||||||
|
// @Description 创建文件存储通道,当前主要用于配置 server-main OpenAPI 上传通道。
|
||||||
|
// @Tags system
|
||||||
|
// @Accept json
|
||||||
|
// @Produce json
|
||||||
|
// @Security BearerAuth
|
||||||
|
// @Param body body store.FileStorageChannelInput true "文件存储通道"
|
||||||
|
// @Success 201 {object} store.FileStorageChannel
|
||||||
|
// @Failure 400 {object} ErrorEnvelope
|
||||||
|
// @Failure 401 {object} ErrorEnvelope
|
||||||
|
// @Failure 403 {object} ErrorEnvelope
|
||||||
|
// @Failure 409 {object} ErrorEnvelope
|
||||||
|
// @Failure 500 {object} ErrorEnvelope
|
||||||
|
// @Router /api/admin/system/file-storage/channels [post]
|
||||||
func (s *Server) createFileStorageChannel(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) createFileStorageChannel(w http.ResponseWriter, r *http.Request) {
|
||||||
var input store.FileStorageChannelInput
|
var input store.FileStorageChannelInput
|
||||||
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
|
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
|
||||||
@ -70,6 +121,23 @@ func (s *Server) createFileStorageChannel(w http.ResponseWriter, r *http.Request
|
|||||||
writeJSON(w, http.StatusCreated, item)
|
writeJSON(w, http.StatusCreated, item)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// updateFileStorageChannel godoc
|
||||||
|
// @Summary 更新文件存储通道
|
||||||
|
// @Description 更新指定文件存储通道的名称、凭证、场景、优先级、状态和重试策略。
|
||||||
|
// @Tags system
|
||||||
|
// @Accept json
|
||||||
|
// @Produce json
|
||||||
|
// @Security BearerAuth
|
||||||
|
// @Param channelID path string true "文件存储通道 ID"
|
||||||
|
// @Param body body store.FileStorageChannelInput true "文件存储通道"
|
||||||
|
// @Success 200 {object} store.FileStorageChannel
|
||||||
|
// @Failure 400 {object} ErrorEnvelope
|
||||||
|
// @Failure 401 {object} ErrorEnvelope
|
||||||
|
// @Failure 403 {object} ErrorEnvelope
|
||||||
|
// @Failure 404 {object} ErrorEnvelope
|
||||||
|
// @Failure 409 {object} ErrorEnvelope
|
||||||
|
// @Failure 500 {object} ErrorEnvelope
|
||||||
|
// @Router /api/admin/system/file-storage/channels/{channelID} [patch]
|
||||||
func (s *Server) updateFileStorageChannel(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) updateFileStorageChannel(w http.ResponseWriter, r *http.Request) {
|
||||||
var input store.FileStorageChannelInput
|
var input store.FileStorageChannelInput
|
||||||
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
|
if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
|
||||||
@ -107,6 +175,19 @@ func (s *Server) updateFileStorageChannel(w http.ResponseWriter, r *http.Request
|
|||||||
writeJSON(w, http.StatusOK, item)
|
writeJSON(w, http.StatusOK, item)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// deleteFileStorageChannel godoc
|
||||||
|
// @Summary 删除文件存储通道
|
||||||
|
// @Description 软删除指定文件存储通道。
|
||||||
|
// @Tags system
|
||||||
|
// @Produce json
|
||||||
|
// @Security BearerAuth
|
||||||
|
// @Param channelID path string true "文件存储通道 ID"
|
||||||
|
// @Success 204 "No Content"
|
||||||
|
// @Failure 401 {object} ErrorEnvelope
|
||||||
|
// @Failure 403 {object} ErrorEnvelope
|
||||||
|
// @Failure 404 {object} ErrorEnvelope
|
||||||
|
// @Failure 500 {object} ErrorEnvelope
|
||||||
|
// @Router /api/admin/system/file-storage/channels/{channelID} [delete]
|
||||||
func (s *Server) deleteFileStorageChannel(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) deleteFileStorageChannel(w http.ResponseWriter, r *http.Request) {
|
||||||
if err := s.store.DeleteFileStorageChannel(r.Context(), r.PathValue("channelID")); err != nil {
|
if err := s.store.DeleteFileStorageChannel(r.Context(), r.PathValue("channelID")); err != nil {
|
||||||
if store.IsNotFound(err) {
|
if store.IsNotFound(err) {
|
||||||
|
|||||||
355
apps/api/internal/runner/candidate_filter.go
Normal file
355
apps/api/internal/runner/candidate_filter.go
Normal file
@ -0,0 +1,355 @@
|
|||||||
|
package runner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||||
|
)
|
||||||
|
|
||||||
|
const unsupportedRequestResolutionCode = "unsupported_request_resolution"
|
||||||
|
|
||||||
|
type requestResolutionRequirement struct {
|
||||||
|
Kind string
|
||||||
|
RequestedModel string
|
||||||
|
ModelType string
|
||||||
|
Resolution string
|
||||||
|
Source string
|
||||||
|
Scopes []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type videoResolutionReferenceStats struct {
|
||||||
|
HasFirstFrame bool
|
||||||
|
HasLastFrame bool
|
||||||
|
ReferenceImages int
|
||||||
|
HasReferenceVideo bool
|
||||||
|
HasReferenceAudio bool
|
||||||
|
HasAnyMedia bool
|
||||||
|
HasExplicitContent bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func filterRuntimeCandidatesByRequest(kind string, requestedModel string, modelType string, body map[string]any, candidates []store.RuntimeModelCandidate) ([]store.RuntimeModelCandidate, map[string]any, error) {
|
||||||
|
requirement, ok := requestResolutionRequirementFor(kind, requestedModel, modelType, body)
|
||||||
|
if !ok || len(candidates) == 0 {
|
||||||
|
return candidates, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered := make([]store.RuntimeModelCandidate, 0, len(candidates))
|
||||||
|
rejected := make([]map[string]any, 0)
|
||||||
|
supportedResolutions := make([]string, 0)
|
||||||
|
for _, candidate := range candidates {
|
||||||
|
supported, detail := candidateSupportsRequestResolution(candidate, requirement)
|
||||||
|
if supported {
|
||||||
|
filtered = append(filtered, candidate)
|
||||||
|
for _, value := range stringListFromAny(detail["allowedResolutions"]) {
|
||||||
|
appendUniqueString(&supportedResolutions, value)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rejected = append(rejected, detail)
|
||||||
|
for _, value := range stringListFromAny(detail["allowedResolutions"]) {
|
||||||
|
appendUniqueString(&supportedResolutions, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
summary := requestResolutionFilterSummary(requirement, len(candidates), len(filtered), rejected, supportedResolutions)
|
||||||
|
if len(filtered) == 0 {
|
||||||
|
return nil, summary, &store.ModelCandidateUnavailableError{
|
||||||
|
Code: unsupportedRequestResolutionCode,
|
||||||
|
Message: unsupportedRequestResolutionMessage(requirement, rejected),
|
||||||
|
Details: summary,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return filtered, summary, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestResolutionRequirementFor(kind string, requestedModel string, modelType string, body map[string]any) (requestResolutionRequirement, bool) {
|
||||||
|
if !isResolutionFilteredModelType(modelType) {
|
||||||
|
return requestResolutionRequirement{}, false
|
||||||
|
}
|
||||||
|
resolution, source := requestResolutionValue(body, modelType)
|
||||||
|
if resolution == "" {
|
||||||
|
return requestResolutionRequirement{}, false
|
||||||
|
}
|
||||||
|
return requestResolutionRequirement{
|
||||||
|
Kind: kind,
|
||||||
|
RequestedModel: requestedModel,
|
||||||
|
ModelType: modelType,
|
||||||
|
Resolution: resolution,
|
||||||
|
Source: source,
|
||||||
|
Scopes: requestResolutionScopes(body, modelType),
|
||||||
|
}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestResolutionValue(body map[string]any, modelType string) (string, string) {
|
||||||
|
if value := normalizedRequestResolution(stringFromAny(body["resolution"])); value != "" {
|
||||||
|
return value, "resolution"
|
||||||
|
}
|
||||||
|
size := normalizedRequestResolution(stringFromAny(body["size"]))
|
||||||
|
if size == "" {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
if isImageResolution(modelType, size) || isVideoResolution(modelType, size) {
|
||||||
|
return size, "size"
|
||||||
|
}
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizedRequestResolution(value string) string {
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
if value == "" || isEmptyParamString(value) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
switch strings.ToLower(value) {
|
||||||
|
case "auto", "automatic", "adaptive", "default":
|
||||||
|
return ""
|
||||||
|
default:
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isResolutionFilteredModelType(modelType string) bool {
|
||||||
|
return modelType == "image_generate" || modelType == "image_edit" || isVideoModelType(modelType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func candidateSupportsRequestResolution(candidate store.RuntimeModelCandidate, requirement requestResolutionRequirement) (bool, map[string]any) {
|
||||||
|
modelType := firstNonEmptyString(candidate.ModelType, requirement.ModelType)
|
||||||
|
capability := capabilityForType(effectiveModelCapability(candidate), modelType)
|
||||||
|
detail := candidateResolutionDetail(candidate, requirement, modelType)
|
||||||
|
if capability == nil {
|
||||||
|
detail["reason"] = "capability_missing"
|
||||||
|
detail["message"] = "候选平台模型未配置对应模型类型能力。"
|
||||||
|
detail["capabilityPath"] = capabilityPath(modelType, "output_resolutions")
|
||||||
|
return false, detail
|
||||||
|
}
|
||||||
|
|
||||||
|
allowed, configured := outputResolutionAllowedValues(capability["output_resolutions"], requirement.Scopes)
|
||||||
|
detail["allowedResolutions"] = allowed
|
||||||
|
detail["capabilityPath"] = capabilityPath(modelType, "output_resolutions")
|
||||||
|
detail["capabilityValue"] = cloneAny(capability["output_resolutions"])
|
||||||
|
if !configured {
|
||||||
|
detail["reason"] = "output_resolutions_missing"
|
||||||
|
detail["message"] = "候选平台模型未声明 output_resolutions。"
|
||||||
|
return false, detail
|
||||||
|
}
|
||||||
|
if containsResolution(allowed, requirement.Resolution) {
|
||||||
|
detail["reason"] = "supported"
|
||||||
|
return true, detail
|
||||||
|
}
|
||||||
|
detail["reason"] = "resolution_not_allowed"
|
||||||
|
detail["message"] = "请求分辨率不在候选平台模型 output_resolutions 中。"
|
||||||
|
return false, detail
|
||||||
|
}
|
||||||
|
|
||||||
|
func outputResolutionAllowedValues(value any, scopes []string) ([]string, bool) {
|
||||||
|
switch typed := value.(type) {
|
||||||
|
case []any, []string, string:
|
||||||
|
return uniqueStringList(stringListFromAny(typed)), true
|
||||||
|
case map[string]any:
|
||||||
|
for _, scope := range append(scopes, "default", "*", "all") {
|
||||||
|
if scope == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if raw, ok := typed[scope]; ok {
|
||||||
|
return uniqueStringList(stringListFromAny(raw)), true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(scopes) == 0 {
|
||||||
|
values := make([]string, 0)
|
||||||
|
for _, raw := range typed {
|
||||||
|
values = append(values, stringListFromAny(raw)...)
|
||||||
|
}
|
||||||
|
return uniqueStringList(values), len(values) > 0
|
||||||
|
}
|
||||||
|
return nil, true
|
||||||
|
default:
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsResolution(values []string, target string) bool {
|
||||||
|
for _, value := range values {
|
||||||
|
if strings.EqualFold(strings.TrimSpace(value), strings.TrimSpace(target)) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func candidateResolutionDetail(candidate store.RuntimeModelCandidate, requirement requestResolutionRequirement, modelType string) map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"platformId": candidate.PlatformID,
|
||||||
|
"platformKey": candidate.PlatformKey,
|
||||||
|
"platformName": candidate.PlatformName,
|
||||||
|
"provider": candidate.Provider,
|
||||||
|
"platformModelId": candidate.PlatformModelID,
|
||||||
|
"modelName": candidate.ModelName,
|
||||||
|
"modelAlias": candidate.ModelAlias,
|
||||||
|
"displayName": candidate.DisplayName,
|
||||||
|
"providerModelName": candidate.ProviderModelName,
|
||||||
|
"modelType": modelType,
|
||||||
|
"requested": map[string]any{
|
||||||
|
"resolution": requirement.Resolution,
|
||||||
|
"source": requirement.Source,
|
||||||
|
"scopes": requirement.Scopes,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestResolutionFilterSummary(requirement requestResolutionRequirement, candidateCount int, supportedCandidateCount int, rejected []map[string]any, supportedResolutions []string) map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"code": unsupportedRequestResolutionCode,
|
||||||
|
"filter": "request_resolution",
|
||||||
|
"kind": requirement.Kind,
|
||||||
|
"requestedModel": requirement.RequestedModel,
|
||||||
|
"modelType": requirement.ModelType,
|
||||||
|
"requestedResolution": requirement.Resolution,
|
||||||
|
"resolutionSource": requirement.Source,
|
||||||
|
"resolutionScopes": requirement.Scopes,
|
||||||
|
"capabilityPath": capabilityPath(requirement.ModelType, "output_resolutions"),
|
||||||
|
"candidateCount": candidateCount,
|
||||||
|
"supportedCandidateCount": supportedCandidateCount,
|
||||||
|
"filteredCandidateCount": len(rejected),
|
||||||
|
"supportedResolutions": uniqueStringList(supportedResolutions),
|
||||||
|
"rejectedCandidates": rejected,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func unsupportedRequestResolutionMessage(requirement requestResolutionRequirement, rejected []map[string]any) string {
|
||||||
|
resource := "媒体"
|
||||||
|
if requirement.ModelType == "image_generate" || requirement.ModelType == "image_edit" {
|
||||||
|
resource = "图像"
|
||||||
|
} else if isVideoModelType(requirement.ModelType) {
|
||||||
|
resource = "视频"
|
||||||
|
}
|
||||||
|
message := fmt.Sprintf("请求的%s分辨率 %s 没有可用平台模型支持,已过滤 %d 个候选平台模型", resource, requirement.Resolution, len(rejected))
|
||||||
|
if summaries := rejectedResolutionSummaries(rejected, 3); len(summaries) > 0 {
|
||||||
|
message += ";候选支持:" + strings.Join(summaries, ";")
|
||||||
|
}
|
||||||
|
return message
|
||||||
|
}
|
||||||
|
|
||||||
|
func rejectedResolutionSummaries(rejected []map[string]any, limit int) []string {
|
||||||
|
summaries := make([]string, 0, limit)
|
||||||
|
for _, item := range rejected {
|
||||||
|
if len(summaries) >= limit {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
allowed := stringListFromAny(item["allowedResolutions"])
|
||||||
|
if len(allowed) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
name := firstNonEmptyString(stringFromAny(item["platformName"]), stringFromAny(item["platformKey"]), stringFromAny(item["provider"]))
|
||||||
|
model := firstNonEmptyString(stringFromAny(item["displayName"]), stringFromAny(item["modelAlias"]), stringFromAny(item["modelName"]))
|
||||||
|
if model != "" {
|
||||||
|
name = firstNonEmptyString(name, model)
|
||||||
|
if name != model {
|
||||||
|
name += "/" + model
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if name == "" {
|
||||||
|
name = "候选"
|
||||||
|
}
|
||||||
|
summaries = append(summaries, fmt.Sprintf("%s=%s", name, strings.Join(allowed, "/")))
|
||||||
|
}
|
||||||
|
return summaries
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestResolutionScopes(body map[string]any, modelType string) []string {
|
||||||
|
if !isVideoModelType(modelType) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
scopes := make([]string, 0)
|
||||||
|
for _, key := range []string{"videoMode", "video_mode", "mode", "generation_mode", "generate_mode", "supported_mode"} {
|
||||||
|
appendUniqueString(&scopes, stringFromMap(body, key))
|
||||||
|
}
|
||||||
|
stats := videoResolutionReferenceStatsFromBody(body)
|
||||||
|
if stats.HasFirstFrame && stats.HasLastFrame {
|
||||||
|
appendUniqueString(&scopes, "input_first_last_frame")
|
||||||
|
appendUniqueString(&scopes, "first_last_frame")
|
||||||
|
} else if stats.HasFirstFrame {
|
||||||
|
appendUniqueString(&scopes, "input_first_frame")
|
||||||
|
} else if stats.HasLastFrame {
|
||||||
|
appendUniqueString(&scopes, "input_last_frame")
|
||||||
|
}
|
||||||
|
if stats.ReferenceImages > 1 {
|
||||||
|
appendUniqueString(&scopes, "input_reference_generate_multiple")
|
||||||
|
appendUniqueString(&scopes, "image_reference")
|
||||||
|
} else if stats.ReferenceImages == 1 {
|
||||||
|
appendUniqueString(&scopes, "input_reference_generate_single")
|
||||||
|
appendUniqueString(&scopes, "image_reference")
|
||||||
|
}
|
||||||
|
if stats.HasReferenceVideo {
|
||||||
|
appendUniqueString(&scopes, "video_reference")
|
||||||
|
}
|
||||||
|
if stats.HasReferenceAudio {
|
||||||
|
appendUniqueString(&scopes, "audio_reference")
|
||||||
|
}
|
||||||
|
if !stats.HasAnyMedia {
|
||||||
|
appendUniqueString(&scopes, "text_to_video")
|
||||||
|
}
|
||||||
|
return scopes
|
||||||
|
}
|
||||||
|
|
||||||
|
func videoResolutionReferenceStatsFromBody(body map[string]any) videoResolutionReferenceStats {
|
||||||
|
stats := videoResolutionReferenceStats{}
|
||||||
|
content := contentItems(body["content"])
|
||||||
|
stats.HasExplicitContent = len(content) > 0
|
||||||
|
for _, item := range content {
|
||||||
|
if isImageContent(item) {
|
||||||
|
stats.HasAnyMedia = true
|
||||||
|
switch strings.TrimSpace(stringFromAny(item["role"])) {
|
||||||
|
case "first_frame":
|
||||||
|
stats.HasFirstFrame = true
|
||||||
|
case "last_frame":
|
||||||
|
stats.HasLastFrame = true
|
||||||
|
default:
|
||||||
|
stats.ReferenceImages++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if isVideoContent(item) {
|
||||||
|
stats.HasAnyMedia = true
|
||||||
|
stats.HasReferenceVideo = true
|
||||||
|
}
|
||||||
|
if isAudioContent(item) {
|
||||||
|
stats.HasAnyMedia = true
|
||||||
|
stats.HasReferenceAudio = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if hasAnyString(body, "first_frame", "firstFrame") {
|
||||||
|
stats.HasAnyMedia = true
|
||||||
|
stats.HasFirstFrame = true
|
||||||
|
}
|
||||||
|
if hasAnyString(body, "last_frame", "lastFrame") {
|
||||||
|
stats.HasAnyMedia = true
|
||||||
|
stats.HasLastFrame = true
|
||||||
|
}
|
||||||
|
if hasAnyString(body, "reference_image", "referenceImage") {
|
||||||
|
stats.HasAnyMedia = true
|
||||||
|
stats.ReferenceImages++
|
||||||
|
}
|
||||||
|
if hasAnyString(body, "video", "video_url", "videoUrl", "reference_video", "referenceVideo") {
|
||||||
|
stats.HasAnyMedia = true
|
||||||
|
stats.HasReferenceVideo = true
|
||||||
|
}
|
||||||
|
if hasAnyString(body, "audio_url", "audioUrl", "reference_audio", "referenceAudio") {
|
||||||
|
stats.HasAnyMedia = true
|
||||||
|
stats.HasReferenceAudio = true
|
||||||
|
}
|
||||||
|
if hasAnyString(body, "image", "images", "image_url", "imageUrl", "image_urls", "imageUrls") {
|
||||||
|
stats.HasAnyMedia = true
|
||||||
|
if !stats.HasFirstFrame && !stats.HasExplicitContent {
|
||||||
|
stats.HasFirstFrame = true
|
||||||
|
} else {
|
||||||
|
stats.ReferenceImages++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return stats
|
||||||
|
}
|
||||||
|
|
||||||
|
func candidateCapabilityFilterMetrics(summary map[string]any) map[string]any {
|
||||||
|
if len(summary) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return map[string]any{"candidateCapabilityFilter": summary}
|
||||||
|
}
|
||||||
191
apps/api/internal/runner/candidate_filter_test.go
Normal file
191
apps/api/internal/runner/candidate_filter_test.go
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
package runner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFilterRuntimeCandidatesByRequestResolutionKeepsSupportedCandidate(t *testing.T) {
|
||||||
|
candidates := []store.RuntimeModelCandidate{
|
||||||
|
candidateWithResolutions("low", "720p"),
|
||||||
|
candidateWithResolutions("high", "1080p"),
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered, summary, err := filterRuntimeCandidatesByRequest("videos.generations", "demo-video", "video_generate", map[string]any{
|
||||||
|
"resolution": "1080p",
|
||||||
|
}, candidates)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("filter should keep a supported candidate: %v", err)
|
||||||
|
}
|
||||||
|
if len(filtered) != 1 || filtered[0].PlatformKey != "high" {
|
||||||
|
t.Fatalf("expected only high resolution candidate, got %+v", filtered)
|
||||||
|
}
|
||||||
|
if summary["filteredCandidateCount"] != 1 || summary["supportedCandidateCount"] != 1 {
|
||||||
|
t.Fatalf("unexpected filter summary: %+v", summary)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilterRuntimeCandidatesByScopedVideoResolution(t *testing.T) {
|
||||||
|
candidates := []store.RuntimeModelCandidate{
|
||||||
|
{
|
||||||
|
PlatformID: "platform-first",
|
||||||
|
PlatformKey: "first",
|
||||||
|
PlatformName: "First Frame Platform",
|
||||||
|
PlatformModelID: "model-first",
|
||||||
|
ModelName: "demo-video",
|
||||||
|
ModelType: "image_to_video",
|
||||||
|
Capabilities: map[string]any{
|
||||||
|
"image_to_video": map[string]any{
|
||||||
|
"output_resolutions": map[string]any{
|
||||||
|
"input_first_frame": []any{"1080p"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PlatformID: "platform-first-last",
|
||||||
|
PlatformKey: "first-last",
|
||||||
|
PlatformName: "First Last Platform",
|
||||||
|
PlatformModelID: "model-first-last",
|
||||||
|
ModelName: "demo-video",
|
||||||
|
ModelType: "image_to_video",
|
||||||
|
Capabilities: map[string]any{
|
||||||
|
"image_to_video": map[string]any{
|
||||||
|
"output_resolutions": map[string]any{
|
||||||
|
"input_first_last_frame": []any{"1080p"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered, _, err := filterRuntimeCandidatesByRequest("videos.generations", "demo-video", "image_to_video", map[string]any{
|
||||||
|
"resolution": "1080p",
|
||||||
|
"content": []any{
|
||||||
|
map[string]any{"type": "image_url", "role": "first_frame", "image_url": map[string]any{"url": "https://example.com/first.png"}},
|
||||||
|
map[string]any{"type": "image_url", "role": "last_frame", "image_url": map[string]any{"url": "https://example.com/last.png"}},
|
||||||
|
},
|
||||||
|
}, candidates)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("filter should keep first-last scoped candidate: %v", err)
|
||||||
|
}
|
||||||
|
if len(filtered) != 1 || filtered[0].PlatformKey != "first-last" {
|
||||||
|
t.Fatalf("expected first-last scoped candidate only, got %+v", filtered)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilterRuntimeCandidatesByRequestResolutionFailsWithDetails(t *testing.T) {
|
||||||
|
candidates := []store.RuntimeModelCandidate{
|
||||||
|
candidateWithImageResolutions("jimeng-v3", "1K", "2K"),
|
||||||
|
candidateWithImageResolutions("jimeng-v4", "1K"),
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered, summary, err := filterRuntimeCandidatesByRequest("images.generations", "demo-image", "image_generate", map[string]any{
|
||||||
|
"resolution": "4K",
|
||||||
|
}, candidates)
|
||||||
|
if len(filtered) != 0 {
|
||||||
|
t.Fatalf("expected no candidates, got %+v", filtered)
|
||||||
|
}
|
||||||
|
var candidateErr *store.ModelCandidateUnavailableError
|
||||||
|
if !errors.As(err, &candidateErr) {
|
||||||
|
t.Fatalf("expected model candidate error, got %T %v", err, err)
|
||||||
|
}
|
||||||
|
if candidateErr.Code != unsupportedRequestResolutionCode {
|
||||||
|
t.Fatalf("unexpected error code: %s", candidateErr.Code)
|
||||||
|
}
|
||||||
|
if !strings.Contains(candidateErr.Message, "4K") {
|
||||||
|
t.Fatalf("message should include requested resolution, got %q", candidateErr.Message)
|
||||||
|
}
|
||||||
|
if summary["filteredCandidateCount"] != 2 || candidateErr.Details["requestedResolution"] != "4K" {
|
||||||
|
t.Fatalf("unexpected filter detail summary=%+v details=%+v", summary, candidateErr.Details)
|
||||||
|
}
|
||||||
|
if details := store.ModelCandidateErrorDetails(err); details["requestedResolution"] != "4K" {
|
||||||
|
t.Fatalf("store detail helper should expose requested resolution, got %+v", details)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilterRuntimeCandidatesSkipsPixelSizeCompatibility(t *testing.T) {
|
||||||
|
candidates := []store.RuntimeModelCandidate{{
|
||||||
|
PlatformID: "openai",
|
||||||
|
PlatformKey: "openai",
|
||||||
|
PlatformModelID: "gpt-image-1",
|
||||||
|
ModelName: "gpt-image-1",
|
||||||
|
ModelType: "image_generate",
|
||||||
|
Capabilities: map[string]any{
|
||||||
|
"image_generate": map[string]any{
|
||||||
|
"aspect_ratio_allowed": []any{"1:1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
|
||||||
|
filtered, summary, err := filterRuntimeCandidatesByRequest("images.generations", "gpt-image-1", "image_generate", map[string]any{
|
||||||
|
"size": "1024x1024",
|
||||||
|
}, candidates)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("pixel size compatibility should skip resolution filtering: %v", err)
|
||||||
|
}
|
||||||
|
if len(filtered) != 1 || summary != nil {
|
||||||
|
t.Fatalf("expected unchanged candidates and no summary, got filtered=%+v summary=%+v", filtered, summary)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildFailureResultIncludesModelCandidateDetails(t *testing.T) {
|
||||||
|
cause := &store.ModelCandidateUnavailableError{
|
||||||
|
Code: unsupportedRequestResolutionCode,
|
||||||
|
Message: "unsupported resolution",
|
||||||
|
Details: map[string]any{
|
||||||
|
"requestedResolution": "4K",
|
||||||
|
"candidateCount": 2,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := buildFailureResult(store.ModelCandidateErrorCode(cause), cause.Error(), "", cause)
|
||||||
|
errorPayload, _ := result["error"].(map[string]any)
|
||||||
|
modelCandidate, _ := errorPayload["modelCandidate"].(map[string]any)
|
||||||
|
if errorPayload["code"] != unsupportedRequestResolutionCode || modelCandidate["requestedResolution"] != "4K" {
|
||||||
|
t.Fatalf("failure result should persist candidate details, got %+v", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func candidateWithResolutions(platformKey string, resolutions ...string) store.RuntimeModelCandidate {
|
||||||
|
return store.RuntimeModelCandidate{
|
||||||
|
PlatformID: "platform-" + platformKey,
|
||||||
|
PlatformKey: platformKey,
|
||||||
|
PlatformName: "Platform " + platformKey,
|
||||||
|
PlatformModelID: "model-" + platformKey,
|
||||||
|
ModelName: "demo-video",
|
||||||
|
ModelType: "video_generate",
|
||||||
|
Capabilities: map[string]any{
|
||||||
|
"video_generate": map[string]any{
|
||||||
|
"output_resolutions": stringsToAny(resolutions),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func candidateWithImageResolutions(platformKey string, resolutions ...string) store.RuntimeModelCandidate {
|
||||||
|
return store.RuntimeModelCandidate{
|
||||||
|
PlatformID: "platform-" + platformKey,
|
||||||
|
PlatformKey: platformKey,
|
||||||
|
PlatformName: "Platform " + platformKey,
|
||||||
|
PlatformModelID: "model-" + platformKey,
|
||||||
|
ModelName: "demo-image",
|
||||||
|
ModelType: "image_generate",
|
||||||
|
Capabilities: map[string]any{
|
||||||
|
"image_generate": map[string]any{
|
||||||
|
"output_resolutions": stringsToAny(resolutions),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func stringsToAny(values []string) []any {
|
||||||
|
out := make([]any, 0, len(values))
|
||||||
|
for _, value := range values {
|
||||||
|
out = append(out, value)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
196
apps/api/internal/runner/param_processor_script.go
Normal file
196
apps/api/internal/runner/param_processor_script.go
Normal file
@ -0,0 +1,196 @@
|
|||||||
|
package runner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
scriptengine "github.com/easyai/easyai-ai-gateway/apps/api/internal/script"
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Service) preprocessRequestWithScripts(ctx context.Context, kind string, body map[string]any, candidate store.RuntimeModelCandidate) parameterPreprocessResult {
|
||||||
|
if platformConfigBool(candidate.PlatformConfig, "skipParamNormalization", "skip_param_normalization") {
|
||||||
|
modelType := strings.TrimSpace(candidate.ModelType)
|
||||||
|
if modelType == "" {
|
||||||
|
modelType = modelTypeFromKind(kind, body)
|
||||||
|
}
|
||||||
|
input := cloneMap(body)
|
||||||
|
return parameterPreprocessResult{
|
||||||
|
Body: cloneMap(body),
|
||||||
|
Log: parameterPreprocessingLog{
|
||||||
|
ModelType: modelType,
|
||||||
|
Input: input,
|
||||||
|
Output: cloneMap(body),
|
||||||
|
Changed: false,
|
||||||
|
Changes: []parameterPreprocessChange{},
|
||||||
|
Model: preprocessingModelSnapshot(candidate),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result := preprocessRequestWithLog(kind, body, candidate)
|
||||||
|
if result.Err != nil {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
scriptText := platformConfigString(candidate.PlatformConfig, "customPreprocessScript", "custom_preprocess_script")
|
||||||
|
if strings.TrimSpace(scriptText) == "" || s.scriptExecutor == nil {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
before := cloneMap(result.Body)
|
||||||
|
scriptContext := s.scriptContext(candidate, result.Log.ModelType, nil, map[string]any{
|
||||||
|
"modelCapability": effectiveModelCapability(candidate),
|
||||||
|
"platformModel": result.Log.Model,
|
||||||
|
"platform": candidate.PlatformConfig,
|
||||||
|
})
|
||||||
|
out, err := s.scriptExecutor.Execute(ctx, scriptengine.Options{
|
||||||
|
Script: scriptText,
|
||||||
|
Args: []any{cloneMap(result.Body), result.Log.ModelType, scriptContext},
|
||||||
|
ContextData: scriptContext,
|
||||||
|
ScriptName: "custom_preprocess_script:" + result.Log.ModelType,
|
||||||
|
PreferredEntryNames: []string{"preprocessParams", "preprocess", "main", "handler"},
|
||||||
|
Timeout: scriptengine.PreprocessTimeout,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
result.Log.recordScriptChange("CustomPreprocessScript", "error", "$", before, result.Body, err.Error())
|
||||||
|
result.Log.Output = cloneMap(result.Body)
|
||||||
|
result.Log.Changed = len(result.Log.Changes) > 0
|
||||||
|
result.Err = err
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
rewritten, ok := out.(map[string]any)
|
||||||
|
if !ok || rewritten == nil {
|
||||||
|
result.Log.Output = cloneMap(result.Body)
|
||||||
|
result.Log.Changed = len(result.Log.Changes) > 0
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
merged := cloneMap(result.Body)
|
||||||
|
for key, value := range rewritten {
|
||||||
|
merged[key] = value
|
||||||
|
}
|
||||||
|
if !mapsEqual(before, merged) {
|
||||||
|
result.Log.recordScriptChange("CustomPreprocessScript", "rewrite", "$", before, merged, "platform custom preprocess script returned parameter updates")
|
||||||
|
}
|
||||||
|
result.Body = merged
|
||||||
|
result.Log.Output = cloneMap(merged)
|
||||||
|
result.Log.Changed = len(result.Log.Changes) > 0
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) scriptContext(candidate store.RuntimeModelCandidate, modelType string, payload map[string]any, extra map[string]any) map[string]any {
|
||||||
|
getTaskURL := platformConfigString(candidate.PlatformConfig, "getTaskURL", "get_task_url")
|
||||||
|
baseURL := strings.TrimRight(strings.TrimSpace(candidate.BaseURL), "/")
|
||||||
|
env := cloneMap(candidate.PlatformConfig)
|
||||||
|
context := map[string]any{
|
||||||
|
"__easyaiScriptContext": true,
|
||||||
|
"baseURL": baseURL,
|
||||||
|
"getTaskURL": getTaskURL,
|
||||||
|
"authValues": cloneMap(candidate.Credentials),
|
||||||
|
"headers": map[string]any{},
|
||||||
|
"payload": cloneMap(payload),
|
||||||
|
"type": modelType,
|
||||||
|
"options": map[string]any{
|
||||||
|
"model": candidate.ModelName,
|
||||||
|
"providerModelName": candidate.ProviderModelName,
|
||||||
|
"platformId": candidate.PlatformID,
|
||||||
|
"platformModelId": candidate.PlatformModelID,
|
||||||
|
"canonicalModelKey": candidate.CanonicalModelKey,
|
||||||
|
"sourceProviderCode": candidate.Provider,
|
||||||
|
},
|
||||||
|
"env": env,
|
||||||
|
"candidate": preprocessingModelSnapshot(candidate),
|
||||||
|
}
|
||||||
|
context["createRequestURL"] = func(path string, base ...string) string {
|
||||||
|
selectedBase := baseURL
|
||||||
|
if len(base) > 0 && strings.TrimSpace(base[0]) != "" {
|
||||||
|
selectedBase = strings.TrimRight(strings.TrimSpace(base[0]), "/")
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") {
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
return selectedBase + "/" + strings.TrimLeft(path, "/")
|
||||||
|
}
|
||||||
|
context["creatRequestURL"] = context["createRequestURL"]
|
||||||
|
context["resolveGetTaskURL"] = func(taskID string) string {
|
||||||
|
return resolveTaskURLTemplate(getTaskURL, taskID, "")
|
||||||
|
}
|
||||||
|
for key, value := range extra {
|
||||||
|
context[key] = value
|
||||||
|
}
|
||||||
|
return context
|
||||||
|
}
|
||||||
|
|
||||||
|
func preprocessingModelSnapshot(candidate store.RuntimeModelCandidate) map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"modelName": candidate.ModelName,
|
||||||
|
"modelAlias": candidate.ModelAlias,
|
||||||
|
"providerModelName": candidate.ProviderModelName,
|
||||||
|
"provider": candidate.Provider,
|
||||||
|
"platformId": candidate.PlatformID,
|
||||||
|
"platformModelId": candidate.PlatformModelID,
|
||||||
|
"capabilities": cloneMap(candidate.Capabilities),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (log *parameterPreprocessingLog) recordScriptChange(processor string, action string, path string, before any, after any, reason string) {
|
||||||
|
if log == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Changes = append(log.Changes, parameterPreprocessChange{
|
||||||
|
Processor: processor,
|
||||||
|
Action: action,
|
||||||
|
Path: path,
|
||||||
|
Before: cloneAny(before),
|
||||||
|
After: cloneAny(after),
|
||||||
|
Reason: reason,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func platformConfigString(config map[string]any, keys ...string) string {
|
||||||
|
for _, key := range keys {
|
||||||
|
if value := strings.TrimSpace(fmt.Sprint(config[key])); value != "" && value != "<nil>" {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func platformConfigBool(config map[string]any, keys ...string) bool {
|
||||||
|
for _, key := range keys {
|
||||||
|
switch value := config[key].(type) {
|
||||||
|
case bool:
|
||||||
|
return value
|
||||||
|
case string:
|
||||||
|
return strings.EqualFold(strings.TrimSpace(value), "true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveTaskURLTemplate(template string, upstreamTaskID string, taskID string) string {
|
||||||
|
out := strings.TrimSpace(template)
|
||||||
|
replacements := [][2]string{
|
||||||
|
{"${upstream_task_id}", upstreamTaskID},
|
||||||
|
{"{{upstream_task_id}}", upstreamTaskID},
|
||||||
|
{"{upstream_task_id}", upstreamTaskID},
|
||||||
|
{"${task_id}", taskID},
|
||||||
|
{"{{task_id}}", taskID},
|
||||||
|
{"{task_id}", taskID},
|
||||||
|
{"${taskId}", upstreamTaskID},
|
||||||
|
{"${taskID}", upstreamTaskID},
|
||||||
|
{"{{taskId}}", upstreamTaskID},
|
||||||
|
{"{{taskID}}", upstreamTaskID},
|
||||||
|
{"{taskId}", upstreamTaskID},
|
||||||
|
{"{taskID}", upstreamTaskID},
|
||||||
|
}
|
||||||
|
for _, replacement := range replacements {
|
||||||
|
out = strings.ReplaceAll(out, replacement[0], replacement[1])
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapsEqual(left map[string]any, right map[string]any) bool {
|
||||||
|
return reflect.DeepEqual(left, right)
|
||||||
|
}
|
||||||
64
apps/api/internal/runner/param_processor_script_test.go
Normal file
64
apps/api/internal/runner/param_processor_script_test.go
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
package runner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
scriptengine "github.com/easyai/easyai-ai-gateway/apps/api/internal/script"
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPreprocessRequestWithCustomScript(t *testing.T) {
|
||||||
|
service := &Service{scriptExecutor: &scriptengine.Executor{}}
|
||||||
|
candidate := store.RuntimeModelCandidate{
|
||||||
|
Provider: "universal",
|
||||||
|
ModelName: "image-model",
|
||||||
|
ModelType: "image_generate",
|
||||||
|
Capabilities: map[string]any{
|
||||||
|
"image_generate": map[string]any{"max_output_images": 4},
|
||||||
|
},
|
||||||
|
PlatformConfig: map[string]any{
|
||||||
|
"customPreprocessScript": `(params, type, context) => {
|
||||||
|
return { prompt: params.prompt + "-" + type, n: 2, provider: context.candidate.provider };
|
||||||
|
}`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := service.preprocessRequestWithScripts(context.Background(), "images.generations", map[string]any{"prompt": "hello", "n": 8}, candidate)
|
||||||
|
if result.Err != nil {
|
||||||
|
t.Fatalf("unexpected preprocess error: %v", result.Err)
|
||||||
|
}
|
||||||
|
if result.Body["prompt"] != "hello-image_generate" || result.Body["n"].(float64) != 2 {
|
||||||
|
t.Fatalf("unexpected body: %#v", result.Body)
|
||||||
|
}
|
||||||
|
if !result.Log.Changed || len(result.Log.Changes) == 0 {
|
||||||
|
t.Fatalf("expected script change in log: %#v", result.Log)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPreprocessRequestSkipParamNormalizationSkipsCustomScript(t *testing.T) {
|
||||||
|
service := &Service{scriptExecutor: &scriptengine.Executor{}}
|
||||||
|
candidate := store.RuntimeModelCandidate{
|
||||||
|
ModelName: "image-model",
|
||||||
|
ModelType: "image_generate",
|
||||||
|
Provider: "universal",
|
||||||
|
Capabilities: map[string]any{
|
||||||
|
"image_generate": map[string]any{"max_output_images": 1},
|
||||||
|
},
|
||||||
|
PlatformConfig: map[string]any{
|
||||||
|
"skipParamNormalization": true,
|
||||||
|
"customPreprocessScript": `(params) => ({ prompt: "changed", n: 1 })`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := service.preprocessRequestWithScripts(context.Background(), "images.generations", map[string]any{"prompt": "hello", "n": 9}, candidate)
|
||||||
|
if result.Err != nil {
|
||||||
|
t.Fatalf("unexpected preprocess error: %v", result.Err)
|
||||||
|
}
|
||||||
|
if result.Body["prompt"] != "hello" || result.Body["n"].(int) != 9 {
|
||||||
|
t.Fatalf("skip should keep raw body, got %#v", result.Body)
|
||||||
|
}
|
||||||
|
if result.Log.Changed || len(result.Log.Changes) != 0 {
|
||||||
|
t.Fatalf("skip should not record changes: %#v", result.Log)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -470,7 +470,7 @@ func isEmptyParamString(value string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func isImageResolution(modelType string, value string) bool {
|
func isImageResolution(modelType string, value string) bool {
|
||||||
return (modelType == "image_generate" || modelType == "image_edit") && containsString([]string{"1K", "2K", "4K", "8K"}, value)
|
return (modelType == "image_generate" || modelType == "image_edit") && containsString([]string{"1K", "2K", "3K", "4K", "8K"}, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
func isVideoResolution(modelType string, value string) bool {
|
func isVideoResolution(modelType string, value string) bool {
|
||||||
|
|||||||
@ -19,7 +19,12 @@ type EstimateResult struct {
|
|||||||
|
|
||||||
func (s *Service) Estimate(ctx context.Context, kind string, model string, body map[string]any, user *auth.User) (EstimateResult, error) {
|
func (s *Service) Estimate(ctx context.Context, kind string, model string, body map[string]any, user *auth.User) (EstimateResult, error) {
|
||||||
body = normalizeRequest(kind, body)
|
body = normalizeRequest(kind, body)
|
||||||
candidates, err := s.store.ListModelCandidates(ctx, model, modelTypeFromKind(kind, body), user)
|
modelType := modelTypeFromKind(kind, body)
|
||||||
|
candidates, err := s.store.ListModelCandidates(ctx, model, modelType, user)
|
||||||
|
if err != nil {
|
||||||
|
return EstimateResult{}, err
|
||||||
|
}
|
||||||
|
candidates, _, err = filterRuntimeCandidatesByRequest(kind, model, modelType, body, candidates)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return EstimateResult{}, err
|
return EstimateResult{}, err
|
||||||
}
|
}
|
||||||
@ -77,19 +82,23 @@ func (s *Service) billings(ctx context.Context, user *auth.User, kind string, bo
|
|||||||
resource = "video"
|
resource = "video"
|
||||||
unit = "5s_video"
|
unit = "5s_video"
|
||||||
baseKey = "videoBase"
|
baseKey = "videoBase"
|
||||||
duration := requestDurationSeconds(body)
|
duration, durationSource := billingDurationSeconds(body, response)
|
||||||
|
audioEnabled, audioSource := billingAudioEnabled(body, response)
|
||||||
durationUnits := math.Max(1, math.Ceil(duration/5))
|
durationUnits := math.Max(1, math.Ceil(duration/5))
|
||||||
amount := float64(count) *
|
amount := float64(count) *
|
||||||
durationUnits *
|
durationUnits *
|
||||||
resourcePrice(config, resource, baseKey, "basePrice") *
|
resourcePrice(config, resource, baseKey, "basePrice") *
|
||||||
resourceWeight(config, resource, "resolutionWeights", firstNonEmptyString(stringFromMap(body, "resolution"), stringFromMap(body, "size"))) *
|
resourceWeight(config, resource, "resolutionWeights", firstNonEmptyString(stringFromMap(body, "resolution"), stringFromMap(body, "size"))) *
|
||||||
resourceWeight(config, resource, "audioWeights", boolWeightKey(boolishValue(body["audio"]))) *
|
resourceWeight(config, resource, "audioWeights", boolWeightKey(audioEnabled)) *
|
||||||
resourceWeight(config, resource, "referenceVideoWeights", boolWeightKey(requestHasReferenceVideo(body))) *
|
resourceWeight(config, resource, "referenceVideoWeights", boolWeightKey(requestHasReferenceVideo(body))) *
|
||||||
resourceWeight(config, resource, "voiceSpecifiedWeights", boolWeightKey(requestHasVoiceID(body))) *
|
resourceWeight(config, resource, "voiceSpecifiedWeights", boolWeightKey(requestHasVoiceID(body, audioEnabled))) *
|
||||||
discount
|
discount
|
||||||
return []any{billingLineWithDetails(candidate, resource, unit, count*int(durationUnits), roundPrice(amount), discount, simulated, map[string]any{
|
return []any{billingLineWithDetails(candidate, resource, unit, count*int(durationUnits), roundPrice(amount), discount, simulated, map[string]any{
|
||||||
"count": count,
|
"count": count,
|
||||||
|
"audio": audioEnabled,
|
||||||
|
"audioSource": audioSource,
|
||||||
"durationSeconds": duration,
|
"durationSeconds": duration,
|
||||||
|
"durationSource": durationSource,
|
||||||
"durationUnit": "5s",
|
"durationUnit": "5s",
|
||||||
"durationUnitCount": durationUnits,
|
"durationUnitCount": durationUnits,
|
||||||
})}
|
})}
|
||||||
@ -340,6 +349,54 @@ func requestDurationSeconds(body map[string]any) float64 {
|
|||||||
return 5
|
return 5
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func billingDurationSeconds(body map[string]any, response clients.Response) (float64, string) {
|
||||||
|
if duration, ok := generatedVideoDurationSeconds(response.Result); ok {
|
||||||
|
return duration, "generated_video"
|
||||||
|
}
|
||||||
|
return requestDurationSeconds(body), "preprocessed_request"
|
||||||
|
}
|
||||||
|
|
||||||
|
func generatedVideoDurationSeconds(result map[string]any) (float64, bool) {
|
||||||
|
data, _ := result["data"].([]any)
|
||||||
|
for _, raw := range data {
|
||||||
|
item, _ := raw.(map[string]any)
|
||||||
|
if len(item) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
duration := floatFromAny(item["duration"])
|
||||||
|
if duration <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rounded := math.Round(duration)
|
||||||
|
if rounded <= 0 {
|
||||||
|
rounded = 1
|
||||||
|
}
|
||||||
|
return rounded, true
|
||||||
|
}
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func billingAudioEnabled(body map[string]any, response clients.Response) (bool, string) {
|
||||||
|
if value, ok := generatedVideoHasAudio(response.Result); ok {
|
||||||
|
return value, "generated_video"
|
||||||
|
}
|
||||||
|
return boolishValue(body["audio"]), "preprocessed_request"
|
||||||
|
}
|
||||||
|
|
||||||
|
func generatedVideoHasAudio(result map[string]any) (bool, bool) {
|
||||||
|
data, _ := result["data"].([]any)
|
||||||
|
for _, raw := range data {
|
||||||
|
item, _ := raw.(map[string]any)
|
||||||
|
if len(item) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if value, ok := boolishOptional(firstPresentValue(item, "has_audio", "hasAudio")); ok {
|
||||||
|
return value, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
|
||||||
func requestHasReferenceVideo(body map[string]any) bool {
|
func requestHasReferenceVideo(body map[string]any) bool {
|
||||||
if hasNonEmptyArray(body["video_list"]) || hasNonEmptyArray(body["videoList"]) {
|
if hasNonEmptyArray(body["video_list"]) || hasNonEmptyArray(body["videoList"]) {
|
||||||
return true
|
return true
|
||||||
@ -362,8 +419,8 @@ func requestHasReferenceVideo(body map[string]any) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func requestHasVoiceID(body map[string]any) bool {
|
func requestHasVoiceID(body map[string]any, audioEnabled bool) bool {
|
||||||
return boolishValue(body["audio"]) && firstNonEmptyStringValue(body, "voice_id", "voiceId") != ""
|
return audioEnabled && firstNonEmptyStringValue(body, "voice_id", "voiceId") != ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func boolWeightKey(value bool) string {
|
func boolWeightKey(value bool) string {
|
||||||
@ -374,25 +431,38 @@ func boolWeightKey(value bool) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func boolishValue(value any) bool {
|
func boolishValue(value any) bool {
|
||||||
|
result, _ := boolishOptional(value)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func boolishOptional(value any) (bool, bool) {
|
||||||
switch typed := value.(type) {
|
switch typed := value.(type) {
|
||||||
case bool:
|
case bool:
|
||||||
return typed
|
return typed, true
|
||||||
case string:
|
case string:
|
||||||
switch strings.ToLower(strings.TrimSpace(typed)) {
|
switch strings.ToLower(strings.TrimSpace(typed)) {
|
||||||
case "true", "1", "yes", "on":
|
case "true", "1", "yes", "on":
|
||||||
return true
|
return true, true
|
||||||
default:
|
case "false", "0", "no", "off":
|
||||||
return false
|
return false, true
|
||||||
}
|
}
|
||||||
case int:
|
case int:
|
||||||
return typed != 0
|
return typed != 0, true
|
||||||
case int64:
|
case int64:
|
||||||
return typed != 0
|
return typed != 0, true
|
||||||
case float64:
|
case float64:
|
||||||
return typed != 0
|
return typed != 0, true
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func firstPresentValue(record map[string]any, keys ...string) any {
|
||||||
|
for _, key := range keys {
|
||||||
|
if value, ok := record[key]; ok {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func hasNonEmptyArray(value any) bool {
|
func hasNonEmptyArray(value any) bool {
|
||||||
|
|||||||
@ -76,6 +76,100 @@ func TestVideoBillingEstimateUsesFiveSecondUnitsAndDynamicWeights(t *testing.T)
|
|||||||
if got, want := line["quantity"], 3; got != want {
|
if got, want := line["quantity"], 3; got != want {
|
||||||
t.Fatalf("video quantity = %v, want %v", got, want)
|
t.Fatalf("video quantity = %v, want %v", got, want)
|
||||||
}
|
}
|
||||||
|
if got, want := line["durationSource"], "preprocessed_request"; got != want {
|
||||||
|
t.Fatalf("video duration source = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := line["audioSource"], "preprocessed_request"; got != want {
|
||||||
|
t.Fatalf("video audio source = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVideoBillingPrefersGeneratedDuration(t *testing.T) {
|
||||||
|
service := &Service{}
|
||||||
|
candidate := store.RuntimeModelCandidate{
|
||||||
|
ModelName: "video-model",
|
||||||
|
BaseBillingConfig: map[string]any{
|
||||||
|
"video": map[string]any{"basePrice": 100},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
items := service.billings(context.Background(), nil, "videos.generations", map[string]any{
|
||||||
|
"duration": 12,
|
||||||
|
"resolution": "720p",
|
||||||
|
}, candidate, clients.Response{
|
||||||
|
Result: map[string]any{
|
||||||
|
"data": []any{map[string]any{"type": "video", "duration": 6.6}},
|
||||||
|
},
|
||||||
|
}, false)
|
||||||
|
|
||||||
|
line := firstBillingLine(t, items)
|
||||||
|
if got, want := floatFromAny(line["durationSeconds"]), 7.0; got != want {
|
||||||
|
t.Fatalf("video generated duration = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := floatFromAny(line["durationUnitCount"]), 2.0; got != want {
|
||||||
|
t.Fatalf("video generated duration units = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := floatFromAny(line["amount"]), 200.0; got != want {
|
||||||
|
t.Fatalf("video generated duration amount = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := line["durationSource"], "generated_video"; got != want {
|
||||||
|
t.Fatalf("video duration source = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVideoBillingPrefersGeneratedAudio(t *testing.T) {
|
||||||
|
service := &Service{}
|
||||||
|
candidate := store.RuntimeModelCandidate{
|
||||||
|
ModelName: "video-model",
|
||||||
|
BaseBillingConfig: map[string]any{
|
||||||
|
"video": map[string]any{
|
||||||
|
"basePrice": 100,
|
||||||
|
"audioWeights": map[string]any{"true": 2, "false": 0.5},
|
||||||
|
"voiceSpecifiedWeights": map[string]any{"true": 4},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
items := service.billings(context.Background(), nil, "videos.generations", map[string]any{
|
||||||
|
"audio": false,
|
||||||
|
"duration": 5,
|
||||||
|
}, candidate, clients.Response{
|
||||||
|
Result: map[string]any{
|
||||||
|
"data": []any{map[string]any{"type": "video", "has_audio": true}},
|
||||||
|
},
|
||||||
|
}, false)
|
||||||
|
|
||||||
|
line := firstBillingLine(t, items)
|
||||||
|
if got, want := floatFromAny(line["amount"]), 200.0; got != want {
|
||||||
|
t.Fatalf("video generated audio amount = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := line["audio"], true; got != want {
|
||||||
|
t.Fatalf("video generated audio = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := line["audioSource"], "generated_video"; got != want {
|
||||||
|
t.Fatalf("video audio source = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
items = service.billings(context.Background(), nil, "videos.generations", map[string]any{
|
||||||
|
"audio": true,
|
||||||
|
"duration": 5,
|
||||||
|
"voice_id": "voice-a",
|
||||||
|
}, candidate, clients.Response{
|
||||||
|
Result: map[string]any{
|
||||||
|
"data": []any{map[string]any{"type": "video", "hasAudio": false}},
|
||||||
|
},
|
||||||
|
}, false)
|
||||||
|
|
||||||
|
line = firstBillingLine(t, items)
|
||||||
|
if got, want := floatFromAny(line["amount"]), 50.0; got != want {
|
||||||
|
t.Fatalf("video generated no-audio amount = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := line["audio"], false; got != want {
|
||||||
|
t.Fatalf("video generated no-audio = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := line["audioSource"], "generated_video"; got != want {
|
||||||
|
t.Fatalf("video no-audio source = %v, want %v", got, want)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestVideoBillingEstimateSupportsServerMainStyleDynamicKeys(t *testing.T) {
|
func TestVideoBillingEstimateSupportsServerMainStyleDynamicKeys(t *testing.T) {
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
package runner
|
package runner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -210,6 +211,12 @@ func failureMetrics(err error, simulated bool) (string, map[string]any, time.Tim
|
|||||||
metrics["error"] = err.Error()
|
metrics["error"] = err.Error()
|
||||||
metrics["errorCategory"] = info.Category
|
metrics["errorCategory"] = info.Category
|
||||||
metrics["retryable"] = retryable
|
metrics["retryable"] = retryable
|
||||||
|
if detail := rateLimitFailureDetail(err); len(detail) > 0 {
|
||||||
|
metrics["rateLimit"] = detail
|
||||||
|
}
|
||||||
|
if detail := store.ModelCandidateErrorDetails(err); len(detail) > 0 {
|
||||||
|
metrics["modelCandidate"] = detail
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if meta.StatusCode > 0 {
|
if meta.StatusCode > 0 {
|
||||||
metrics["statusCode"] = meta.StatusCode
|
metrics["statusCode"] = meta.StatusCode
|
||||||
@ -226,6 +233,64 @@ func failureMetrics(err error, simulated bool) (string, map[string]any, time.Tim
|
|||||||
return meta.RequestID, metrics, meta.ResponseStartedAt, meta.ResponseFinishedAt, meta.ResponseDurationMS
|
return meta.RequestID, metrics, meta.ResponseStartedAt, meta.ResponseFinishedAt, meta.ResponseDurationMS
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildFailureResult(code string, message string, requestID string, err error) map[string]any {
|
||||||
|
errorPayload := map[string]any{
|
||||||
|
"code": code,
|
||||||
|
"message": message,
|
||||||
|
}
|
||||||
|
if requestID != "" {
|
||||||
|
errorPayload["requestId"] = requestID
|
||||||
|
}
|
||||||
|
if detail := rateLimitFailureDetail(err); len(detail) > 0 {
|
||||||
|
errorPayload["rateLimit"] = detail
|
||||||
|
}
|
||||||
|
if detail := store.ModelCandidateErrorDetails(err); len(detail) > 0 {
|
||||||
|
errorPayload["modelCandidate"] = detail
|
||||||
|
}
|
||||||
|
return map[string]any{"error": errorPayload}
|
||||||
|
}
|
||||||
|
|
||||||
|
func rateLimitFailureDetail(err error) map[string]any {
|
||||||
|
var limitErr *store.RateLimitExceededError
|
||||||
|
if !errors.As(err, &limitErr) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
detail := map[string]any{
|
||||||
|
"scopeType": limitErr.ScopeType,
|
||||||
|
"scopeKey": limitErr.ScopeKey,
|
||||||
|
"scopeName": limitErr.ScopeName,
|
||||||
|
"metric": limitErr.Metric,
|
||||||
|
"limit": limitErr.Limit,
|
||||||
|
"amount": limitErr.Amount,
|
||||||
|
"current": limitErr.Current,
|
||||||
|
"used": limitErr.Used,
|
||||||
|
"reserved": limitErr.Reserved,
|
||||||
|
"projected": limitErr.Projected,
|
||||||
|
"windowSeconds": limitErr.WindowSeconds,
|
||||||
|
"retryable": limitErr.Retryable,
|
||||||
|
"exceeded": map[string]any{
|
||||||
|
"metric": limitErr.Metric,
|
||||||
|
"current": limitErr.Current,
|
||||||
|
"amount": limitErr.Amount,
|
||||||
|
"projected": limitErr.Projected,
|
||||||
|
"limit": limitErr.Limit,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if limitErr.RetryAfter > 0 {
|
||||||
|
detail["retryAfterMs"] = limitErr.RetryAfter.Milliseconds()
|
||||||
|
}
|
||||||
|
if !limitErr.ResetAt.IsZero() {
|
||||||
|
detail["resetAt"] = limitErr.ResetAt.UTC().Format(time.RFC3339Nano)
|
||||||
|
}
|
||||||
|
if len(limitErr.ScopeMetadata) > 0 {
|
||||||
|
detail["scopeMetadata"] = limitErr.ScopeMetadata
|
||||||
|
}
|
||||||
|
if len(limitErr.Policy) > 0 {
|
||||||
|
detail["rateLimitPolicy"] = limitErr.Policy
|
||||||
|
}
|
||||||
|
return detail
|
||||||
|
}
|
||||||
|
|
||||||
func mergeMetrics(values ...map[string]any) map[string]any {
|
func mergeMetrics(values ...map[string]any) map[string]any {
|
||||||
out := map[string]any{}
|
out := map[string]any{}
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
|
|||||||
@ -12,6 +12,7 @@ import (
|
|||||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
||||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
|
||||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
|
||||||
|
scriptengine "github.com/easyai/easyai-ai-gateway/apps/api/internal/script"
|
||||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||||
"github.com/jackc/pgx/v5"
|
"github.com/jackc/pgx/v5"
|
||||||
"github.com/riverqueue/river"
|
"github.com/riverqueue/river"
|
||||||
@ -22,6 +23,7 @@ type Service struct {
|
|||||||
store *store.Store
|
store *store.Store
|
||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
clients map[string]clients.Client
|
clients map[string]clients.Client
|
||||||
|
scriptExecutor *scriptengine.Executor
|
||||||
httpClients *httpClientCache
|
httpClients *httpClientCache
|
||||||
riverClient *river.Client[pgx.Tx]
|
riverClient *river.Client[pgx.Tx]
|
||||||
}
|
}
|
||||||
@ -47,14 +49,28 @@ func (e *TaskQueuedError) Is(target error) bool {
|
|||||||
|
|
||||||
func New(cfg config.Config, db *store.Store, logger *slog.Logger) *Service {
|
func New(cfg config.Config, db *store.Store, logger *slog.Logger) *Service {
|
||||||
httpClients := newHTTPClientCache()
|
httpClients := newHTTPClientCache()
|
||||||
|
scriptExecutor := &scriptengine.Executor{Logger: logger}
|
||||||
return &Service{
|
return &Service{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
store: db,
|
store: db,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
scriptExecutor: scriptExecutor,
|
||||||
clients: map[string]clients.Client{
|
clients: map[string]clients.Client{
|
||||||
"openai": clients.OpenAIClient{HTTPClient: httpClients.none},
|
"openai": clients.OpenAIClient{HTTPClient: httpClients.none},
|
||||||
|
"aliyun-bailian": clients.AliyunBailianClient{HTTPClient: httpClients.none},
|
||||||
|
"blackforest": clients.BlackforestClient{HTTPClient: httpClients.none},
|
||||||
"gemini": clients.GeminiClient{HTTPClient: httpClients.none},
|
"gemini": clients.GeminiClient{HTTPClient: httpClients.none},
|
||||||
|
"jimeng": clients.JimengClient{HTTPClient: httpClients.none},
|
||||||
|
"midjourney": clients.MidjourneyClient{HTTPClient: httpClients.none},
|
||||||
|
"minimax": clients.MinimaxClient{HTTPClient: httpClients.none},
|
||||||
|
"newapi": clients.NewAPIClient{HTTPClient: httpClients.none},
|
||||||
|
"tencent-hunyuan-image": clients.HunyuanImageClient{HTTPClient: httpClients.none},
|
||||||
|
"tencent-hunyuan-video": clients.HunyuanVideoClient{HTTPClient: httpClients.none},
|
||||||
|
"vidu": clients.ViduClient{HTTPClient: httpClients.none},
|
||||||
"volces": clients.VolcesClient{HTTPClient: httpClients.none},
|
"volces": clients.VolcesClient{HTTPClient: httpClients.none},
|
||||||
|
"keling": clients.KelingClient{HTTPClient: httpClients.none},
|
||||||
|
"kling": clients.KelingClient{HTTPClient: httpClients.none},
|
||||||
|
"universal": clients.UniversalClient{HTTPClient: httpClients.none, ScriptExecutor: scriptExecutor},
|
||||||
"simulation": clients.SimulationClient{},
|
"simulation": clients.SimulationClient{},
|
||||||
},
|
},
|
||||||
httpClients: httpClients,
|
httpClients: httpClients,
|
||||||
@ -82,6 +98,17 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := validateRequest(task.Kind, body); err != nil {
|
if err := validateRequest(task.Kind, body); err != nil {
|
||||||
|
s.recordFailedAttempt(ctx, failedAttemptRecord{
|
||||||
|
Task: task,
|
||||||
|
Body: body,
|
||||||
|
AttemptNo: task.AttemptCount + 1,
|
||||||
|
Code: "bad_request",
|
||||||
|
Cause: err,
|
||||||
|
Simulated: task.RunMode == "simulation",
|
||||||
|
Scope: "request_validation",
|
||||||
|
Reason: "request_validation_failed",
|
||||||
|
ModelType: modelType,
|
||||||
|
})
|
||||||
failed, finishErr := s.failTask(ctx, task.ID, "bad_request", err.Error(), task.RunMode == "simulation", err)
|
failed, finishErr := s.failTask(ctx, task.ID, "bad_request", err.Error(), task.RunMode == "simulation", err)
|
||||||
if finishErr != nil {
|
if finishErr != nil {
|
||||||
return Result{}, finishErr
|
return Result{}, finishErr
|
||||||
@ -90,25 +117,77 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
|
|||||||
}
|
}
|
||||||
candidates, err := s.store.ListModelCandidates(ctx, task.Model, modelType, user)
|
candidates, err := s.store.ListModelCandidates(ctx, task.Model, modelType, user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
s.recordFailedAttempt(ctx, failedAttemptRecord{
|
||||||
|
Task: task,
|
||||||
|
Body: body,
|
||||||
|
AttemptNo: task.AttemptCount + 1,
|
||||||
|
Code: store.ModelCandidateErrorCode(err),
|
||||||
|
Cause: err,
|
||||||
|
Simulated: task.RunMode == "simulation",
|
||||||
|
Scope: "candidate_selection",
|
||||||
|
Reason: "candidate_selection_failed",
|
||||||
|
ModelType: modelType,
|
||||||
|
})
|
||||||
failed, finishErr := s.failTask(ctx, task.ID, store.ModelCandidateErrorCode(err), err.Error(), task.RunMode == "simulation", err)
|
failed, finishErr := s.failTask(ctx, task.ID, store.ModelCandidateErrorCode(err), err.Error(), task.RunMode == "simulation", err)
|
||||||
if finishErr != nil {
|
if finishErr != nil {
|
||||||
return Result{}, finishErr
|
return Result{}, finishErr
|
||||||
}
|
}
|
||||||
return Result{Task: failed, Output: failed.Result}, err
|
return Result{Task: failed, Output: failed.Result}, err
|
||||||
}
|
}
|
||||||
|
var candidateFilterSummary map[string]any
|
||||||
|
candidates, candidateFilterSummary, err = filterRuntimeCandidatesByRequest(task.Kind, task.Model, modelType, body, candidates)
|
||||||
|
if err != nil {
|
||||||
|
candidateFilterMetrics := candidateCapabilityFilterMetrics(candidateFilterSummary)
|
||||||
|
s.recordFailedAttempt(ctx, failedAttemptRecord{
|
||||||
|
Task: task,
|
||||||
|
Body: body,
|
||||||
|
AttemptNo: task.AttemptCount + 1,
|
||||||
|
Code: store.ModelCandidateErrorCode(err),
|
||||||
|
Cause: err,
|
||||||
|
Simulated: task.RunMode == "simulation",
|
||||||
|
Scope: "candidate_request_filter",
|
||||||
|
Reason: store.ModelCandidateErrorCode(err),
|
||||||
|
ExtraMetrics: []map[string]any{candidateFilterMetrics},
|
||||||
|
ModelType: modelType,
|
||||||
|
})
|
||||||
|
failed, finishErr := s.failTask(ctx, task.ID, store.ModelCandidateErrorCode(err), err.Error(), task.RunMode == "simulation", err, candidateFilterMetrics)
|
||||||
|
if finishErr != nil {
|
||||||
|
return Result{}, finishErr
|
||||||
|
}
|
||||||
|
return Result{Task: failed, Output: failed.Result}, err
|
||||||
|
}
|
||||||
firstCandidateBody := body
|
firstCandidateBody := body
|
||||||
normalizedModelType := modelType
|
normalizedModelType := modelType
|
||||||
|
attemptNo := task.AttemptCount
|
||||||
var firstPreprocessing parameterPreprocessingLog
|
var firstPreprocessing parameterPreprocessingLog
|
||||||
|
var walletReservations []store.WalletBillingReservation
|
||||||
|
walletReservationFinalized := false
|
||||||
|
defer func() {
|
||||||
|
if !walletReservationFinalized && len(walletReservations) > 0 {
|
||||||
|
_ = s.store.ReleaseTaskBillingReservations(context.WithoutCancel(ctx), walletReservations, "task_not_settled")
|
||||||
|
}
|
||||||
|
}()
|
||||||
if len(candidates) > 0 {
|
if len(candidates) > 0 {
|
||||||
preprocessing := preprocessRequestWithLog(task.Kind, body, candidates[0])
|
preprocessing := s.preprocessRequestWithScripts(ctx, task.Kind, body, candidates[0])
|
||||||
firstCandidateBody = preprocessing.Body
|
firstCandidateBody = preprocessing.Body
|
||||||
firstPreprocessing = preprocessing.Log
|
firstPreprocessing = preprocessing.Log
|
||||||
normalizedModelType = candidates[0].ModelType
|
normalizedModelType = candidates[0].ModelType
|
||||||
if preprocessing.Err != nil {
|
if preprocessing.Err != nil {
|
||||||
clientErr := parameterPreprocessClientError(preprocessing.Err)
|
clientErr := parameterPreprocessClientError(preprocessing.Err)
|
||||||
if logErr := s.recordTaskParameterPreprocessing(ctx, task.ID, "", 0, candidates[0], firstPreprocessing); logErr != nil {
|
attemptNo = s.recordFailedAttempt(ctx, failedAttemptRecord{
|
||||||
return Result{}, logErr
|
Task: task,
|
||||||
}
|
Body: firstCandidateBody,
|
||||||
|
Candidate: &candidates[0],
|
||||||
|
AttemptNo: attemptNo + 1,
|
||||||
|
Code: clients.ErrorCode(clientErr),
|
||||||
|
Cause: clientErr,
|
||||||
|
Simulated: task.RunMode == "simulation",
|
||||||
|
Scope: "parameter_preprocessing",
|
||||||
|
Reason: "parameter_preprocessing_failed",
|
||||||
|
ExtraMetrics: []map[string]any{parameterPreprocessingMetrics(firstPreprocessing)},
|
||||||
|
Preprocessing: &firstPreprocessing,
|
||||||
|
ModelType: normalizedModelType,
|
||||||
|
})
|
||||||
failed, finishErr := s.failTask(ctx, task.ID, clients.ErrorCode(clientErr), clientErr.Error(), task.RunMode == "simulation", clientErr, parameterPreprocessingMetrics(firstPreprocessing))
|
failed, finishErr := s.failTask(ctx, task.ID, clients.ErrorCode(clientErr), clientErr.Error(), task.RunMode == "simulation", clientErr, parameterPreprocessingMetrics(firstPreprocessing))
|
||||||
if finishErr != nil {
|
if finishErr != nil {
|
||||||
return Result{}, finishErr
|
return Result{}, finishErr
|
||||||
@ -119,18 +198,31 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
|
|||||||
return Result{}, err
|
return Result{}, err
|
||||||
}
|
}
|
||||||
estimatedBillings := s.estimatedBillings(ctx, user, task.Kind, firstCandidateBody, candidates[0])
|
estimatedBillings := s.estimatedBillings(ctx, user, task.Kind, firstCandidateBody, candidates[0])
|
||||||
if err := s.ensureWalletBalance(ctx, user, estimatedBillings); err != nil {
|
var reserveErr error
|
||||||
if errors.Is(err, store.ErrInsufficientWalletBalance) {
|
walletReservations, reserveErr = s.store.ReserveTaskBilling(ctx, task, user, estimatedBillings)
|
||||||
if logErr := s.recordTaskParameterPreprocessing(ctx, task.ID, "", 0, candidates[0], firstPreprocessing); logErr != nil {
|
if reserveErr != nil {
|
||||||
return Result{}, logErr
|
if errors.Is(reserveErr, store.ErrInsufficientWalletBalance) {
|
||||||
}
|
attemptNo = s.recordFailedAttempt(ctx, failedAttemptRecord{
|
||||||
failed, finishErr := s.failTask(ctx, task.ID, "insufficient_balance", err.Error(), task.RunMode == "simulation", err, parameterPreprocessingMetrics(firstPreprocessing))
|
Task: task,
|
||||||
|
Body: firstCandidateBody,
|
||||||
|
Candidate: &candidates[0],
|
||||||
|
AttemptNo: attemptNo + 1,
|
||||||
|
Code: "insufficient_balance",
|
||||||
|
Cause: reserveErr,
|
||||||
|
Simulated: task.RunMode == "simulation",
|
||||||
|
Scope: "wallet_balance",
|
||||||
|
Reason: "wallet_balance_check_failed",
|
||||||
|
ExtraMetrics: []map[string]any{parameterPreprocessingMetrics(firstPreprocessing)},
|
||||||
|
Preprocessing: &firstPreprocessing,
|
||||||
|
ModelType: normalizedModelType,
|
||||||
|
})
|
||||||
|
failed, finishErr := s.failTask(ctx, task.ID, "insufficient_balance", reserveErr.Error(), task.RunMode == "simulation", reserveErr, parameterPreprocessingMetrics(firstPreprocessing))
|
||||||
if finishErr != nil {
|
if finishErr != nil {
|
||||||
return Result{}, finishErr
|
return Result{}, finishErr
|
||||||
}
|
}
|
||||||
return Result{Task: failed, Output: failed.Result}, err
|
return Result{Task: failed, Output: failed.Result}, reserveErr
|
||||||
}
|
}
|
||||||
return Result{}, err
|
return Result{}, reserveErr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := s.emit(ctx, task.ID, "task.progress", "running", "normalizing", 0.15, "request normalized", map[string]any{"modelType": normalizedModelType}, task.RunMode == "simulation"); err != nil {
|
if err := s.emit(ctx, task.ID, "task.progress", "running", "normalizing", 0.15, "request normalized", map[string]any{"modelType": normalizedModelType}, task.RunMode == "simulation"); err != nil {
|
||||||
@ -143,7 +235,6 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
|
|||||||
}
|
}
|
||||||
maxPlatforms := maxPlatformsForCandidates(candidates, runnerPolicy)
|
maxPlatforms := maxPlatformsForCandidates(candidates, runnerPolicy)
|
||||||
maxFailoverDuration := maxFailoverDurationForCandidates(candidates, runnerPolicy)
|
maxFailoverDuration := maxFailoverDurationForCandidates(candidates, runnerPolicy)
|
||||||
attemptNo := task.AttemptCount
|
|
||||||
var lastErr error
|
var lastErr error
|
||||||
var lastCandidate store.RuntimeModelCandidate
|
var lastCandidate store.RuntimeModelCandidate
|
||||||
var lastPreprocessing *parameterPreprocessingLog
|
var lastPreprocessing *parameterPreprocessingLog
|
||||||
@ -157,11 +248,25 @@ candidatesLoop:
|
|||||||
var candidateErr error
|
var candidateErr error
|
||||||
for clientAttempt := 1; clientAttempt <= clientAttempts; clientAttempt++ {
|
for clientAttempt := 1; clientAttempt <= clientAttempts; clientAttempt++ {
|
||||||
nextAttemptNo := attemptNo + 1
|
nextAttemptNo := attemptNo + 1
|
||||||
preprocessing := preprocessRequestWithLog(task.Kind, body, candidate)
|
preprocessing := s.preprocessRequestWithScripts(ctx, task.Kind, body, candidate)
|
||||||
preprocessingLog := preprocessing.Log
|
preprocessingLog := preprocessing.Log
|
||||||
lastPreprocessing = &preprocessingLog
|
lastPreprocessing = &preprocessingLog
|
||||||
if preprocessing.Err != nil {
|
if preprocessing.Err != nil {
|
||||||
lastErr = parameterPreprocessClientError(preprocessing.Err)
|
lastErr = parameterPreprocessClientError(preprocessing.Err)
|
||||||
|
attemptNo = s.recordFailedAttempt(ctx, failedAttemptRecord{
|
||||||
|
Task: task,
|
||||||
|
Body: preprocessing.Body,
|
||||||
|
Candidate: &candidate,
|
||||||
|
AttemptNo: nextAttemptNo,
|
||||||
|
Code: clients.ErrorCode(lastErr),
|
||||||
|
Cause: lastErr,
|
||||||
|
Simulated: isSimulation(task, candidate),
|
||||||
|
Scope: "parameter_preprocessing",
|
||||||
|
Reason: "parameter_preprocessing_failed",
|
||||||
|
ExtraMetrics: []map[string]any{parameterPreprocessingMetrics(preprocessingLog)},
|
||||||
|
Preprocessing: &preprocessingLog,
|
||||||
|
ModelType: candidate.ModelType,
|
||||||
|
})
|
||||||
break candidatesLoop
|
break candidatesLoop
|
||||||
}
|
}
|
||||||
candidateBody := preprocessing.Body
|
candidateBody := preprocessing.Body
|
||||||
@ -170,6 +275,7 @@ candidatesLoop:
|
|||||||
attemptNo = nextAttemptNo
|
attemptNo = nextAttemptNo
|
||||||
billings := s.billings(ctx, user, task.Kind, candidateBody, candidate, response, isSimulation(task, candidate))
|
billings := s.billings(ctx, user, task.Kind, candidateBody, candidate, response, isSimulation(task, candidate))
|
||||||
record := buildSuccessRecord(task, user, candidateBody, candidate, response, billings, isSimulation(task, candidate))
|
record := buildSuccessRecord(task, user, candidateBody, candidate, response, billings, isSimulation(task, candidate))
|
||||||
|
record.Metrics = mergeMetrics(record.Metrics, candidateCapabilityFilterMetrics(candidateFilterSummary))
|
||||||
record.Metrics = mergeMetrics(record.Metrics, parameterPreprocessingMetrics(preprocessing.Log))
|
record.Metrics = mergeMetrics(record.Metrics, parameterPreprocessingMetrics(preprocessing.Log))
|
||||||
record.Metrics = s.withAttemptHistory(ctx, task.ID, record.Metrics)
|
record.Metrics = s.withAttemptHistory(ctx, task.ID, record.Metrics)
|
||||||
finished, finishErr := s.store.FinishTaskSuccess(ctx, store.FinishTaskSuccessInput{
|
finished, finishErr := s.store.FinishTaskSuccess(ctx, store.FinishTaskSuccessInput{
|
||||||
@ -189,9 +295,18 @@ candidatesLoop:
|
|||||||
if finishErr != nil {
|
if finishErr != nil {
|
||||||
return Result{}, finishErr
|
return Result{}, finishErr
|
||||||
}
|
}
|
||||||
|
if finished.FinalChargeAmount > 0 {
|
||||||
|
walletReservationFinalized = true
|
||||||
if settleErr := s.store.SettleTaskBilling(ctx, finished); settleErr != nil {
|
if settleErr := s.store.SettleTaskBilling(ctx, finished); settleErr != nil {
|
||||||
return Result{}, settleErr
|
return Result{}, settleErr
|
||||||
}
|
}
|
||||||
|
} else if len(walletReservations) > 0 {
|
||||||
|
if releaseErr := s.store.ReleaseTaskBillingReservations(ctx, walletReservations, "task_billing_zero"); releaseErr != nil {
|
||||||
|
return Result{}, releaseErr
|
||||||
|
}
|
||||||
|
walletReservationFinalized = true
|
||||||
|
}
|
||||||
|
walletReservationFinalized = true
|
||||||
if finished.FinalChargeAmount > 0 {
|
if finished.FinalChargeAmount > 0 {
|
||||||
if err := s.emit(ctx, task.ID, "task.billing.settled", "succeeded", "billing", 0.98, "task billing settled", map[string]any{
|
if err := s.emit(ctx, task.ID, "task.billing.settled", "succeeded", "billing", 0.98, "task billing settled", map[string]any{
|
||||||
"amount": finished.FinalChargeAmount,
|
"amount": finished.FinalChargeAmount,
|
||||||
@ -222,6 +337,19 @@ candidatesLoop:
|
|||||||
}
|
}
|
||||||
return Result{Task: queued, Output: queued.Result}, &TaskQueuedError{Delay: delay}
|
return Result{Task: queued, Output: queued.Result}, &TaskQueuedError{Delay: delay}
|
||||||
}
|
}
|
||||||
|
attemptNo = s.recordFailedAttempt(ctx, failedAttemptRecord{
|
||||||
|
Task: task,
|
||||||
|
Body: candidateBody,
|
||||||
|
Candidate: &candidate,
|
||||||
|
AttemptNo: nextAttemptNo,
|
||||||
|
Code: clients.ErrorCode(err),
|
||||||
|
Cause: err,
|
||||||
|
Simulated: isSimulation(task, candidate),
|
||||||
|
Scope: "rate_limit",
|
||||||
|
Reason: "local_rate_limit_blocked",
|
||||||
|
ExtraMetrics: []map[string]any{parameterPreprocessingMetrics(preprocessing.Log)},
|
||||||
|
ModelType: candidate.ModelType,
|
||||||
|
})
|
||||||
break candidatesLoop
|
break candidatesLoop
|
||||||
}
|
}
|
||||||
attemptNo = nextAttemptNo
|
attemptNo = nextAttemptNo
|
||||||
@ -520,6 +648,7 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user
|
|||||||
return clients.Response{}, err
|
return clients.Response{}, err
|
||||||
}
|
}
|
||||||
response.Result = uploadedResult
|
response.Result = uploadedResult
|
||||||
|
response.Result = s.enrichGeneratedVideoMetadata(ctx, task.Kind, response.Result)
|
||||||
for _, progress := range response.Progress {
|
for _, progress := range response.Progress {
|
||||||
if err := s.emit(ctx, task.ID, "task.progress", "running", progress.Phase, progress.Progress, progress.Message, progress.Payload, simulated); err != nil {
|
if err := s.emit(ctx, task.ID, "task.progress", "running", progress.Phase, progress.Progress, progress.Message, progress.Payload, simulated); err != nil {
|
||||||
return clients.Response{}, fmt.Errorf("emit task progress: %w", err)
|
return clients.Response{}, fmt.Errorf("emit task progress: %w", err)
|
||||||
@ -584,6 +713,11 @@ func (s *Service) clientFor(candidate store.RuntimeModelCandidate, simulated boo
|
|||||||
if key == "" {
|
if key == "" {
|
||||||
key = strings.ToLower(strings.TrimSpace(candidate.Provider))
|
key = strings.ToLower(strings.TrimSpace(candidate.Provider))
|
||||||
}
|
}
|
||||||
|
provider := strings.ToLower(strings.TrimSpace(candidate.Provider))
|
||||||
|
baseURL := strings.ToLower(strings.TrimSpace(candidate.BaseURL))
|
||||||
|
if key == "google-gemini" || provider == "gemini" || provider == "google-gemini" || provider == "gemini-openai" || strings.Contains(baseURL, "generativelanguage.googleapis.com") {
|
||||||
|
key = "gemini"
|
||||||
|
}
|
||||||
if client, ok := s.clients[key]; ok {
|
if client, ok := s.clients[key]; ok {
|
||||||
return client
|
return client
|
||||||
}
|
}
|
||||||
@ -601,6 +735,7 @@ func (s *Service) failTask(ctx context.Context, taskID string, code string, mess
|
|||||||
TaskID: taskID,
|
TaskID: taskID,
|
||||||
Code: code,
|
Code: code,
|
||||||
Message: message,
|
Message: message,
|
||||||
|
Result: buildFailureResult(code, message, requestID, cause),
|
||||||
RequestID: requestID,
|
RequestID: requestID,
|
||||||
Metrics: metrics,
|
Metrics: metrics,
|
||||||
ResponseStartedAt: responseStartedAt,
|
ResponseStartedAt: responseStartedAt,
|
||||||
@ -616,6 +751,110 @@ func (s *Service) failTask(ctx context.Context, taskID string, code string, mess
|
|||||||
return failed, nil
|
return failed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type failedAttemptRecord struct {
|
||||||
|
Task store.GatewayTask
|
||||||
|
Body map[string]any
|
||||||
|
Candidate *store.RuntimeModelCandidate
|
||||||
|
AttemptNo int
|
||||||
|
Code string
|
||||||
|
Cause error
|
||||||
|
Simulated bool
|
||||||
|
Scope string
|
||||||
|
Reason string
|
||||||
|
ExtraMetrics []map[string]any
|
||||||
|
Preprocessing *parameterPreprocessingLog
|
||||||
|
ModelType string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) recordFailedAttempt(ctx context.Context, input failedAttemptRecord) int {
|
||||||
|
attemptNo := input.AttemptNo
|
||||||
|
if attemptNo <= 0 {
|
||||||
|
attemptNo = input.Task.AttemptCount + 1
|
||||||
|
}
|
||||||
|
code := firstNonEmptyString(input.Code, clients.ErrorCode(input.Cause))
|
||||||
|
message := ""
|
||||||
|
if input.Cause != nil {
|
||||||
|
message = input.Cause.Error()
|
||||||
|
}
|
||||||
|
retryable := clients.IsRetryable(input.Cause)
|
||||||
|
requestID, failure, responseStartedAt, responseFinishedAt, responseDurationMS := failureMetrics(input.Cause, input.Simulated)
|
||||||
|
scope := firstNonEmptyString(input.Scope, "pre_provider")
|
||||||
|
reason := firstNonEmptyString(input.Reason, "pre_provider_failed")
|
||||||
|
trace := failureTraceEntryWithReason(input.Cause, retryable, scope, reason)
|
||||||
|
statusCode := clients.ErrorResponseMetadata(input.Cause).StatusCode
|
||||||
|
category := failureCategory(strings.ToLower(strings.TrimSpace(code)), statusCode, message)
|
||||||
|
if code != "" {
|
||||||
|
failure["errorCode"] = code
|
||||||
|
trace["errorCode"] = code
|
||||||
|
}
|
||||||
|
if category != "" {
|
||||||
|
failure["errorCategory"] = category
|
||||||
|
trace["category"] = category
|
||||||
|
}
|
||||||
|
failure["failureScope"] = scope
|
||||||
|
failure["failureReason"] = reason
|
||||||
|
failure["trace"] = []any{trace}
|
||||||
|
|
||||||
|
baseMetrics := map[string]any{
|
||||||
|
"attempt": attemptNo,
|
||||||
|
"kind": input.Task.Kind,
|
||||||
|
"runMode": input.Task.RunMode,
|
||||||
|
"requestedModel": input.Task.Model,
|
||||||
|
"simulated": input.Simulated,
|
||||||
|
}
|
||||||
|
if input.ModelType != "" {
|
||||||
|
baseMetrics["modelType"] = input.ModelType
|
||||||
|
}
|
||||||
|
var platformID, platformModelID, clientID, queueKey string
|
||||||
|
if input.Candidate != nil {
|
||||||
|
baseMetrics = attemptMetrics(*input.Candidate, attemptNo, input.Simulated)
|
||||||
|
baseMetrics["kind"] = input.Task.Kind
|
||||||
|
baseMetrics["runMode"] = input.Task.RunMode
|
||||||
|
baseMetrics["requestedModel"] = input.Task.Model
|
||||||
|
platformID = input.Candidate.PlatformID
|
||||||
|
platformModelID = input.Candidate.PlatformModelID
|
||||||
|
clientID = input.Candidate.ClientID
|
||||||
|
queueKey = input.Candidate.QueueKey
|
||||||
|
}
|
||||||
|
metrics := mergeMetrics(append([]map[string]any{baseMetrics, failure}, input.ExtraMetrics...)...)
|
||||||
|
attemptID, err := s.store.CreateTaskAttempt(ctx, store.CreateTaskAttemptInput{
|
||||||
|
TaskID: input.Task.ID,
|
||||||
|
AttemptNo: attemptNo,
|
||||||
|
PlatformID: platformID,
|
||||||
|
PlatformModelID: platformModelID,
|
||||||
|
ClientID: clientID,
|
||||||
|
QueueKey: queueKey,
|
||||||
|
Status: "running",
|
||||||
|
Simulated: input.Simulated,
|
||||||
|
RequestSnapshot: input.Body,
|
||||||
|
Metrics: metrics,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Warn("record failed task attempt failed", "taskID", input.Task.ID, "attempt", attemptNo, "error", err)
|
||||||
|
return attemptNo
|
||||||
|
}
|
||||||
|
if input.Preprocessing != nil && input.Candidate != nil {
|
||||||
|
if err := s.recordTaskParameterPreprocessing(ctx, input.Task.ID, attemptID, attemptNo, *input.Candidate, *input.Preprocessing); err != nil {
|
||||||
|
s.logger.Warn("record failed attempt parameter preprocessing failed", "taskID", input.Task.ID, "attempt", attemptNo, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{
|
||||||
|
AttemptID: attemptID,
|
||||||
|
Status: "failed",
|
||||||
|
Retryable: retryable,
|
||||||
|
RequestID: requestID,
|
||||||
|
Metrics: metrics,
|
||||||
|
ResponseStartedAt: responseStartedAt,
|
||||||
|
ResponseFinishedAt: responseFinishedAt,
|
||||||
|
ResponseDurationMS: responseDurationMS,
|
||||||
|
ErrorCode: code,
|
||||||
|
ErrorMessage: message,
|
||||||
|
}); err != nil {
|
||||||
|
s.logger.Warn("finish failed task attempt failed", "taskID", input.Task.ID, "attempt", attemptNo, "error", err)
|
||||||
|
}
|
||||||
|
return attemptNo
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Service) requeueRateLimitedTask(ctx context.Context, task store.GatewayTask, cause error, candidate store.RuntimeModelCandidate) (store.GatewayTask, time.Duration, error) {
|
func (s *Service) requeueRateLimitedTask(ctx context.Context, task store.GatewayTask, cause error, candidate store.RuntimeModelCandidate) (store.GatewayTask, time.Duration, error) {
|
||||||
delay := localRateLimitRetryAfter(cause)
|
delay := localRateLimitRetryAfter(cause)
|
||||||
if delay <= 0 {
|
if delay <= 0 {
|
||||||
@ -888,8 +1127,13 @@ func parameterPreprocessClientError(err error) *clients.ClientError {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
code := "invalid_parameter"
|
||||||
|
var coded interface{ ErrorCode() string }
|
||||||
|
if errors.As(err, &coded) && strings.TrimSpace(coded.ErrorCode()) != "" {
|
||||||
|
code = coded.ErrorCode()
|
||||||
|
}
|
||||||
return &clients.ClientError{
|
return &clients.ClientError{
|
||||||
Code: "invalid_parameter",
|
Code: code,
|
||||||
Message: err.Error(),
|
Message: err.Error(),
|
||||||
StatusCode: 400,
|
StatusCode: 400,
|
||||||
Retryable: false,
|
Retryable: false,
|
||||||
|
|||||||
34
apps/api/internal/runner/service_test.go
Normal file
34
apps/api/internal/runner/service_test.go
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
package runner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||||
|
)
|
||||||
|
|
||||||
|
type namedClient string
|
||||||
|
|
||||||
|
func (namedClient) Run(context.Context, clients.Request) (clients.Response, error) {
|
||||||
|
return clients.Response{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientForMapsGoogleGeminiSpecToGeminiClient(t *testing.T) {
|
||||||
|
service := &Service{clients: map[string]clients.Client{
|
||||||
|
"gemini": namedClient("gemini"),
|
||||||
|
"openai": namedClient("openai"),
|
||||||
|
}}
|
||||||
|
|
||||||
|
candidates := []store.RuntimeModelCandidate{
|
||||||
|
{SpecType: "google-gemini"},
|
||||||
|
{SpecType: "openai", Provider: "gemini-openai"},
|
||||||
|
{SpecType: "openai", BaseURL: "https://generativelanguage.googleapis.com/v1beta/openai"},
|
||||||
|
}
|
||||||
|
for _, candidate := range candidates {
|
||||||
|
client := service.clientFor(candidate, false)
|
||||||
|
if client != namedClient("gemini") {
|
||||||
|
t.Fatalf("Gemini candidate should use gemini client, candidate=%+v got %T %[2]v", candidate, client)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -7,8 +7,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func failureTraceEntry(err error, retryable bool) map[string]any {
|
func failureTraceEntry(err error, retryable bool) map[string]any {
|
||||||
|
return failureTraceEntryWithReason(err, retryable, "client", "client_call_failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func failureTraceEntryWithReason(err error, retryable bool, scope string, reason string) map[string]any {
|
||||||
info := failureInfoFromError(err)
|
info := failureInfoFromError(err)
|
||||||
entry := policyTraceEntry("failure", "client", "failed", "client_call_failed", policyRuleMatch{}, info)
|
entry := policyTraceEntry("failure", scope, "failed", reason, policyRuleMatch{}, info)
|
||||||
entry["retryable"] = retryable
|
entry["retryable"] = retryable
|
||||||
return entry
|
return entry
|
||||||
}
|
}
|
||||||
|
|||||||
144
apps/api/internal/runner/video_duration.go
Normal file
144
apps/api/internal/runner/video_duration.go
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
package runner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"os/exec"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const generatedVideoMetadataProbeTimeout = 8 * time.Second
|
||||||
|
|
||||||
|
type generatedVideoMetadata struct {
|
||||||
|
Duration float64
|
||||||
|
HasAudio bool
|
||||||
|
HasAudioKnown bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type ffprobeVideoMetadata struct {
|
||||||
|
Format struct {
|
||||||
|
Duration string `json:"duration"`
|
||||||
|
} `json:"format"`
|
||||||
|
Streams []struct {
|
||||||
|
CodecType string `json:"codec_type"`
|
||||||
|
} `json:"streams"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) enrichGeneratedVideoMetadata(ctx context.Context, taskKind string, result map[string]any) map[string]any {
|
||||||
|
if taskKind != "videos.generations" {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
data, _ := result["data"].([]any)
|
||||||
|
if len(data) == 0 {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
for _, raw := range data {
|
||||||
|
item, _ := raw.(map[string]any)
|
||||||
|
if len(item) == 0 || !isGeneratedVideoItem(item) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
needsDuration := floatFromAny(item["duration"]) <= 0
|
||||||
|
_, hasAudioMetadata := boolishOptional(firstPresentValue(item, "has_audio", "hasAudio"))
|
||||||
|
if !needsDuration && hasAudioMetadata {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
urlValue := firstNonEmptyStringValue(item, "video_url", "videoUrl", "url")
|
||||||
|
if urlValue == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
metadata, err := s.probeVideoMetadata(ctx, urlValue)
|
||||||
|
if err != nil {
|
||||||
|
if s.logger != nil {
|
||||||
|
s.logger.Debug("probe generated video metadata failed", "url", trimForLog(urlValue), "error", err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if needsDuration && metadata.Duration > 0 {
|
||||||
|
item["duration"] = metadata.Duration
|
||||||
|
}
|
||||||
|
if !hasAudioMetadata && metadata.HasAudioKnown {
|
||||||
|
item["has_audio"] = metadata.HasAudio
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func isGeneratedVideoItem(item map[string]any) bool {
|
||||||
|
itemType := strings.TrimSpace(stringFromAny(item["type"]))
|
||||||
|
if itemType == "video" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if firstNonEmptyStringValue(item, "video_url", "videoUrl") != "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
urlValue := strings.ToLower(firstNonEmptyStringValue(item, "url"))
|
||||||
|
return strings.Contains(urlValue, ".mp4") ||
|
||||||
|
strings.Contains(urlValue, ".mov") ||
|
||||||
|
strings.Contains(urlValue, ".webm") ||
|
||||||
|
strings.Contains(urlValue, ".m3u8")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) probeVideoMetadata(ctx context.Context, rawURL string) (generatedVideoMetadata, error) {
|
||||||
|
if _, err := exec.LookPath("ffprobe"); err != nil {
|
||||||
|
return generatedVideoMetadata{}, err
|
||||||
|
}
|
||||||
|
probeURL := rawURL
|
||||||
|
if s != nil {
|
||||||
|
if resolved, err := s.generatedAssetFetchURL(rawURL); err == nil && strings.TrimSpace(resolved) != "" {
|
||||||
|
probeURL = resolved
|
||||||
|
}
|
||||||
|
}
|
||||||
|
probeCtx, cancel := context.WithTimeout(ctx, generatedVideoMetadataProbeTimeout)
|
||||||
|
defer cancel()
|
||||||
|
cmd := exec.CommandContext(
|
||||||
|
probeCtx,
|
||||||
|
"ffprobe",
|
||||||
|
"-v", "error",
|
||||||
|
"-show_entries", "format=duration:stream=codec_type",
|
||||||
|
"-of", "json",
|
||||||
|
probeURL,
|
||||||
|
)
|
||||||
|
output, err := cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return generatedVideoMetadata{}, err
|
||||||
|
}
|
||||||
|
var probed ffprobeVideoMetadata
|
||||||
|
if err := json.Unmarshal(output, &probed); err != nil {
|
||||||
|
return generatedVideoMetadata{}, err
|
||||||
|
}
|
||||||
|
metadata := generatedVideoMetadata{}
|
||||||
|
if durationText := strings.TrimSpace(probed.Format.Duration); durationText != "" {
|
||||||
|
if duration, err := strconv.ParseFloat(durationText, 64); err == nil && duration > 0 && !math.IsNaN(duration) && !math.IsInf(duration, 0) {
|
||||||
|
rounded := math.Round(duration)
|
||||||
|
if rounded <= 0 {
|
||||||
|
rounded = 1
|
||||||
|
}
|
||||||
|
metadata.Duration = rounded
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if probed.Streams != nil {
|
||||||
|
metadata.HasAudioKnown = true
|
||||||
|
for _, stream := range probed.Streams {
|
||||||
|
if strings.TrimSpace(stream.CodecType) == "audio" {
|
||||||
|
metadata.HasAudio = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if metadata.Duration <= 0 && !metadata.HasAudioKnown {
|
||||||
|
return metadata, fmt.Errorf("invalid video metadata: %q", trimForLog(string(output)))
|
||||||
|
}
|
||||||
|
return metadata, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func trimForLog(value string) string {
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
if len(value) <= 120 {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
return value[:120] + "..."
|
||||||
|
}
|
||||||
@ -1,38 +0,0 @@
|
|||||||
package runner
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
|
||||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (s *Service) ensureWalletBalance(ctx context.Context, user *auth.User, billings []any) error {
|
|
||||||
amounts := map[string]float64{}
|
|
||||||
for _, raw := range billings {
|
|
||||||
line, _ := raw.(map[string]any)
|
|
||||||
if line == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
currency := strings.TrimSpace(stringFromAny(line["currency"]))
|
|
||||||
if currency == "" {
|
|
||||||
currency = "resource"
|
|
||||||
}
|
|
||||||
amounts[currency] = roundPrice(amounts[currency] + floatFromAny(line["amount"]))
|
|
||||||
}
|
|
||||||
for currency, amount := range amounts {
|
|
||||||
if amount <= 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
availability, err := s.store.WalletAvailability(ctx, user, currency, amount)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !availability.Enough {
|
|
||||||
return fmt.Errorf("%w: required %.6f %s, available %.6f", store.ErrInsufficientWalletBalance, amount, currency, availability.AvailableAmount)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
202
apps/api/internal/runner/wallet_execute_test.go
Normal file
202
apps/api/internal/runner/wallet_execute_test.go
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
package runner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||||
|
)
|
||||||
|
|
||||||
|
type walletExecuteMockClient struct {
|
||||||
|
calls atomic.Int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (client *walletExecuteMockClient) Run(context.Context, clients.Request) (clients.Response, error) {
|
||||||
|
client.calls.Add(1)
|
||||||
|
return clients.Response{
|
||||||
|
Result: map[string]any{"mock": true},
|
||||||
|
RequestID: "mock-wallet-execute",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecuteWithMockClientRejectsConcurrentTasksBeyondWalletBalance(t *testing.T) {
|
||||||
|
databaseURL := strings.TrimSpace(os.Getenv("AI_GATEWAY_TEST_DATABASE_URL"))
|
||||||
|
if databaseURL == "" {
|
||||||
|
t.Skip("set AI_GATEWAY_TEST_DATABASE_URL to run the wallet execute integration test")
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
db, err := store.Connect(ctx, databaseURL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("connect store: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(db.Close)
|
||||||
|
|
||||||
|
suffix := strconv.FormatInt(time.Now().UnixNano(), 10)
|
||||||
|
tenant, err := db.CreateTenant(ctx, store.GatewayTenantInput{
|
||||||
|
TenantKey: "wallet-execute-" + suffix,
|
||||||
|
Name: "Wallet Execute Test " + suffix,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create tenant: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = db.DeleteTenant(context.Background(), tenant.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
gatewayUser, err := db.CreateGatewayUser(ctx, store.GatewayUserInput{
|
||||||
|
UserKey: "wallet-execute-user-" + suffix,
|
||||||
|
Username: "wallet_execute_" + suffix,
|
||||||
|
GatewayTenantID: tenant.ID,
|
||||||
|
TenantKey: tenant.TenantKey,
|
||||||
|
Roles: []string{"user"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create gateway user: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = db.DeleteGatewayUser(context.Background(), gatewayUser.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
platform, err := db.CreatePlatform(ctx, store.CreatePlatformInput{
|
||||||
|
Provider: "mock",
|
||||||
|
PlatformKey: "wallet-execute-mock-" + suffix,
|
||||||
|
Name: "Wallet Execute Mock " + suffix,
|
||||||
|
AuthType: "none",
|
||||||
|
Config: map[string]any{"specType": "mock"},
|
||||||
|
Status: "enabled",
|
||||||
|
Priority: 1,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create mock platform: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = db.DeletePlatform(context.Background(), platform.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
if _, err := db.CreatePlatformModel(ctx, store.CreatePlatformModelInput{
|
||||||
|
PlatformID: platform.ID,
|
||||||
|
ModelName: "mock-wallet-image",
|
||||||
|
ProviderModelName: "mock-wallet-image",
|
||||||
|
ModelType: store.StringList{"image_generate"},
|
||||||
|
DisplayName: "Mock Wallet Image",
|
||||||
|
BillingConfig: map[string]any{
|
||||||
|
"image": map[string]any{"basePrice": 10},
|
||||||
|
},
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("create mock platform model: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
user := &auth.User{
|
||||||
|
ID: gatewayUser.ID,
|
||||||
|
Source: "gateway",
|
||||||
|
GatewayUserID: gatewayUser.ID,
|
||||||
|
GatewayTenantID: tenant.ID,
|
||||||
|
TenantKey: tenant.TenantKey,
|
||||||
|
Roles: gatewayUser.Roles,
|
||||||
|
}
|
||||||
|
if _, err := db.SetUserWalletBalance(ctx, store.WalletBalanceAdjustmentInput{
|
||||||
|
GatewayUserID: gatewayUser.ID,
|
||||||
|
Currency: "resource",
|
||||||
|
Balance: 10,
|
||||||
|
Reason: "seed wallet execute test",
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("seed wallet balance: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tasks := make([]store.GatewayTask, 0, 2)
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
task, err := db.CreateTask(ctx, store.CreateTaskInput{
|
||||||
|
Kind: "images.generations",
|
||||||
|
Model: "mock-wallet-image",
|
||||||
|
Request: map[string]any{
|
||||||
|
"count": 1,
|
||||||
|
"prompt": "wallet execute test",
|
||||||
|
},
|
||||||
|
}, user)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create task: %v", err)
|
||||||
|
}
|
||||||
|
tasks = append(tasks, task)
|
||||||
|
}
|
||||||
|
|
||||||
|
mockClient := &walletExecuteMockClient{}
|
||||||
|
service := New(config.Config{}, db, slog.New(slog.NewTextHandler(io.Discard, nil)))
|
||||||
|
service.clients["mock"] = mockClient
|
||||||
|
|
||||||
|
type executeResult struct {
|
||||||
|
result Result
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
results := make(chan executeResult, len(tasks))
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for _, task := range tasks {
|
||||||
|
task := task
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
result, err := service.Execute(ctx, task, user)
|
||||||
|
results <- executeResult{result: result, err: err}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
close(results)
|
||||||
|
|
||||||
|
successCount := 0
|
||||||
|
insufficientCount := 0
|
||||||
|
for item := range results {
|
||||||
|
if item.err == nil {
|
||||||
|
successCount++
|
||||||
|
if item.result.Task.Status != "succeeded" {
|
||||||
|
t.Fatalf("successful execution status = %s, want succeeded", item.result.Task.Status)
|
||||||
|
}
|
||||||
|
if !walletExecuteFloatNear(item.result.Task.FinalChargeAmount, 10) {
|
||||||
|
t.Fatalf("successful execution final charge = %f, want 10", item.result.Task.FinalChargeAmount)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if errors.Is(item.err, store.ErrInsufficientWalletBalance) {
|
||||||
|
insufficientCount++
|
||||||
|
if item.result.Task.Status != "failed" || item.result.Task.ErrorCode != "insufficient_balance" {
|
||||||
|
t.Fatalf("insufficient execution task = %+v", item.result.Task)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
t.Fatalf("unexpected execute error: %v", item.err)
|
||||||
|
}
|
||||||
|
if successCount != 1 || insufficientCount != 1 {
|
||||||
|
t.Fatalf("expected one successful mock execution and one insufficient balance rejection, got success=%d insufficient=%d", successCount, insufficientCount)
|
||||||
|
}
|
||||||
|
if got := mockClient.calls.Load(); got != 1 {
|
||||||
|
t.Fatalf("mock client calls = %d, want 1", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
summary, err := db.GetWalletSummary(ctx, user, "resource")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("get wallet summary: %v", err)
|
||||||
|
}
|
||||||
|
account := summary.PrimaryAccount
|
||||||
|
if !walletExecuteFloatNear(account.Balance, 0) || !walletExecuteFloatNear(account.FrozenBalance, 0) || !walletExecuteFloatNear(account.TotalSpent, 10) {
|
||||||
|
t.Fatalf("wallet after concurrent mock execution balance=%f frozen=%f spent=%f, want 0/0/10", account.Balance, account.FrozenBalance, account.TotalSpent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func walletExecuteFloatNear(a float64, b float64) bool {
|
||||||
|
delta := a - b
|
||||||
|
if delta < 0 {
|
||||||
|
delta = -delta
|
||||||
|
}
|
||||||
|
return delta < 0.000001
|
||||||
|
}
|
||||||
530
apps/api/internal/script/executor.go
Normal file
530
apps/api/internal/script/executor.go
Normal file
@ -0,0 +1,530 @@
|
|||||||
|
package script
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/dop251/goja"
|
||||||
|
"github.com/dop251/goja_nodejs/eventloop"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
DefaultTimeout = 30 * time.Second
|
||||||
|
PreprocessTimeout = 10 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
type Logger interface {
|
||||||
|
Debug(msg string, args ...any)
|
||||||
|
Info(msg string, args ...any)
|
||||||
|
Warn(msg string, args ...any)
|
||||||
|
Error(msg string, args ...any)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Executor struct {
|
||||||
|
HTTPClient *http.Client
|
||||||
|
Logger Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
type Options struct {
|
||||||
|
Script string
|
||||||
|
Args []any
|
||||||
|
ContextData map[string]any
|
||||||
|
ScriptName string
|
||||||
|
PreferredEntryNames []string
|
||||||
|
Timeout time.Duration
|
||||||
|
HTTPClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
type Error struct {
|
||||||
|
Code string
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Error) Error() string {
|
||||||
|
if e == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(e.Message) != "" {
|
||||||
|
return e.Message
|
||||||
|
}
|
||||||
|
return e.Code
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Error) ErrorCode() string {
|
||||||
|
if e == nil || strings.TrimSpace(e.Code) == "" {
|
||||||
|
return "script_error"
|
||||||
|
}
|
||||||
|
return e.Code
|
||||||
|
}
|
||||||
|
|
||||||
|
type result struct {
|
||||||
|
value any
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
functionDeclarationPattern = regexp.MustCompile(`(?:^|\n)\s*(?:async\s+)?function\s+([A-Za-z_$][\w$]*)\s*\(`)
|
||||||
|
assignedFunctionPattern = regexp.MustCompile(`(?:^|\n)\s*(?:const|let|var)\s+([A-Za-z_$][\w$]*)\s*=\s*(?:async\s*)?(?:function\b|\([^)]*\)\s*=>|[A-Za-z_$][\w$]*\s*=>)`)
|
||||||
|
)
|
||||||
|
|
||||||
|
func (e Executor) Execute(ctx context.Context, opts Options) (any, error) {
|
||||||
|
scriptText := strings.TrimSpace(opts.Script)
|
||||||
|
if scriptText == "" {
|
||||||
|
return nil, &Error{Code: "script_empty", Message: "script is empty"}
|
||||||
|
}
|
||||||
|
scriptName := strings.TrimSpace(opts.ScriptName)
|
||||||
|
if scriptName == "" {
|
||||||
|
scriptName = "script"
|
||||||
|
}
|
||||||
|
timeout := opts.Timeout
|
||||||
|
if timeout <= 0 {
|
||||||
|
timeout = DefaultTimeout
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
loop := eventloop.NewEventLoop(eventloop.EnableConsole(false))
|
||||||
|
loop.Start()
|
||||||
|
defer loop.Terminate()
|
||||||
|
|
||||||
|
resultCh := make(chan result, 1)
|
||||||
|
var once sync.Once
|
||||||
|
finish := func(value any, err error) {
|
||||||
|
once.Do(func() {
|
||||||
|
resultCh <- result{value: value, err: err}
|
||||||
|
loop.StopNoWait()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
ok := loop.RunOnLoop(func(vm *goja.Runtime) {
|
||||||
|
e.installRuntime(ctx, loop, vm, opts.HTTPClient, scriptName)
|
||||||
|
for key, value := range opts.ContextData {
|
||||||
|
_ = vm.Set(key, value)
|
||||||
|
}
|
||||||
|
value, err := e.invoke(vm, scriptText, opts.Args, opts.PreferredEntryNames, scriptName)
|
||||||
|
if err != nil {
|
||||||
|
finish(nil, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
e.resolveValue(vm, value, finish)
|
||||||
|
})
|
||||||
|
if !ok {
|
||||||
|
return nil, &Error{Code: "script_runtime_error", Message: "script event loop is not available"}
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case out := <-resultCh:
|
||||||
|
if out.err != nil {
|
||||||
|
return nil, out.err
|
||||||
|
}
|
||||||
|
return normalizeExport(out.value), nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
loop.Terminate()
|
||||||
|
code := "script_timeout"
|
||||||
|
if errors.Is(ctx.Err(), context.Canceled) {
|
||||||
|
code = "script_cancelled"
|
||||||
|
}
|
||||||
|
return nil, &Error{Code: code, Message: fmt.Sprintf("%s exceeded %s", scriptName, timeout)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e Executor) invoke(vm *goja.Runtime, scriptText string, args []any, preferred []string, scriptName string) (goja.Value, error) {
|
||||||
|
if fnValue, err := vm.RunString("(" + scriptText + ")"); err == nil {
|
||||||
|
if fn, ok := goja.AssertFunction(fnValue); ok {
|
||||||
|
return fn(goja.Undefined(), values(vm, args)...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := vm.RunString(scriptText); err != nil {
|
||||||
|
return nil, &Error{Code: "script_compile_error", Message: err.Error()}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, name := range entryCandidates(scriptText, preferred) {
|
||||||
|
fnValue, err := vm.RunString(fmt.Sprintf("(typeof %s === 'function' ? %s : undefined)", name, name))
|
||||||
|
if err != nil || goja.IsUndefined(fnValue) || goja.IsNull(fnValue) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fn, ok := goja.AssertFunction(fnValue)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return fn(goja.Undefined(), values(vm, args)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, &Error{Code: "script_entry_missing", Message: fmt.Sprintf("%s must expose an executable function", scriptName)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e Executor) resolveValue(vm *goja.Runtime, value goja.Value, finish func(any, error)) {
|
||||||
|
if value == nil {
|
||||||
|
finish(nil, nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if promise, ok := value.Export().(*goja.Promise); ok {
|
||||||
|
switch promise.State() {
|
||||||
|
case goja.PromiseStateFulfilled:
|
||||||
|
finish(exportValue(promise.Result()), nil)
|
||||||
|
case goja.PromiseStateRejected:
|
||||||
|
finish(nil, &Error{Code: "script_error", Message: stringify(promise.Result())})
|
||||||
|
default:
|
||||||
|
obj := value.ToObject(vm)
|
||||||
|
thenFn, ok := goja.AssertFunction(obj.Get("then"))
|
||||||
|
if !ok {
|
||||||
|
finish(nil, &Error{Code: "script_error", Message: "promise.then is not callable"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
onResolve := func(call goja.FunctionCall) goja.Value {
|
||||||
|
finish(exportValue(call.Argument(0)), nil)
|
||||||
|
return goja.Undefined()
|
||||||
|
}
|
||||||
|
onReject := func(call goja.FunctionCall) goja.Value {
|
||||||
|
finish(nil, &Error{Code: "script_error", Message: stringify(call.Argument(0))})
|
||||||
|
return goja.Undefined()
|
||||||
|
}
|
||||||
|
_, _ = thenFn(obj, vm.ToValue(onResolve), vm.ToValue(onReject))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
finish(exportValue(value), nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e Executor) installRuntime(ctx context.Context, loop *eventloop.EventLoop, vm *goja.Runtime, client *http.Client, scriptName string) {
|
||||||
|
vm.SetFieldNameMapper(goja.TagFieldNameMapper("json", true))
|
||||||
|
e.installConsole(vm, scriptName)
|
||||||
|
e.installHTTP(ctx, loop, vm, firstHTTPClient(client, e.HTTPClient), scriptName)
|
||||||
|
_ = vm.Set("FormData", formDataConstructor(vm))
|
||||||
|
_, _ = vm.RunString(`
|
||||||
|
function __easyaiGotRequest(method, url, options) {
|
||||||
|
return {
|
||||||
|
json: function() { return __easyaiHTTP(method, url, options || {}).then(function(resp) { return resp.json(); }); },
|
||||||
|
text: function() { return __easyaiHTTP(method, url, options || {}).then(function(resp) { return resp.text(); }); }
|
||||||
|
};
|
||||||
|
}
|
||||||
|
var got = {
|
||||||
|
get: function(url, options) { return __easyaiGotRequest("GET", url, options); },
|
||||||
|
post: function(url, options) { return __easyaiGotRequest("POST", url, options); },
|
||||||
|
put: function(url, options) { return __easyaiGotRequest("PUT", url, options); },
|
||||||
|
patch: function(url, options) { return __easyaiGotRequest("PATCH", url, options); },
|
||||||
|
delete: function(url, options) { return __easyaiGotRequest("DELETE", url, options); },
|
||||||
|
extend: function() { return this; }
|
||||||
|
};
|
||||||
|
function fetch(url, options) {
|
||||||
|
options = options || {};
|
||||||
|
return __easyaiHTTP(options.method || "GET", url, options);
|
||||||
|
}
|
||||||
|
`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e Executor) installConsole(vm *goja.Runtime, scriptName string) {
|
||||||
|
log := func(level string, args ...any) {
|
||||||
|
if e.Logger == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
values := make([]any, 0, len(args)+1)
|
||||||
|
values = append(values, "script", scriptName)
|
||||||
|
values = append(values, args...)
|
||||||
|
switch level {
|
||||||
|
case "error":
|
||||||
|
e.Logger.Error("script console", values...)
|
||||||
|
case "warn":
|
||||||
|
e.Logger.Warn("script console", values...)
|
||||||
|
case "info":
|
||||||
|
e.Logger.Info("script console", values...)
|
||||||
|
default:
|
||||||
|
e.Logger.Debug("script console", values...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = vm.Set("console", map[string]any{
|
||||||
|
"log": func(args ...any) { log("debug", args...) },
|
||||||
|
"debug": func(args ...any) { log("debug", args...) },
|
||||||
|
"info": func(args ...any) { log("info", args...) },
|
||||||
|
"warn": func(args ...any) { log("warn", args...) },
|
||||||
|
"error": func(args ...any) { log("error", args...) },
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e Executor) installHTTP(ctx context.Context, loop *eventloop.EventLoop, vm *goja.Runtime, client *http.Client, scriptName string) {
|
||||||
|
_ = vm.Set("__easyaiHTTP", func(call goja.FunctionCall) goja.Value {
|
||||||
|
method := strings.ToUpper(strings.TrimSpace(call.Argument(0).String()))
|
||||||
|
if method == "" {
|
||||||
|
method = http.MethodGet
|
||||||
|
}
|
||||||
|
url := strings.TrimSpace(call.Argument(1).String())
|
||||||
|
options := exportMap(call.Argument(2))
|
||||||
|
promise, resolve, reject := vm.NewPromise()
|
||||||
|
go func() {
|
||||||
|
response, err := doHTTPRequest(ctx, client, method, url, options)
|
||||||
|
loop.RunOnLoop(func(runtime *goja.Runtime) {
|
||||||
|
if err != nil {
|
||||||
|
_ = reject(err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = resolve(httpResponseObject(runtime, response))
|
||||||
|
})
|
||||||
|
}()
|
||||||
|
return vm.ToValue(promise)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func doHTTPRequest(ctx context.Context, client *http.Client, method string, url string, options map[string]any) (httpScriptResponse, error) {
|
||||||
|
if strings.TrimSpace(url) == "" {
|
||||||
|
return httpScriptResponse{}, errors.New("url is required")
|
||||||
|
}
|
||||||
|
var body io.Reader
|
||||||
|
headers := map[string]string{}
|
||||||
|
if rawHeaders, ok := options["headers"].(map[string]any); ok {
|
||||||
|
for key, value := range rawHeaders {
|
||||||
|
if text := strings.TrimSpace(fmt.Sprint(value)); text != "" {
|
||||||
|
headers[key] = text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if jsonBody, ok := options["json"]; ok {
|
||||||
|
raw, err := json.Marshal(jsonBody)
|
||||||
|
if err != nil {
|
||||||
|
return httpScriptResponse{}, err
|
||||||
|
}
|
||||||
|
body = bytes.NewReader(raw)
|
||||||
|
if _, ok := headers["Content-Type"]; !ok {
|
||||||
|
headers["Content-Type"] = "application/json"
|
||||||
|
}
|
||||||
|
} else if rawBody, ok := options["body"]; ok {
|
||||||
|
body, headers = requestBody(rawBody, headers)
|
||||||
|
}
|
||||||
|
req, err := http.NewRequestWithContext(ctx, method, url, body)
|
||||||
|
if err != nil {
|
||||||
|
return httpScriptResponse{}, err
|
||||||
|
}
|
||||||
|
for key, value := range headers {
|
||||||
|
req.Header.Set(key, value)
|
||||||
|
}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return httpScriptResponse{}, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
raw, _ := io.ReadAll(io.LimitReader(resp.Body, 16*1024*1024))
|
||||||
|
out := httpScriptResponse{
|
||||||
|
Status: resp.Status,
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
OK: resp.StatusCode >= 200 && resp.StatusCode < 300,
|
||||||
|
Headers: map[string]any{},
|
||||||
|
Body: string(raw),
|
||||||
|
}
|
||||||
|
for key, values := range resp.Header {
|
||||||
|
if len(values) == 1 {
|
||||||
|
out.Headers[key] = values[0]
|
||||||
|
} else {
|
||||||
|
out.Headers[key] = values
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(raw) > 0 {
|
||||||
|
var parsed any
|
||||||
|
if json.Unmarshal(raw, &parsed) == nil {
|
||||||
|
out.JSON = parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type httpScriptResponse struct {
|
||||||
|
Status string
|
||||||
|
StatusCode int
|
||||||
|
OK bool
|
||||||
|
Headers map[string]any
|
||||||
|
Body string
|
||||||
|
JSON any
|
||||||
|
}
|
||||||
|
|
||||||
|
func httpResponseObject(vm *goja.Runtime, response httpScriptResponse) map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"status": response.StatusCode,
|
||||||
|
"statusCode": response.StatusCode,
|
||||||
|
"ok": response.OK,
|
||||||
|
"headers": response.Headers,
|
||||||
|
"text": func() string {
|
||||||
|
return response.Body
|
||||||
|
},
|
||||||
|
"json": func() any {
|
||||||
|
if response.JSON != nil {
|
||||||
|
return response.JSON
|
||||||
|
}
|
||||||
|
var parsed any
|
||||||
|
if json.Unmarshal([]byte(response.Body), &parsed) == nil {
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
panic(vm.NewTypeError("response body is not valid JSON"))
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestBody(value any, headers map[string]string) (io.Reader, map[string]string) {
|
||||||
|
switch typed := value.(type) {
|
||||||
|
case string:
|
||||||
|
return strings.NewReader(typed), headers
|
||||||
|
case []byte:
|
||||||
|
return bytes.NewReader(typed), headers
|
||||||
|
case map[string]any:
|
||||||
|
if fields, ok := typed["__easyaiFormData"].([]any); ok {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer := multipart.NewWriter(&buf)
|
||||||
|
for _, rawField := range fields {
|
||||||
|
field, ok := rawField.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
_ = writer.WriteField(strings.TrimSpace(fmt.Sprint(field["name"])), fmt.Sprint(field["value"]))
|
||||||
|
}
|
||||||
|
_ = writer.Close()
|
||||||
|
headers["Content-Type"] = writer.FormDataContentType()
|
||||||
|
return &buf, headers
|
||||||
|
}
|
||||||
|
raw, _ := json.Marshal(typed)
|
||||||
|
headers["Content-Type"] = "application/json"
|
||||||
|
return bytes.NewReader(raw), headers
|
||||||
|
default:
|
||||||
|
raw, _ := json.Marshal(typed)
|
||||||
|
headers["Content-Type"] = "application/json"
|
||||||
|
return bytes.NewReader(raw), headers
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func formDataConstructor(vm *goja.Runtime) func(goja.ConstructorCall) *goja.Object {
|
||||||
|
return func(call goja.ConstructorCall) *goja.Object {
|
||||||
|
obj := call.This
|
||||||
|
_ = obj.Set("__easyaiFormData", []any{})
|
||||||
|
_ = obj.Set("append", func(name string, value any) {
|
||||||
|
fields := exportSlice(obj.Get("__easyaiFormData"))
|
||||||
|
fields = append(fields, map[string]any{"name": name, "value": value})
|
||||||
|
_ = obj.Set("__easyaiFormData", fields)
|
||||||
|
})
|
||||||
|
return obj
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func entryCandidates(scriptText string, preferred []string) []string {
|
||||||
|
values := make([]string, 0, len(preferred)+4)
|
||||||
|
appendUnique := func(value string) {
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
if value == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, existing := range values {
|
||||||
|
if existing == value {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
values = append(values, value)
|
||||||
|
}
|
||||||
|
for _, value := range preferred {
|
||||||
|
appendUnique(value)
|
||||||
|
}
|
||||||
|
for _, match := range functionDeclarationPattern.FindAllStringSubmatch(scriptText, -1) {
|
||||||
|
appendUnique(match[1])
|
||||||
|
}
|
||||||
|
for _, match := range assignedFunctionPattern.FindAllStringSubmatch(scriptText, -1) {
|
||||||
|
appendUnique(match[1])
|
||||||
|
}
|
||||||
|
appendUnique("main")
|
||||||
|
appendUnique("handler")
|
||||||
|
return values
|
||||||
|
}
|
||||||
|
|
||||||
|
func values(vm *goja.Runtime, input []any) []goja.Value {
|
||||||
|
out := make([]goja.Value, 0, len(input))
|
||||||
|
for _, item := range input {
|
||||||
|
out = append(out, toValue(vm, item))
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func toValue(vm *goja.Runtime, item any) goja.Value {
|
||||||
|
if values, ok := item.(map[string]any); ok {
|
||||||
|
copied := map[string]any{}
|
||||||
|
for key, value := range values {
|
||||||
|
if key == "__easyaiScriptContext" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
copied[key] = value
|
||||||
|
}
|
||||||
|
obj := vm.ToValue(copied).ToObject(vm)
|
||||||
|
if marker, _ := values["__easyaiScriptContext"].(bool); marker {
|
||||||
|
_ = obj.Set("got", vm.Get("got"))
|
||||||
|
_ = obj.Set("fetch", vm.Get("fetch"))
|
||||||
|
_ = obj.Set("FormData", vm.Get("FormData"))
|
||||||
|
}
|
||||||
|
return obj
|
||||||
|
}
|
||||||
|
return vm.ToValue(item)
|
||||||
|
}
|
||||||
|
|
||||||
|
func exportValue(value goja.Value) any {
|
||||||
|
if value == nil || goja.IsUndefined(value) || goja.IsNull(value) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return value.Export()
|
||||||
|
}
|
||||||
|
|
||||||
|
func exportMap(value goja.Value) map[string]any {
|
||||||
|
if value == nil || goja.IsUndefined(value) || goja.IsNull(value) {
|
||||||
|
return map[string]any{}
|
||||||
|
}
|
||||||
|
if typed, ok := normalizeExport(value.Export()).(map[string]any); ok {
|
||||||
|
return typed
|
||||||
|
}
|
||||||
|
return map[string]any{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func exportSlice(value goja.Value) []any {
|
||||||
|
if value == nil || goja.IsUndefined(value) || goja.IsNull(value) {
|
||||||
|
return []any{}
|
||||||
|
}
|
||||||
|
if typed, ok := normalizeExport(value.Export()).([]any); ok {
|
||||||
|
return typed
|
||||||
|
}
|
||||||
|
return []any{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeExport(value any) any {
|
||||||
|
raw, err := json.Marshal(value)
|
||||||
|
if err != nil {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
var out any
|
||||||
|
if json.Unmarshal(raw, &out) != nil {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func firstHTTPClient(values ...*http.Client) *http.Client {
|
||||||
|
for _, value := range values {
|
||||||
|
if value != nil {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return http.DefaultClient
|
||||||
|
}
|
||||||
|
|
||||||
|
func stringify(value goja.Value) string {
|
||||||
|
if value == nil || goja.IsUndefined(value) || goja.IsNull(value) {
|
||||||
|
return "script rejected"
|
||||||
|
}
|
||||||
|
if exported, ok := normalizeExport(value.Export()).(map[string]any); ok {
|
||||||
|
for _, key := range []string{"message", "error", "error_message"} {
|
||||||
|
if message := strings.TrimSpace(fmt.Sprint(exported[key])); message != "" && message != "<nil>" {
|
||||||
|
return message
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(value.String())
|
||||||
|
}
|
||||||
116
apps/api/internal/script/executor_test.go
Normal file
116
apps/api/internal/script/executor_test.go
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
package script
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExecutorRunsFunctionExpression(t *testing.T) {
|
||||||
|
out, err := (Executor{}).Execute(context.Background(), Options{
|
||||||
|
Script: `(params) => ({ prompt: params.prompt.toUpperCase(), n: 2 })`,
|
||||||
|
Args: []any{map[string]any{"prompt": "hello"}},
|
||||||
|
ScriptName: "custom_preprocess_script",
|
||||||
|
Timeout: time.Second,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("execute failed: %v", err)
|
||||||
|
}
|
||||||
|
result := out.(map[string]any)
|
||||||
|
if result["prompt"] != "HELLO" || result["n"].(float64) != 2 {
|
||||||
|
t.Fatalf("unexpected result: %#v", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecutorSelectsPreferredEntry(t *testing.T) {
|
||||||
|
out, err := (Executor{}).Execute(context.Background(), Options{
|
||||||
|
Script: `
|
||||||
|
function helper() { return { wrong: true }; }
|
||||||
|
async function submitTask(payload, context) {
|
||||||
|
return { status: "submitted", task_id: payload.id, baseURL: context.baseURL };
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
Args: []any{map[string]any{"id": "task-1"}, map[string]any{"baseURL": "https://example.test"}},
|
||||||
|
ContextData: map[string]any{"baseURL": "https://example.test"},
|
||||||
|
PreferredEntryNames: []string{"submitTask", "submit"},
|
||||||
|
ScriptName: "custom_submit_script:video_generate",
|
||||||
|
Timeout: time.Second,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("execute failed: %v", err)
|
||||||
|
}
|
||||||
|
result := out.(map[string]any)
|
||||||
|
if result["task_id"] != "task-1" || result["baseURL"] != "https://example.test" {
|
||||||
|
t.Fatalf("unexpected result: %#v", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecutorGotJSONHelper(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
t.Fatalf("unexpected method: %s", r.Method)
|
||||||
|
}
|
||||||
|
if r.Header.Get("Authorization") != "Bearer test" {
|
||||||
|
t.Fatalf("missing authorization header")
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"status":"success","task_id":"remote-1"}`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
out, err := (Executor{}).Execute(context.Background(), Options{
|
||||||
|
Script: `
|
||||||
|
async function submitTask(payload, context) {
|
||||||
|
return await got.post(context.baseURL + "/tasks", {
|
||||||
|
headers: { Authorization: "Bearer " + context.authValues.apiKey },
|
||||||
|
json: payload
|
||||||
|
}).json();
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
Args: []any{map[string]any{"prompt": "hello"}, map[string]any{"baseURL": server.URL, "authValues": map[string]any{"apiKey": "test"}}},
|
||||||
|
ContextData: map[string]any{"baseURL": server.URL, "authValues": map[string]any{"apiKey": "test"}},
|
||||||
|
PreferredEntryNames: []string{"submitTask"},
|
||||||
|
ScriptName: "custom_submit_script:image_generate",
|
||||||
|
Timeout: 2 * time.Second,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("execute failed: %v", err)
|
||||||
|
}
|
||||||
|
result := out.(map[string]any)
|
||||||
|
if result["task_id"] != "remote-1" {
|
||||||
|
t.Fatalf("unexpected result: %#v", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecutorTimeout(t *testing.T) {
|
||||||
|
_, err := (Executor{}).Execute(context.Background(), Options{
|
||||||
|
Script: `async function main() { await new Promise((resolve) => setTimeout(resolve, 200)); return true; }`,
|
||||||
|
ScriptName: "custom_poll_script",
|
||||||
|
Timeout: 25 * time.Millisecond,
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected timeout")
|
||||||
|
}
|
||||||
|
scriptErr, ok := err.(*Error)
|
||||||
|
if !ok || scriptErr.Code != "script_timeout" {
|
||||||
|
t.Fatalf("expected script_timeout, got %#v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecutorRejectedPromiseMessage(t *testing.T) {
|
||||||
|
_, err := (Executor{}).Execute(context.Background(), Options{
|
||||||
|
Script: `async function main() { throw new Error("boom"); }`,
|
||||||
|
ScriptName: "custom_submit_script",
|
||||||
|
Timeout: time.Second,
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected rejection")
|
||||||
|
}
|
||||||
|
scriptErr, ok := err.(*Error)
|
||||||
|
if !ok || scriptErr.Code != "script_error" || !strings.Contains(scriptErr.Message, "boom") {
|
||||||
|
t.Fatalf("expected script_error with boom, got %#v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -14,6 +14,7 @@ var (
|
|||||||
type ModelCandidateUnavailableError struct {
|
type ModelCandidateUnavailableError struct {
|
||||||
Code string
|
Code string
|
||||||
Message string
|
Message string
|
||||||
|
Details map[string]any
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *ModelCandidateUnavailableError) Error() string {
|
func (e *ModelCandidateUnavailableError) Error() string {
|
||||||
@ -32,6 +33,14 @@ func ModelCandidateErrorCode(err error) string {
|
|||||||
return "no_model_candidate"
|
return "no_model_candidate"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ModelCandidateErrorDetails(err error) map[string]any {
|
||||||
|
var candidateErr *ModelCandidateUnavailableError
|
||||||
|
if errors.As(err, &candidateErr) && len(candidateErr.Details) > 0 {
|
||||||
|
return candidateErr.Details
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type RateLimitExceededError struct {
|
type RateLimitExceededError struct {
|
||||||
ScopeType string
|
ScopeType string
|
||||||
ScopeKey string
|
ScopeKey string
|
||||||
@ -247,6 +256,7 @@ type FinishTaskFailureInput struct {
|
|||||||
TaskID string
|
TaskID string
|
||||||
Code string
|
Code string
|
||||||
Message string
|
Message string
|
||||||
|
Result map[string]any
|
||||||
RequestID string
|
RequestID string
|
||||||
Metrics map[string]any
|
Metrics map[string]any
|
||||||
ResponseStartedAt time.Time
|
ResponseStartedAt time.Time
|
||||||
|
|||||||
@ -3,13 +3,13 @@ package store
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
||||||
"github.com/jackc/pgx/v5"
|
"github.com/jackc/pgx/v5"
|
||||||
"github.com/jackc/pgx/v5/pgconn"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type TaskListFilter struct {
|
type TaskListFilter struct {
|
||||||
@ -687,14 +687,15 @@ func (s *Store) SettleTaskBilling(ctx context.Context, task GatewayTask) error {
|
|||||||
if currency == "" || currency == "mixed" {
|
if currency == "" || currency == "mixed" {
|
||||||
currency = "resource"
|
currency = "resource"
|
||||||
}
|
}
|
||||||
metadata, _ := json.Marshal(map[string]any{
|
metadataMap := map[string]any{
|
||||||
"taskId": task.ID,
|
"taskId": task.ID,
|
||||||
"kind": task.Kind,
|
"kind": task.Kind,
|
||||||
"model": task.Model,
|
"model": task.Model,
|
||||||
"resolvedModel": task.ResolvedModel,
|
"resolvedModel": task.ResolvedModel,
|
||||||
"billings": task.Billings,
|
"billings": task.Billings,
|
||||||
"billingSummary": task.BillingSummary,
|
"billingSummary": task.BillingSummary,
|
||||||
})
|
}
|
||||||
|
metadata, _ := json.Marshal(metadataMap)
|
||||||
return pgx.BeginFunc(ctx, s.pool, func(tx pgx.Tx) error {
|
return pgx.BeginFunc(ctx, s.pool, func(tx pgx.Tx) error {
|
||||||
if _, err := tx.Exec(ctx, `
|
if _, err := tx.Exec(ctx, `
|
||||||
INSERT INTO gateway_wallet_accounts (
|
INSERT INTO gateway_wallet_accounts (
|
||||||
@ -706,42 +707,85 @@ ON CONFLICT (gateway_user_id, currency) DO NOTHING`,
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
var exists bool
|
var exists bool
|
||||||
|
var accountID string
|
||||||
|
var balanceBefore float64
|
||||||
|
var frozenBefore float64
|
||||||
|
var gatewayTenantID string
|
||||||
|
if err := tx.QueryRow(ctx, `
|
||||||
|
SELECT id::text, balance::float8, frozen_balance::float8, COALESCE(gateway_tenant_id::text, '')
|
||||||
|
FROM gateway_wallet_accounts
|
||||||
|
WHERE gateway_user_id = $1::uuid
|
||||||
|
AND currency = $2
|
||||||
|
FOR UPDATE`, task.GatewayUserID, currency).Scan(&accountID, &balanceBefore, &frozenBefore, &gatewayTenantID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
if err := tx.QueryRow(ctx, `
|
if err := tx.QueryRow(ctx, `
|
||||||
SELECT EXISTS (
|
SELECT EXISTS (
|
||||||
SELECT 1
|
SELECT 1
|
||||||
FROM gateway_wallet_transactions t
|
FROM gateway_wallet_transactions
|
||||||
JOIN gateway_wallet_accounts a ON a.id = t.account_id
|
WHERE account_id = $1::uuid
|
||||||
WHERE a.gateway_user_id = $1::uuid
|
AND idempotency_key = $2
|
||||||
AND a.currency = $2
|
)`, accountID, billingIdempotencyKey(task.ID)).Scan(&exists); err != nil {
|
||||||
AND t.idempotency_key = $3
|
|
||||||
)`, task.GatewayUserID, currency, billingIdempotencyKey(task.ID)).Scan(&exists); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if exists {
|
if exists {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
var accountID string
|
|
||||||
var balanceBefore float64
|
amount := roundMoney(task.FinalChargeAmount)
|
||||||
var gatewayTenantID string
|
reservationKey, reservedAmount, err := activeWalletReservation(ctx, tx, accountID, task.ID)
|
||||||
if err := tx.QueryRow(ctx, `
|
if err != nil {
|
||||||
SELECT id::text, balance::float8, COALESCE(gateway_tenant_id::text, '')
|
|
||||||
FROM gateway_wallet_accounts
|
|
||||||
WHERE gateway_user_id = $1::uuid
|
|
||||||
AND currency = $2
|
|
||||||
FOR UPDATE`, task.GatewayUserID, currency).Scan(&accountID, &balanceBefore, &gatewayTenantID); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
amount := roundMoney(task.FinalChargeAmount)
|
reservedAmount = roundMoney(reservedAmount)
|
||||||
|
spendableForTask := roundMoney(balanceBefore - frozenBefore + reservedAmount)
|
||||||
|
if spendableForTask+0.000001 < amount {
|
||||||
|
return fmt.Errorf("%w: required %.6f %s, available %.6f", ErrInsufficientWalletBalance, amount, currency, spendableForTask)
|
||||||
|
}
|
||||||
|
|
||||||
balanceAfter := roundMoney(balanceBefore - amount)
|
balanceAfter := roundMoney(balanceBefore - amount)
|
||||||
|
frozenAfter := roundMoney(frozenBefore - reservedAmount)
|
||||||
|
if frozenAfter < 0 {
|
||||||
|
frozenAfter = 0
|
||||||
|
}
|
||||||
if _, err := tx.Exec(ctx, `
|
if _, err := tx.Exec(ctx, `
|
||||||
UPDATE gateway_wallet_accounts
|
UPDATE gateway_wallet_accounts
|
||||||
SET balance = $2,
|
SET balance = $2,
|
||||||
total_spent = total_spent + $3,
|
total_spent = total_spent + $3,
|
||||||
|
frozen_balance = $4,
|
||||||
updated_at = now()
|
updated_at = now()
|
||||||
WHERE id = $1::uuid`, accountID, balanceAfter, amount); err != nil {
|
WHERE id = $1::uuid`, accountID, balanceAfter, amount, frozenAfter); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err := tx.Exec(ctx, `
|
if reservedAmount > 0 {
|
||||||
|
releaseMetadata, _ := json.Marshal(map[string]any{
|
||||||
|
"taskId": task.ID,
|
||||||
|
"reason": "task_billing_settled",
|
||||||
|
"reserved": reservedAmount,
|
||||||
|
"frozenBefore": roundMoney(frozenBefore),
|
||||||
|
"frozenAfter": frozenAfter,
|
||||||
|
})
|
||||||
|
if _, err := tx.Exec(ctx, `
|
||||||
|
INSERT INTO gateway_wallet_transactions (
|
||||||
|
account_id, gateway_tenant_id, gateway_user_id, direction, transaction_type,
|
||||||
|
amount, balance_before, balance_after, idempotency_key, reference_type, reference_id, metadata
|
||||||
|
)
|
||||||
|
VALUES (
|
||||||
|
$1::uuid, NULLIF($2, '')::uuid, $3::uuid, 'credit', 'release',
|
||||||
|
$4, $5, $6, $7, 'gateway_task', $8, $9::jsonb
|
||||||
|
)
|
||||||
|
ON CONFLICT (account_id, idempotency_key) WHERE idempotency_key IS NOT NULL DO NOTHING`,
|
||||||
|
accountID, firstNonEmpty(gatewayTenantID, task.GatewayTenantID), task.GatewayUserID, reservedAmount, roundMoney(balanceBefore), roundMoney(balanceBefore), billingReservationReleaseIdempotencyKey(reservationKey), task.ID, string(releaseMetadata)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
billingMetadata := mergeObjects(metadataMap, map[string]any{
|
||||||
|
"reservedAmount": reservedAmount,
|
||||||
|
"frozenBefore": roundMoney(frozenBefore),
|
||||||
|
"frozenAfter": frozenAfter,
|
||||||
|
})
|
||||||
|
metadata, _ = json.Marshal(billingMetadata)
|
||||||
|
if _, err := tx.Exec(ctx, `
|
||||||
INSERT INTO gateway_wallet_transactions (
|
INSERT INTO gateway_wallet_transactions (
|
||||||
account_id, gateway_tenant_id, gateway_user_id, direction, transaction_type,
|
account_id, gateway_tenant_id, gateway_user_id, direction, transaction_type,
|
||||||
amount, balance_before, balance_after, idempotency_key, reference_type, reference_id, metadata
|
amount, balance_before, balance_after, idempotency_key, reference_type, reference_id, metadata
|
||||||
@ -750,11 +794,10 @@ VALUES (
|
|||||||
$1::uuid, NULLIF($2, '')::uuid, $3::uuid, 'debit', 'task_billing',
|
$1::uuid, NULLIF($2, '')::uuid, $3::uuid, 'debit', 'task_billing',
|
||||||
$4, $5, $6, $7, 'gateway_task', $8, $9::jsonb
|
$4, $5, $6, $7, 'gateway_task', $8, $9::jsonb
|
||||||
)`,
|
)`,
|
||||||
accountID, firstNonEmpty(gatewayTenantID, task.GatewayTenantID), task.GatewayUserID, amount, roundMoney(balanceBefore), balanceAfter, billingIdempotencyKey(task.ID), task.ID, string(metadata))
|
accountID, firstNonEmpty(gatewayTenantID, task.GatewayTenantID), task.GatewayUserID, amount, roundMoney(balanceBefore), balanceAfter, billingIdempotencyKey(task.ID), task.ID, string(metadata)); err != nil {
|
||||||
if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == "23505" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return err
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -778,9 +821,10 @@ func taskBillingString(value any) string {
|
|||||||
|
|
||||||
func (s *Store) FinishTaskFailure(ctx context.Context, input FinishTaskFailureInput) (GatewayTask, error) {
|
func (s *Store) FinishTaskFailure(ctx context.Context, input FinishTaskFailureInput) (GatewayTask, error) {
|
||||||
metricsJSON, _ := json.Marshal(emptyObjectIfNil(input.Metrics))
|
metricsJSON, _ := json.Marshal(emptyObjectIfNil(input.Metrics))
|
||||||
|
resultJSON, _ := json.Marshal(emptyObjectIfNil(input.Result))
|
||||||
if _, err := s.pool.Exec(ctx, `
|
if _, err := s.pool.Exec(ctx, `
|
||||||
UPDATE gateway_tasks
|
UPDATE gateway_tasks
|
||||||
SET status = 'failed',
|
SET status = 'failed',
|
||||||
error = NULLIF($2::text, ''),
|
error = NULLIF($2::text, ''),
|
||||||
error_code = NULLIF($3::text, ''),
|
error_code = NULLIF($3::text, ''),
|
||||||
error_message = NULLIF($2::text, ''),
|
error_message = NULLIF($2::text, ''),
|
||||||
@ -789,6 +833,7 @@ SET status = 'failed',
|
|||||||
response_started_at = $6::timestamptz,
|
response_started_at = $6::timestamptz,
|
||||||
response_finished_at = $7::timestamptz,
|
response_finished_at = $7::timestamptz,
|
||||||
response_duration_ms = $8,
|
response_duration_ms = $8,
|
||||||
|
result = $9::jsonb,
|
||||||
locked_by = NULL,
|
locked_by = NULL,
|
||||||
locked_at = NULL,
|
locked_at = NULL,
|
||||||
heartbeat_at = NULL,
|
heartbeat_at = NULL,
|
||||||
@ -803,6 +848,7 @@ WHERE id = $1::uuid`,
|
|||||||
nullableTime(input.ResponseStartedAt),
|
nullableTime(input.ResponseStartedAt),
|
||||||
nullableTime(input.ResponseFinishedAt),
|
nullableTime(input.ResponseFinishedAt),
|
||||||
input.ResponseDurationMS,
|
input.ResponseDurationMS,
|
||||||
|
string(resultJSON),
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return GatewayTask{}, err
|
return GatewayTask{}, err
|
||||||
}
|
}
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -92,6 +93,16 @@ type WalletAdjustmentResult struct {
|
|||||||
Transaction GatewayWalletTransaction `json:"transaction"`
|
Transaction GatewayWalletTransaction `json:"transaction"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type WalletBillingReservation struct {
|
||||||
|
TaskID string `json:"taskId"`
|
||||||
|
AccountID string `json:"accountId"`
|
||||||
|
GatewayUserID string `json:"gatewayUserId"`
|
||||||
|
GatewayTenantID string `json:"gatewayTenantId,omitempty"`
|
||||||
|
Currency string `json:"currency"`
|
||||||
|
Amount float64 `json:"amount"`
|
||||||
|
IdempotencyKey string `json:"idempotencyKey"`
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Store) WalletAvailability(ctx context.Context, user *auth.User, currency string, requiredAmount float64) (WalletAvailability, error) {
|
func (s *Store) WalletAvailability(ctx context.Context, user *auth.User, currency string, requiredAmount float64) (WalletAvailability, error) {
|
||||||
gatewayUserID := localGatewayUserID(user)
|
gatewayUserID := localGatewayUserID(user)
|
||||||
if gatewayUserID == "" {
|
if gatewayUserID == "" {
|
||||||
@ -115,6 +126,223 @@ func (s *Store) WalletAvailability(ctx context.Context, user *auth.User, currenc
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Store) ReserveTaskBilling(ctx context.Context, task GatewayTask, user *auth.User, billings []any) ([]WalletBillingReservation, error) {
|
||||||
|
gatewayUserID := taskGatewayUserID(task, user)
|
||||||
|
if gatewayUserID == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
taskID := strings.TrimSpace(task.ID)
|
||||||
|
if taskID == "" {
|
||||||
|
return nil, fmt.Errorf("task id is required for wallet reservation")
|
||||||
|
}
|
||||||
|
|
||||||
|
amounts := walletBillingAmounts(billings)
|
||||||
|
if len(amounts) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
reservations := make([]WalletBillingReservation, 0, len(amounts))
|
||||||
|
err := pgx.BeginFunc(ctx, s.pool, func(tx pgx.Tx) error {
|
||||||
|
for currency, rawAmount := range amounts {
|
||||||
|
amount := roundMoney(rawAmount)
|
||||||
|
if amount <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
account, err := s.ensureWalletAccount(ctx, tx, gatewayUserID, currency)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
locked, err := lockWalletAccount(ctx, tx, account.ID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
activeKey, activeAmount, err := activeWalletReservation(ctx, tx, locked.ID, taskID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if activeAmount > 0 {
|
||||||
|
reservation := WalletBillingReservation{
|
||||||
|
TaskID: taskID,
|
||||||
|
AccountID: locked.ID,
|
||||||
|
GatewayUserID: gatewayUserID,
|
||||||
|
GatewayTenantID: firstNonEmpty(locked.GatewayTenantID, task.GatewayTenantID),
|
||||||
|
Currency: locked.Currency,
|
||||||
|
Amount: activeAmount,
|
||||||
|
IdempotencyKey: activeKey,
|
||||||
|
}
|
||||||
|
reservations = append(reservations, reservation)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
sequence, err := nextWalletReservationSequence(ctx, tx, locked.ID, taskID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
key := billingReservationIdempotencyKey(taskID, locked.Currency, sequence)
|
||||||
|
reservation := WalletBillingReservation{
|
||||||
|
TaskID: taskID,
|
||||||
|
AccountID: locked.ID,
|
||||||
|
GatewayUserID: gatewayUserID,
|
||||||
|
GatewayTenantID: firstNonEmpty(locked.GatewayTenantID, task.GatewayTenantID),
|
||||||
|
Currency: locked.Currency,
|
||||||
|
Amount: amount,
|
||||||
|
IdempotencyKey: key,
|
||||||
|
}
|
||||||
|
available := roundMoney(locked.Balance - locked.FrozenBalance)
|
||||||
|
if available+0.000001 < amount {
|
||||||
|
return fmt.Errorf("%w: required %.6f %s, available %.6f", ErrInsufficientWalletBalance, amount, locked.Currency, available)
|
||||||
|
}
|
||||||
|
|
||||||
|
frozenAfter := roundMoney(locked.FrozenBalance + amount)
|
||||||
|
if _, err := tx.Exec(ctx, `
|
||||||
|
UPDATE gateway_wallet_accounts
|
||||||
|
SET frozen_balance = $2,
|
||||||
|
updated_at = now()
|
||||||
|
WHERE id = $1::uuid`, locked.ID, frozenAfter); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
metadata, _ := json.Marshal(map[string]any{
|
||||||
|
"taskId": taskID,
|
||||||
|
"kind": task.Kind,
|
||||||
|
"model": task.Model,
|
||||||
|
"reserved": amount,
|
||||||
|
"balance": roundMoney(locked.Balance),
|
||||||
|
"frozenBefore": roundMoney(locked.FrozenBalance),
|
||||||
|
"frozenAfter": frozenAfter,
|
||||||
|
})
|
||||||
|
if _, err := tx.Exec(ctx, `
|
||||||
|
INSERT INTO gateway_wallet_transactions (
|
||||||
|
account_id, gateway_tenant_id, gateway_user_id, direction, transaction_type,
|
||||||
|
amount, balance_before, balance_after, idempotency_key, reference_type, reference_id, metadata
|
||||||
|
)
|
||||||
|
VALUES (
|
||||||
|
$1::uuid, NULLIF($2, '')::uuid, $3::uuid, 'debit', 'reserve',
|
||||||
|
$4, $5, $6, $7, 'gateway_task', $8, $9::jsonb
|
||||||
|
)`,
|
||||||
|
locked.ID,
|
||||||
|
firstNonEmpty(locked.GatewayTenantID, task.GatewayTenantID),
|
||||||
|
gatewayUserID,
|
||||||
|
amount,
|
||||||
|
roundMoney(locked.Balance),
|
||||||
|
roundMoney(locked.Balance),
|
||||||
|
key,
|
||||||
|
taskID,
|
||||||
|
string(metadata),
|
||||||
|
); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
reservations = append(reservations, reservation)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return reservations, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) ReleaseTaskBillingReservations(ctx context.Context, reservations []WalletBillingReservation, reason string) error {
|
||||||
|
if len(reservations) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
reason = strings.TrimSpace(reason)
|
||||||
|
if reason == "" {
|
||||||
|
reason = "task_not_settled"
|
||||||
|
}
|
||||||
|
return pgx.BeginFunc(ctx, s.pool, func(tx pgx.Tx) error {
|
||||||
|
for _, reservation := range reservations {
|
||||||
|
if reservation.Amount <= 0 || strings.TrimSpace(reservation.AccountID) == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
reserveKey := strings.TrimSpace(reservation.IdempotencyKey)
|
||||||
|
if reserveKey == "" {
|
||||||
|
reserveKey = billingReservationIdempotencyKey(reservation.TaskID, reservation.Currency, 1)
|
||||||
|
}
|
||||||
|
releaseKey := billingReservationReleaseIdempotencyKey(reserveKey)
|
||||||
|
locked, err := lockWalletAccount(ctx, tx, reservation.AccountID)
|
||||||
|
if err != nil {
|
||||||
|
if err == pgx.ErrNoRows {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var alreadyReleased bool
|
||||||
|
if err := tx.QueryRow(ctx, `
|
||||||
|
SELECT EXISTS (
|
||||||
|
SELECT 1
|
||||||
|
FROM gateway_wallet_transactions
|
||||||
|
WHERE account_id = $1::uuid
|
||||||
|
AND idempotency_key = $2
|
||||||
|
)`, reservation.AccountID, releaseKey).Scan(&alreadyReleased); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if alreadyReleased {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var storedReservedAmount float64
|
||||||
|
if err := tx.QueryRow(ctx, `
|
||||||
|
SELECT COALESCE((
|
||||||
|
SELECT amount::float8
|
||||||
|
FROM gateway_wallet_transactions
|
||||||
|
WHERE account_id = $1::uuid
|
||||||
|
AND idempotency_key = $2
|
||||||
|
AND transaction_type = 'reserve'
|
||||||
|
LIMIT 1
|
||||||
|
), 0)::float8`, reservation.AccountID, reserveKey).Scan(&storedReservedAmount); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if storedReservedAmount <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
amount := roundMoney(storedReservedAmount)
|
||||||
|
frozenAfter := roundMoney(locked.FrozenBalance - amount)
|
||||||
|
if frozenAfter < 0 {
|
||||||
|
frozenAfter = 0
|
||||||
|
}
|
||||||
|
if _, err := tx.Exec(ctx, `
|
||||||
|
UPDATE gateway_wallet_accounts
|
||||||
|
SET frozen_balance = $2,
|
||||||
|
updated_at = now()
|
||||||
|
WHERE id = $1::uuid`, locked.ID, frozenAfter); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
metadata, _ := json.Marshal(map[string]any{
|
||||||
|
"taskId": reservation.TaskID,
|
||||||
|
"reason": reason,
|
||||||
|
"reserved": amount,
|
||||||
|
"frozenBefore": roundMoney(locked.FrozenBalance),
|
||||||
|
"frozenAfter": frozenAfter,
|
||||||
|
})
|
||||||
|
if _, err := tx.Exec(ctx, `
|
||||||
|
INSERT INTO gateway_wallet_transactions (
|
||||||
|
account_id, gateway_tenant_id, gateway_user_id, direction, transaction_type,
|
||||||
|
amount, balance_before, balance_after, idempotency_key, reference_type, reference_id, metadata
|
||||||
|
)
|
||||||
|
VALUES (
|
||||||
|
$1::uuid, NULLIF($2, '')::uuid, $3::uuid, 'credit', 'release',
|
||||||
|
$4, $5, $6, $7, 'gateway_task', $8, $9::jsonb
|
||||||
|
)
|
||||||
|
ON CONFLICT (account_id, idempotency_key) WHERE idempotency_key IS NOT NULL DO NOTHING`,
|
||||||
|
locked.ID,
|
||||||
|
locked.GatewayTenantID,
|
||||||
|
locked.GatewayUserID,
|
||||||
|
amount,
|
||||||
|
roundMoney(locked.Balance),
|
||||||
|
roundMoney(locked.Balance),
|
||||||
|
releaseKey,
|
||||||
|
reservation.TaskID,
|
||||||
|
string(metadata),
|
||||||
|
); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Store) GetWalletSummary(ctx context.Context, user *auth.User, currency string) (WalletSummary, error) {
|
func (s *Store) GetWalletSummary(ctx context.Context, user *auth.User, currency string) (WalletSummary, error) {
|
||||||
gatewayUserID := localGatewayUserID(user)
|
gatewayUserID := localGatewayUserID(user)
|
||||||
if gatewayUserID == "" {
|
if gatewayUserID == "" {
|
||||||
@ -465,6 +693,124 @@ WHERE gateway_user_id = $1::uuid
|
|||||||
return account, nil
|
return account, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func lockWalletAccount(ctx context.Context, tx pgx.Tx, accountID string) (GatewayWalletAccount, error) {
|
||||||
|
return scanWalletAccount(tx.QueryRow(ctx, `
|
||||||
|
SELECT id::text, COALESCE(gateway_tenant_id::text, ''), gateway_user_id::text,
|
||||||
|
COALESCE(tenant_id, ''), COALESCE(tenant_key, ''), COALESCE(user_id, ''),
|
||||||
|
currency, balance::float8, frozen_balance::float8, total_recharged::float8,
|
||||||
|
total_spent::float8, status, metadata, created_at, updated_at
|
||||||
|
FROM gateway_wallet_accounts
|
||||||
|
WHERE id = $1::uuid
|
||||||
|
FOR UPDATE`, accountID))
|
||||||
|
}
|
||||||
|
|
||||||
|
func activeWalletReservation(ctx context.Context, tx pgx.Tx, accountID string, taskID string) (string, float64, error) {
|
||||||
|
var key string
|
||||||
|
var amount float64
|
||||||
|
err := tx.QueryRow(ctx, `
|
||||||
|
SELECT COALESCE(t.idempotency_key, ''), t.amount::float8
|
||||||
|
FROM gateway_wallet_transactions t
|
||||||
|
WHERE t.account_id = $1::uuid
|
||||||
|
AND t.reference_type = 'gateway_task'
|
||||||
|
AND t.reference_id = $2
|
||||||
|
AND t.transaction_type = 'reserve'
|
||||||
|
AND COALESCE(t.idempotency_key, '') <> ''
|
||||||
|
AND NOT EXISTS (
|
||||||
|
SELECT 1
|
||||||
|
FROM gateway_wallet_transactions r
|
||||||
|
WHERE r.account_id = t.account_id
|
||||||
|
AND r.transaction_type = 'release'
|
||||||
|
AND r.idempotency_key = t.idempotency_key || ':release'
|
||||||
|
)
|
||||||
|
ORDER BY t.created_at DESC
|
||||||
|
LIMIT 1`, accountID, taskID).Scan(&key, &amount)
|
||||||
|
if err == pgx.ErrNoRows {
|
||||||
|
return "", 0, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return "", 0, err
|
||||||
|
}
|
||||||
|
return key, roundMoney(amount), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func nextWalletReservationSequence(ctx context.Context, tx pgx.Tx, accountID string, taskID string) (int, error) {
|
||||||
|
var count int
|
||||||
|
if err := tx.QueryRow(ctx, `
|
||||||
|
SELECT COUNT(*)::int
|
||||||
|
FROM gateway_wallet_transactions
|
||||||
|
WHERE account_id = $1::uuid
|
||||||
|
AND reference_type = 'gateway_task'
|
||||||
|
AND reference_id = $2
|
||||||
|
AND transaction_type = 'reserve'`, accountID, taskID).Scan(&count); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return count + 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func walletBillingAmounts(billings []any) map[string]float64 {
|
||||||
|
amounts := map[string]float64{}
|
||||||
|
for _, raw := range billings {
|
||||||
|
line, _ := raw.(map[string]any)
|
||||||
|
if line == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
amount := roundMoney(walletFloat(line["amount"]))
|
||||||
|
if amount <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
currency := normalizeWalletCurrency(walletString(line["currency"]))
|
||||||
|
amounts[currency] = roundMoney(amounts[currency] + amount)
|
||||||
|
}
|
||||||
|
return amounts
|
||||||
|
}
|
||||||
|
|
||||||
|
func taskGatewayUserID(task GatewayTask, user *auth.User) string {
|
||||||
|
return firstNonEmpty(strings.TrimSpace(task.GatewayUserID), localGatewayUserID(user))
|
||||||
|
}
|
||||||
|
|
||||||
|
func billingReservationIdempotencyKey(taskID string, currency string, sequence int) string {
|
||||||
|
if sequence <= 0 {
|
||||||
|
sequence = 1
|
||||||
|
}
|
||||||
|
return "task:" + strings.TrimSpace(taskID) + ":wallet-reservation:" + normalizeWalletCurrency(currency) + ":" + strconv.Itoa(sequence)
|
||||||
|
}
|
||||||
|
|
||||||
|
func billingReservationReleaseIdempotencyKey(reservationKey string) string {
|
||||||
|
return strings.TrimSpace(reservationKey) + ":release"
|
||||||
|
}
|
||||||
|
|
||||||
|
func walletString(value any) string {
|
||||||
|
if text, ok := value.(string); ok {
|
||||||
|
return strings.TrimSpace(text)
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func walletFloat(value any) float64 {
|
||||||
|
switch typed := value.(type) {
|
||||||
|
case float64:
|
||||||
|
return typed
|
||||||
|
case float32:
|
||||||
|
return float64(typed)
|
||||||
|
case int:
|
||||||
|
return float64(typed)
|
||||||
|
case int64:
|
||||||
|
return float64(typed)
|
||||||
|
case json.Number:
|
||||||
|
next, _ := typed.Float64()
|
||||||
|
return next
|
||||||
|
case string:
|
||||||
|
next := strings.TrimSpace(typed)
|
||||||
|
if next == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
parsed, _ := strconv.ParseFloat(next, 64)
|
||||||
|
return parsed
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func normalizeWalletCurrency(currency string) string {
|
func normalizeWalletCurrency(currency string) string {
|
||||||
currency = strings.TrimSpace(currency)
|
currency = strings.TrimSpace(currency)
|
||||||
if currency == "" {
|
if currency == "" {
|
||||||
|
|||||||
171
apps/api/internal/store/wallet_reservation_test.go
Normal file
171
apps/api/internal/store/wallet_reservation_test.go
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestReserveTaskBillingSerializesConcurrentWalletReservations(t *testing.T) {
|
||||||
|
databaseURL := strings.TrimSpace(os.Getenv("AI_GATEWAY_TEST_DATABASE_URL"))
|
||||||
|
if databaseURL == "" {
|
||||||
|
t.Skip("set AI_GATEWAY_TEST_DATABASE_URL to run the wallet reservation integration test")
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
db, err := Connect(ctx, databaseURL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("connect store: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
tenantID, userID := seedWalletReservationUser(t, ctx, db)
|
||||||
|
if _, err := db.SetUserWalletBalance(ctx, WalletBalanceAdjustmentInput{
|
||||||
|
GatewayUserID: userID,
|
||||||
|
Currency: "resource",
|
||||||
|
Balance: 10,
|
||||||
|
Reason: "seed wallet reservation test",
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("seed wallet balance: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
firstTaskID := newWalletReservationTestUUID(t, ctx, db)
|
||||||
|
secondTaskID := newWalletReservationTestUUID(t, ctx, db)
|
||||||
|
billings := []any{map[string]any{"currency": "resource", "amount": float64(10)}}
|
||||||
|
user := &auth.User{GatewayUserID: userID, GatewayTenantID: tenantID}
|
||||||
|
tasks := []GatewayTask{
|
||||||
|
{ID: firstTaskID, GatewayUserID: userID, GatewayTenantID: tenantID, Kind: "images.generations", Model: "mock-image"},
|
||||||
|
{ID: secondTaskID, GatewayUserID: userID, GatewayTenantID: tenantID, Kind: "videos.generations", Model: "mock-video"},
|
||||||
|
}
|
||||||
|
|
||||||
|
type reserveResult struct {
|
||||||
|
reservations []WalletBillingReservation
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
results := make(chan reserveResult, len(tasks))
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for _, task := range tasks {
|
||||||
|
task := task
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
reservations, err := db.ReserveTaskBilling(ctx, task, user, billings)
|
||||||
|
results <- reserveResult{reservations: reservations, err: err}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
close(results)
|
||||||
|
|
||||||
|
var successReservations []WalletBillingReservation
|
||||||
|
successCount := 0
|
||||||
|
insufficientCount := 0
|
||||||
|
for result := range results {
|
||||||
|
if result.err == nil {
|
||||||
|
successCount++
|
||||||
|
successReservations = result.reservations
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if errors.Is(result.err, ErrInsufficientWalletBalance) {
|
||||||
|
insufficientCount++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
t.Fatalf("unexpected reservation error: %v", result.err)
|
||||||
|
}
|
||||||
|
if successCount != 1 || insufficientCount != 1 {
|
||||||
|
t.Fatalf("expected one successful reservation and one insufficient balance rejection, got success=%d insufficient=%d", successCount, insufficientCount)
|
||||||
|
}
|
||||||
|
if len(successReservations) != 1 || !walletFloatNear(successReservations[0].Amount, 10) {
|
||||||
|
t.Fatalf("unexpected successful reservations: %+v", successReservations)
|
||||||
|
}
|
||||||
|
|
||||||
|
balance, frozen, spent := readWalletReservationAccount(t, ctx, db, userID)
|
||||||
|
if !walletFloatNear(balance, 10) || !walletFloatNear(frozen, 10) || !walletFloatNear(spent, 0) {
|
||||||
|
t.Fatalf("reservation should freeze balance without spending it, balance=%f frozen=%f spent=%f", balance, frozen, spent)
|
||||||
|
}
|
||||||
|
|
||||||
|
settleTask := GatewayTask{
|
||||||
|
ID: successReservations[0].TaskID,
|
||||||
|
GatewayUserID: userID,
|
||||||
|
GatewayTenantID: tenantID,
|
||||||
|
Kind: "images.generations",
|
||||||
|
Model: "mock-image",
|
||||||
|
ResolvedModel: "mock-image",
|
||||||
|
Billings: billings,
|
||||||
|
BillingSummary: map[string]any{"currency": "resource", "totalAmount": float64(10)},
|
||||||
|
FinalChargeAmount: 10,
|
||||||
|
}
|
||||||
|
if err := db.SettleTaskBilling(ctx, settleTask); err != nil {
|
||||||
|
t.Fatalf("settle reserved task billing: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.SettleTaskBilling(ctx, settleTask); err != nil {
|
||||||
|
t.Fatalf("settle reserved task billing should be idempotent: %v", err)
|
||||||
|
}
|
||||||
|
balance, frozen, spent = readWalletReservationAccount(t, ctx, db, userID)
|
||||||
|
if !walletFloatNear(balance, 0) || !walletFloatNear(frozen, 0) || !walletFloatNear(spent, 10) {
|
||||||
|
t.Fatalf("settlement should release reservation and debit once, balance=%f frozen=%f spent=%f", balance, frozen, spent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func seedWalletReservationUser(t *testing.T, ctx context.Context, db *Store) (string, string) {
|
||||||
|
t.Helper()
|
||||||
|
suffix := strconv.FormatInt(time.Now().UnixNano(), 10)
|
||||||
|
var tenantID string
|
||||||
|
if err := db.pool.QueryRow(ctx, `
|
||||||
|
INSERT INTO gateway_tenants (tenant_key, name)
|
||||||
|
VALUES ($1, $2)
|
||||||
|
RETURNING id::text`, "wallet-reservation-"+suffix, "Wallet Reservation Test "+suffix).Scan(&tenantID); err != nil {
|
||||||
|
t.Fatalf("insert test tenant: %v", err)
|
||||||
|
}
|
||||||
|
var userID string
|
||||||
|
if err := db.pool.QueryRow(ctx, `
|
||||||
|
INSERT INTO gateway_users (user_key, username, gateway_tenant_id, tenant_key, roles)
|
||||||
|
VALUES ($1, $2, $3::uuid, $4, '["basic"]'::jsonb)
|
||||||
|
RETURNING id::text`, "wallet-reservation-user-"+suffix, "wallet_reservation_"+suffix, tenantID, "wallet-reservation-"+suffix).Scan(&userID); err != nil {
|
||||||
|
t.Fatalf("insert test user: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
cleanupCtx := context.Background()
|
||||||
|
_, _ = db.pool.Exec(cleanupCtx, `DELETE FROM gateway_users WHERE id = $1::uuid`, userID)
|
||||||
|
_, _ = db.pool.Exec(cleanupCtx, `DELETE FROM gateway_tenants WHERE id = $1::uuid`, tenantID)
|
||||||
|
})
|
||||||
|
return tenantID, userID
|
||||||
|
}
|
||||||
|
|
||||||
|
func newWalletReservationTestUUID(t *testing.T, ctx context.Context, db *Store) string {
|
||||||
|
t.Helper()
|
||||||
|
var id string
|
||||||
|
if err := db.pool.QueryRow(ctx, `SELECT gen_random_uuid()::text`).Scan(&id); err != nil {
|
||||||
|
t.Fatalf("generate uuid: %v", err)
|
||||||
|
}
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
func readWalletReservationAccount(t *testing.T, ctx context.Context, db *Store, userID string) (float64, float64, float64) {
|
||||||
|
t.Helper()
|
||||||
|
var balance float64
|
||||||
|
var frozen float64
|
||||||
|
var spent float64
|
||||||
|
if err := db.pool.QueryRow(ctx, `
|
||||||
|
SELECT balance::float8, frozen_balance::float8, total_spent::float8
|
||||||
|
FROM gateway_wallet_accounts
|
||||||
|
WHERE gateway_user_id = $1::uuid
|
||||||
|
AND currency = 'resource'`, userID).Scan(&balance, &frozen, &spent); err != nil {
|
||||||
|
t.Fatalf("read wallet account: %v", err)
|
||||||
|
}
|
||||||
|
return balance, frozen, spent
|
||||||
|
}
|
||||||
|
|
||||||
|
func walletFloatNear(a float64, b float64) bool {
|
||||||
|
delta := a - b
|
||||||
|
if delta < 0 {
|
||||||
|
delta = -delta
|
||||||
|
}
|
||||||
|
return delta < 0.000001
|
||||||
|
}
|
||||||
25
apps/api/migrations/0038_keling_omni_audio_flags.sql
Normal file
25
apps/api/migrations/0038_keling_omni_audio_flags.sql
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
UPDATE base_model_catalog
|
||||||
|
SET capabilities = jsonb_set(
|
||||||
|
jsonb_set(capabilities, '{omni_video,input_audio}', 'false'::jsonb, true),
|
||||||
|
'{omni_video,max_audios}', '0'::jsonb, true
|
||||||
|
),
|
||||||
|
metadata = jsonb_set(
|
||||||
|
jsonb_set(metadata, '{rawModel,capabilities,omni_video,input_audio}', 'false'::jsonb, true),
|
||||||
|
'{rawModel,capabilities,omni_video,max_audios}', '0'::jsonb, true
|
||||||
|
),
|
||||||
|
updated_at = now()
|
||||||
|
WHERE provider_key = 'keling'
|
||||||
|
AND provider_model_name IN ('kling-video-o1', 'kling-v3-omni')
|
||||||
|
AND capabilities ? 'omni_video';
|
||||||
|
|
||||||
|
UPDATE platform_models m
|
||||||
|
SET capabilities = jsonb_set(
|
||||||
|
jsonb_set(m.capabilities, '{omni_video,input_audio}', 'false'::jsonb, true),
|
||||||
|
'{omni_video,max_audios}', '0'::jsonb, true
|
||||||
|
),
|
||||||
|
updated_at = now()
|
||||||
|
FROM integration_platforms p
|
||||||
|
WHERE m.platform_id = p.id
|
||||||
|
AND p.provider = 'keling'
|
||||||
|
AND COALESCE(NULLIF(m.provider_model_name, ''), m.model_name) IN ('kling-video-o1', 'kling-v3-omni')
|
||||||
|
AND m.capabilities ? 'omni_video';
|
||||||
19
apps/api/migrations/0039_exclude_easyai_media_catalog.sql
Normal file
19
apps/api/migrations/0039_exclude_easyai_media_catalog.sql
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
-- EasyAI/server-main is intentionally not migrated as an AI Gateway runtime
|
||||||
|
-- provider. Keep its historical catalog rows for traceability, but hide them
|
||||||
|
-- from fresh admin selection and mark the exclusion reason explicitly.
|
||||||
|
UPDATE base_model_catalog
|
||||||
|
SET status = 'deprecated',
|
||||||
|
metadata = COALESCE(metadata, '{}'::jsonb) || jsonb_build_object(
|
||||||
|
'selectable', false,
|
||||||
|
'migrationExcludedReason', 'excluded from AI Gateway media runtime migration to avoid gateway-to-server-main loopback',
|
||||||
|
'migrationExcludedAt', '0039_exclude_easyai_media_catalog'
|
||||||
|
)
|
||||||
|
WHERE provider_key = 'easyai'
|
||||||
|
AND model_type ?| ARRAY[
|
||||||
|
'image_generate',
|
||||||
|
'image_edit',
|
||||||
|
'video_generate',
|
||||||
|
'image_to_video',
|
||||||
|
'omni_video',
|
||||||
|
'video_edit'
|
||||||
|
];
|
||||||
@ -684,11 +684,13 @@ export interface VideoGenerationContent {
|
|||||||
};
|
};
|
||||||
video_url?: {
|
video_url?: {
|
||||||
url: string;
|
url: string;
|
||||||
|
mime_type?: string;
|
||||||
refer_type?: 'feature' | 'base';
|
refer_type?: 'feature' | 'base';
|
||||||
keep_original_sound?: 'yes' | 'no';
|
keep_original_sound?: 'yes' | 'no';
|
||||||
};
|
};
|
||||||
audio_url?: {
|
audio_url?: {
|
||||||
url: string;
|
url: string;
|
||||||
|
mime_type?: string;
|
||||||
};
|
};
|
||||||
role?: VideoGenerationContentRole;
|
role?: VideoGenerationContentRole;
|
||||||
shot_index?: number;
|
shot_index?: number;
|
||||||
|
|||||||
@ -812,7 +812,7 @@ function TaskRecord(props: { task: GatewayTask; token: string; onCopyRequestId:
|
|||||||
<TableCell>{props.task.apiKeyName || props.task.apiKeyPrefix || props.task.apiKeyId || '-'}</TableCell>
|
<TableCell>{props.task.apiKeyName || props.task.apiKeyPrefix || props.task.apiKeyId || '-'}</TableCell>
|
||||||
<TableCell className="taskRecordTokenCell">{tokenUsage}</TableCell>
|
<TableCell className="taskRecordTokenCell">{tokenUsage}</TableCell>
|
||||||
<TableCell>{chargeText}</TableCell>
|
<TableCell>{chargeText}</TableCell>
|
||||||
<TableCell>{formatDuration(props.task.responseDurationMs)}</TableCell>
|
<TableCell>{formatDuration(taskDurationMs(props.task))}</TableCell>
|
||||||
<TableCell>{formatDateTime(props.task.createdAt)}</TableCell>
|
<TableCell>{formatDateTime(props.task.createdAt)}</TableCell>
|
||||||
<TableCell>
|
<TableCell>
|
||||||
<Button type="button" variant="ghost" size="sm" className="taskRecordJsonButton" title={taskErrorText(props.task) || '查看原始 JSON'} onClick={() => props.onOpenJson(props.task)}>
|
<Button type="button" variant="ghost" size="sm" className="taskRecordJsonButton" title={taskErrorText(props.task) || '查看原始 JSON'} onClick={() => props.onOpenJson(props.task)}>
|
||||||
@ -971,7 +971,10 @@ function TaskAttemptPopoverContent(props: { task: GatewayTask }) {
|
|||||||
const attempts = props.task.attempts ?? [];
|
const attempts = props.task.attempts ?? [];
|
||||||
return (
|
return (
|
||||||
<span className="taskRecordAttemptPopover" role="tooltip">
|
<span className="taskRecordAttemptPopover" role="tooltip">
|
||||||
{attempts.map((attempt) => (
|
{attempts.map((attempt) => {
|
||||||
|
const trace = taskAttemptTrace(attempt);
|
||||||
|
const rateLimitText = taskAttemptRateLimitText(attempt);
|
||||||
|
return (
|
||||||
<span
|
<span
|
||||||
key={attempt.id || `${props.task.id}-${attempt.attemptNo}`}
|
key={attempt.id || `${props.task.id}-${attempt.attemptNo}`}
|
||||||
className={`taskRecordAttemptDetail ${attempt.status === 'failed' ? 'failed' : attempt.status === 'succeeded' ? 'succeeded' : ''}`}
|
className={`taskRecordAttemptDetail ${attempt.status === 'failed' ? 'failed' : attempt.status === 'succeeded' ? 'succeeded' : ''}`}
|
||||||
@ -982,9 +985,10 @@ function TaskAttemptPopoverContent(props: { task: GatewayTask }) {
|
|||||||
</span>
|
</span>
|
||||||
<small>{taskAttemptMeta(attempt)}</small>
|
<small>{taskAttemptMeta(attempt)}</small>
|
||||||
{attempt.status === 'failed' && <span className="taskRecordAttemptError">{taskAttemptFailureReason(attempt)}</span>}
|
{attempt.status === 'failed' && <span className="taskRecordAttemptError">{taskAttemptFailureReason(attempt)}</span>}
|
||||||
{taskAttemptTrace(attempt).length > 0 && (
|
{(rateLimitText || trace.length > 0) && (
|
||||||
<span className="taskRecordAttemptTrace">
|
<span className="taskRecordAttemptTrace">
|
||||||
{taskAttemptTrace(attempt).map((entry, index) => (
|
{rateLimitText && <span className="taskRecordAttemptTraceItem">{rateLimitText}</span>}
|
||||||
|
{trace.map((entry, index) => (
|
||||||
<span key={`${attempt.id || attempt.attemptNo}-trace-${index}`} className="taskRecordAttemptTraceItem">
|
<span key={`${attempt.id || attempt.attemptNo}-trace-${index}`} className="taskRecordAttemptTraceItem">
|
||||||
{taskAttemptTraceText(entry)}
|
{taskAttemptTraceText(entry)}
|
||||||
</span>
|
</span>
|
||||||
@ -992,7 +996,8 @@ function TaskAttemptPopoverContent(props: { task: GatewayTask }) {
|
|||||||
</span>
|
</span>
|
||||||
)}
|
)}
|
||||||
</span>
|
</span>
|
||||||
))}
|
);
|
||||||
|
})}
|
||||||
</span>
|
</span>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@ -1024,7 +1029,7 @@ function taskAttemptMeta(attempt: NonNullable<GatewayTask['attempts']>[number])
|
|||||||
attempt.providerModelName || attempt.modelName || attempt.modelAlias,
|
attempt.providerModelName || attempt.modelName || attempt.modelAlias,
|
||||||
attempt.requestId ? `RequestID ${attempt.requestId}` : '',
|
attempt.requestId ? `RequestID ${attempt.requestId}` : '',
|
||||||
statusCode ? `状态码 ${statusCode}` : '',
|
statusCode ? `状态码 ${statusCode}` : '',
|
||||||
attempt.responseDurationMs ? formatDuration(attempt.responseDurationMs) : '',
|
formatDuration(attemptDurationMs(attempt)),
|
||||||
].filter(Boolean);
|
].filter(Boolean);
|
||||||
return values.join(' · ') || attempt.clientId || '-';
|
return values.join(' · ') || attempt.clientId || '-';
|
||||||
}
|
}
|
||||||
@ -1055,6 +1060,29 @@ function taskAttemptTrace(attempt: NonNullable<GatewayTask['attempts']>[number])
|
|||||||
return raw.filter((item): item is Record<string, unknown> => Boolean(item) && typeof item === 'object' && !Array.isArray(item));
|
return raw.filter((item): item is Record<string, unknown> => Boolean(item) && typeof item === 'object' && !Array.isArray(item));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function taskAttemptRateLimitText(attempt: NonNullable<GatewayTask['attempts']>[number]) {
|
||||||
|
const detail = metadataObject(attempt.metrics, 'rateLimit');
|
||||||
|
if (!Object.keys(detail).length) return '';
|
||||||
|
const scopeName = objectString(detail, 'scopeName') || objectString(detail, 'scopeKey') || '限流对象';
|
||||||
|
const metric = objectString(detail, 'metric') || 'rate_limit';
|
||||||
|
const current = metadataNumber(detail, 'current');
|
||||||
|
const amount = metadataNumber(detail, 'amount');
|
||||||
|
const projected = metadataNumber(detail, 'projected');
|
||||||
|
const limit = metadataNumber(detail, 'limit');
|
||||||
|
const windowSeconds = metadataNumber(detail, 'windowSeconds');
|
||||||
|
const retryAfterMs = metadataNumber(detail, 'retryAfterMs');
|
||||||
|
const values = [
|
||||||
|
`限流 ${scopeName} · ${metric}`,
|
||||||
|
current !== null ? `当前 ${formatCellValue(current)}` : '',
|
||||||
|
amount !== null ? `本次 ${formatCellValue(amount)}` : '',
|
||||||
|
projected !== null ? `预计 ${formatCellValue(projected)}` : '',
|
||||||
|
limit !== null ? `限制 ${formatCellValue(limit)}` : '',
|
||||||
|
windowSeconds !== null ? `窗口 ${Math.trunc(windowSeconds)} 秒` : '',
|
||||||
|
retryAfterMs !== null ? `约 ${formatDuration(Math.trunc(retryAfterMs))} 后可重试` : '',
|
||||||
|
].filter(Boolean);
|
||||||
|
return values.join(' · ');
|
||||||
|
}
|
||||||
|
|
||||||
function taskAttemptTraceText(entry: Record<string, unknown>) {
|
function taskAttemptTraceText(entry: Record<string, unknown>) {
|
||||||
const event = objectString(entry, 'event');
|
const event = objectString(entry, 'event');
|
||||||
const action = objectString(entry, 'action');
|
const action = objectString(entry, 'action');
|
||||||
@ -1116,6 +1144,12 @@ function taskAttemptTraceReasonLabel(reason: string) {
|
|||||||
client_retryable: '客户端标记可重试',
|
client_retryable: '客户端标记可重试',
|
||||||
client_non_retryable: '客户端标记不可重试',
|
client_non_retryable: '客户端标记不可重试',
|
||||||
same_client_max_attempts: '达到本平台最大尝试次数',
|
same_client_max_attempts: '达到本平台最大尝试次数',
|
||||||
|
request_validation_failed: '请求校验失败',
|
||||||
|
candidate_selection_failed: '候选模型选择失败',
|
||||||
|
parameter_preprocessing_failed: '参数预处理失败',
|
||||||
|
wallet_balance_check_failed: '余额校验失败',
|
||||||
|
local_rate_limit_blocked: '本地限流拦截',
|
||||||
|
pre_provider_failed: '调用上游前失败',
|
||||||
local_rate_limit_wait_queue: '本地限流排队等待',
|
local_rate_limit_wait_queue: '本地限流排队等待',
|
||||||
failover_time_budget_exceeded: '超过全局切换时间预算',
|
failover_time_budget_exceeded: '超过全局切换时间预算',
|
||||||
runner_policy_disabled: '全局调度策略停用',
|
runner_policy_disabled: '全局调度策略停用',
|
||||||
@ -1321,10 +1355,41 @@ function tokenValue(value: unknown) {
|
|||||||
return Number.isFinite(numericValue) ? numericValue : null;
|
return Number.isFinite(numericValue) ? numericValue : null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function taskDurationMs(task: GatewayTask) {
|
||||||
|
return (
|
||||||
|
positiveDurationMs(task.responseDurationMs) ??
|
||||||
|
elapsedDurationMs(task.responseStartedAt, task.responseFinishedAt) ??
|
||||||
|
elapsedDurationMs(task.createdAt, task.finishedAt)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
function attemptDurationMs(attempt: NonNullable<GatewayTask['attempts']>[number]) {
|
||||||
|
return (
|
||||||
|
positiveDurationMs(attempt.responseDurationMs) ??
|
||||||
|
elapsedDurationMs(attempt.responseStartedAt, attempt.responseFinishedAt) ??
|
||||||
|
elapsedDurationMs(attempt.startedAt, attempt.finishedAt)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
function positiveDurationMs(value?: number) {
|
||||||
|
if (value === undefined || value === null) return undefined;
|
||||||
|
const numericValue = Number(value);
|
||||||
|
return Number.isFinite(numericValue) && numericValue > 0 ? numericValue : undefined;
|
||||||
|
}
|
||||||
|
|
||||||
|
function elapsedDurationMs(start?: string, end?: string) {
|
||||||
|
if (!start || !end) return undefined;
|
||||||
|
const startedAt = new Date(start).getTime();
|
||||||
|
const finishedAt = new Date(end).getTime();
|
||||||
|
if (!Number.isFinite(startedAt) || !Number.isFinite(finishedAt)) return undefined;
|
||||||
|
const elapsed = finishedAt - startedAt;
|
||||||
|
return elapsed > 0 ? Math.max(1, Math.round(elapsed)) : undefined;
|
||||||
|
}
|
||||||
|
|
||||||
function formatDuration(value?: number) {
|
function formatDuration(value?: number) {
|
||||||
if (value === undefined || value === null) return '-';
|
if (value === undefined || value === null) return '-';
|
||||||
const milliseconds = Math.max(0, Math.round(value));
|
const milliseconds = Math.max(0, Math.round(value));
|
||||||
if (milliseconds === 0) return '0秒';
|
if (milliseconds === 0) return '-';
|
||||||
if (milliseconds < 1000) return `${milliseconds}毫秒`;
|
if (milliseconds < 1000) return `${milliseconds}毫秒`;
|
||||||
const totalSeconds = Math.round(milliseconds / 1000);
|
const totalSeconds = Math.round(milliseconds / 1000);
|
||||||
const hours = Math.floor(totalSeconds / 3600);
|
const hours = Math.floor(totalSeconds / 3600);
|
||||||
|
|||||||
@ -40,14 +40,14 @@ type ValueOption = { label: string; value: string };
|
|||||||
const textFields: FieldDefinition[] = [
|
const textFields: FieldDefinition[] = [
|
||||||
{ key: 'supportTool', label: '工具调用', hint: 'function calling / tools', type: 'boolean' },
|
{ key: 'supportTool', label: '工具调用', hint: 'function calling / tools', type: 'boolean' },
|
||||||
{ key: 'supportStructuredOutput', label: '结构化输出', hint: 'JSON Schema 等输出', type: 'boolean' },
|
{ key: 'supportStructuredOutput', label: '结构化输出', hint: 'JSON Schema 等输出', type: 'boolean' },
|
||||||
{ key: 'supportThinking', label: '思考能力', hint: '支持 thinking 参数', type: 'boolean' },
|
{ key: 'supportThinking', label: '推理能力', hint: '支持 reasoning / thinking 参数', type: 'boolean' },
|
||||||
{ key: 'supportThinkingModeSwitch', label: '思考开关', hint: '可按请求切换', type: 'boolean' },
|
{ key: 'supportThinkingModeSwitch', label: '思考开关', hint: '可按请求切换', type: 'boolean' },
|
||||||
{ key: 'supportWebSearch', label: '联网搜索', type: 'boolean' },
|
{ key: 'supportWebSearch', label: '联网搜索', type: 'boolean' },
|
||||||
{ key: 'max_context_tokens', label: '上下文 Token', placeholder: '128000', type: 'number' },
|
{ key: 'max_context_tokens', label: '上下文 Token', placeholder: '128000', type: 'number' },
|
||||||
{ key: 'max_input_tokens', label: '最大输入 Token', placeholder: '64000', type: 'number' },
|
{ key: 'max_input_tokens', label: '最大输入 Token', placeholder: '64000', type: 'number' },
|
||||||
{ key: 'max_output_tokens', label: '最大输出 Token', placeholder: '8192', type: 'number' },
|
{ key: 'max_output_tokens', label: '最大输出 Token', placeholder: '8192', type: 'number' },
|
||||||
{ key: 'max_thinking_tokens', label: '最大思考 Token', placeholder: '32768', type: 'number' },
|
{ key: 'max_thinking_tokens', label: '最大思考 Token', placeholder: '32768', type: 'number' },
|
||||||
{ key: 'thinkingEffortLevels', label: '思考强度', placeholder: 'minimal, low, medium, high', type: 'list' },
|
{ key: 'thinkingEffortLevels', label: '推理深度', hint: '声明模型支持的 reasoning_effort 取值,可填写 max 等供应商自定义值', placeholder: 'none, minimal, low, medium, high, xhigh, max', type: 'list' },
|
||||||
];
|
];
|
||||||
|
|
||||||
const embeddingFields: FieldDefinition[] = [
|
const embeddingFields: FieldDefinition[] = [
|
||||||
@ -535,7 +535,7 @@ const imageAspectRatioOptions = [
|
|||||||
'7:4',
|
'7:4',
|
||||||
'4:7',
|
'4:7',
|
||||||
];
|
];
|
||||||
const thinkingEffortOptions = ['minimal', 'low', 'medium', 'high'];
|
const thinkingEffortOptions = ['none', 'minimal', 'low', 'medium', 'high', 'xhigh', 'max'];
|
||||||
const omniVideoModeOptions = ['text_to_video', 'image_reference', 'element_reference', 'first_last_frame', 'video_reference', 'video_edit', 'multi_shot'];
|
const omniVideoModeOptions = ['text_to_video', 'image_reference', 'element_reference', 'first_last_frame', 'video_reference', 'video_edit', 'multi_shot'];
|
||||||
const durationOptionValues = ['1', '2', '3', '4', '5', '6', '8', '10', '15', '20', '25', '30'];
|
const durationOptionValues = ['1', '2', '3', '4', '5', '6', '8', '10', '15', '20', '25', '30'];
|
||||||
const exclusiveCapabilityFields: Record<string, string> = {
|
const exclusiveCapabilityFields: Record<string, string> = {
|
||||||
|
|||||||
@ -32,8 +32,8 @@ export interface PlaygroundUpload {
|
|||||||
export type OpenAIChatContentPart =
|
export type OpenAIChatContentPart =
|
||||||
| { type: 'text'; text: string }
|
| { type: 'text'; text: string }
|
||||||
| { type: 'image_url'; image_url: { url: string } }
|
| { type: 'image_url'; image_url: { url: string } }
|
||||||
| { type: 'video_url'; video_url: { url: string } }
|
| { type: 'video_url'; video_url: { mime_type?: string; url: string } }
|
||||||
| { type: 'audio_url'; audio_url: { url: string } }
|
| { type: 'audio_url'; audio_url: { mime_type?: string; url: string } }
|
||||||
| { type: 'file_url'; file_url: { filename: string; url: string } };
|
| { type: 'file_url'; file_url: { filename: string; url: string } };
|
||||||
|
|
||||||
export const mediaUploadAccept = 'image/*,video/*,audio/*';
|
export const mediaUploadAccept = 'image/*,video/*,audio/*';
|
||||||
@ -518,11 +518,17 @@ export function openAIContentFromPromptAndUploads(prompt: string, uploads: Playg
|
|||||||
function openAIContentPartFromUpload(item: PlaygroundUpload): OpenAIChatContentPart | undefined {
|
function openAIContentPartFromUpload(item: PlaygroundUpload): OpenAIChatContentPart | undefined {
|
||||||
if (!item.url) return undefined;
|
if (!item.url) return undefined;
|
||||||
if (item.kind === 'image') return { type: 'image_url', image_url: { url: item.url } };
|
if (item.kind === 'image') return { type: 'image_url', image_url: { url: item.url } };
|
||||||
if (item.kind === 'video') return { type: 'video_url', video_url: { url: item.url } };
|
if (item.kind === 'video') return { type: 'video_url', video_url: mediaURLPayload(item) };
|
||||||
if (item.kind === 'audio') return { type: 'audio_url', audio_url: { url: item.url } };
|
if (item.kind === 'audio') return { type: 'audio_url', audio_url: mediaURLPayload(item) };
|
||||||
return { type: 'file_url', file_url: { filename: item.name, url: item.url } };
|
return { type: 'file_url', file_url: { filename: item.name, url: item.url } };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function mediaURLPayload(item: PlaygroundUpload) {
|
||||||
|
const payload: { mime_type?: string; url: string } = { url: item.url };
|
||||||
|
if (item.contentType) payload.mime_type = item.contentType;
|
||||||
|
return payload;
|
||||||
|
}
|
||||||
|
|
||||||
export function mediaUploadRequestPayload(uploads: PlaygroundUpload[], mode: Exclude<PlaygroundMode, 'chat'>) {
|
export function mediaUploadRequestPayload(uploads: PlaygroundUpload[], mode: Exclude<PlaygroundMode, 'chat'>) {
|
||||||
const images = uploads.filter((item) => item.kind === 'image').map((item) => item.url);
|
const images = uploads.filter((item) => item.kind === 'image').map((item) => item.url);
|
||||||
const payload: Record<string, string | string[]> = {};
|
const payload: Record<string, string | string[]> = {};
|
||||||
@ -570,10 +576,10 @@ function videoGenerationContentFromUpload(item: PlaygroundUpload): VideoGenerati
|
|||||||
return { type: 'image_url', role: 'reference_image', image_url: { url: item.url } };
|
return { type: 'image_url', role: 'reference_image', image_url: { url: item.url } };
|
||||||
}
|
}
|
||||||
if (item.kind === 'video') {
|
if (item.kind === 'video') {
|
||||||
return { type: 'video_url', role: 'reference_video', video_url: { url: item.url, refer_type: 'feature' } };
|
return { type: 'video_url', role: 'reference_video', video_url: { ...mediaURLPayload(item), refer_type: 'feature' } };
|
||||||
}
|
}
|
||||||
if (item.kind === 'audio') {
|
if (item.kind === 'audio') {
|
||||||
return { type: 'audio_url', role: 'reference_audio', audio_url: { url: item.url } };
|
return { type: 'audio_url', role: 'reference_audio', audio_url: mediaURLPayload(item) };
|
||||||
}
|
}
|
||||||
return undefined;
|
return undefined;
|
||||||
}
|
}
|
||||||
|
|||||||
65
devenv.lock
Normal file
65
devenv.lock
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
{
|
||||||
|
"nodes": {
|
||||||
|
"devenv": {
|
||||||
|
"locked": {
|
||||||
|
"dir": "src/modules",
|
||||||
|
"lastModified": 1778613747,
|
||||||
|
"narHash": "sha256-+FdF9iIvBQIC391Xkoso3IFIl/Iqp2NolSvCOgEIm78=",
|
||||||
|
"owner": "cachix",
|
||||||
|
"repo": "devenv",
|
||||||
|
"rev": "c9ee1d61986a6dde1cf45e738b01395cd5bce470",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"dir": "src/modules",
|
||||||
|
"owner": "cachix",
|
||||||
|
"repo": "devenv",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nixpkgs": {
|
||||||
|
"inputs": {
|
||||||
|
"nixpkgs-src": "nixpkgs-src"
|
||||||
|
},
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1778507786,
|
||||||
|
"narHash": "sha256-HzSQCKMsMr8r55LwM1JuzIOB+8bzk0FEv6sItKvsfoY=",
|
||||||
|
"owner": "cachix",
|
||||||
|
"repo": "devenv-nixpkgs",
|
||||||
|
"rev": "8f24a228a782e24576b155d1e39f0d914b380691",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "cachix",
|
||||||
|
"ref": "rolling",
|
||||||
|
"repo": "devenv-nixpkgs",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nixpkgs-src": {
|
||||||
|
"flake": false,
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1778274207,
|
||||||
|
"narHash": "sha256-I4puXmX1iovcCHZlRmztO3vW0mAbbRvq4F8wgIMQ1MM=",
|
||||||
|
"owner": "NixOS",
|
||||||
|
"repo": "nixpkgs",
|
||||||
|
"rev": "b3da656039dc7a6240f27b2ef8cc6a3ef3bccae7",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "NixOS",
|
||||||
|
"ref": "nixpkgs-unstable",
|
||||||
|
"repo": "nixpkgs",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"root": {
|
||||||
|
"inputs": {
|
||||||
|
"devenv": "devenv",
|
||||||
|
"nixpkgs": "nixpkgs"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"root": "root",
|
||||||
|
"version": 7
|
||||||
|
}
|
||||||
104
devenv.nix
Normal file
104
devenv.nix
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
{
|
||||||
|
pkgs,
|
||||||
|
lib,
|
||||||
|
config,
|
||||||
|
inputs,
|
||||||
|
...
|
||||||
|
}:
|
||||||
|
|
||||||
|
{
|
||||||
|
starship = {
|
||||||
|
enable = true;
|
||||||
|
config = {
|
||||||
|
enable = true;
|
||||||
|
path = ./starship.toml;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
env = {
|
||||||
|
AI_GATEWAY_DATABASE_NAME = "easyai_ai_gateway";
|
||||||
|
AI_GATEWAY_DATABASE_URL = "host=${config.env.DEVENV_RUNTIME}/postgres dbname=easyai_ai_gateway sslmode=disable";
|
||||||
|
AI_GATEWAY_SKIP_DB_CREATE = "1";
|
||||||
|
};
|
||||||
|
|
||||||
|
packages = with pkgs; [
|
||||||
|
curl
|
||||||
|
docker-client
|
||||||
|
git
|
||||||
|
jq
|
||||||
|
lsof
|
||||||
|
postgresql_18
|
||||||
|
ripgrep
|
||||||
|
watchexec
|
||||||
|
];
|
||||||
|
|
||||||
|
scripts = {
|
||||||
|
dev.exec = "pnpm dev";
|
||||||
|
build.exec = "pnpm build";
|
||||||
|
test-all.exec = "pnpm test";
|
||||||
|
lint.exec = "pnpm lint";
|
||||||
|
migrate.exec = "pnpm migrate";
|
||||||
|
db-create.exec = "pnpm db:create";
|
||||||
|
api-test.exec = "pnpm nx run api:test";
|
||||||
|
web-build.exec = "pnpm nx run web:build";
|
||||||
|
};
|
||||||
|
|
||||||
|
services.postgres = {
|
||||||
|
enable = true;
|
||||||
|
package = pkgs.postgresql_18.withPackages (postgresPackages: [
|
||||||
|
postgresPackages.pgvector
|
||||||
|
]);
|
||||||
|
listen_addresses = "";
|
||||||
|
initialDatabases = [
|
||||||
|
{
|
||||||
|
name = "easyai_ai_gateway";
|
||||||
|
initialSQL = ''
|
||||||
|
CREATE EXTENSION IF NOT EXISTS pgcrypto;
|
||||||
|
CREATE EXTENSION IF NOT EXISTS vector;
|
||||||
|
'';
|
||||||
|
}
|
||||||
|
];
|
||||||
|
};
|
||||||
|
|
||||||
|
# https://devenv.sh/languages/
|
||||||
|
languages.go = {
|
||||||
|
enable = true;
|
||||||
|
package = pkgs.go;
|
||||||
|
};
|
||||||
|
|
||||||
|
languages.javascript = {
|
||||||
|
enable = true;
|
||||||
|
package = pkgs.nodejs_22;
|
||||||
|
nodejs.enable = true;
|
||||||
|
lsp.enable = true;
|
||||||
|
pnpm = {
|
||||||
|
enable = true;
|
||||||
|
install.enable = true;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
enterShell = ''
|
||||||
|
echo ""
|
||||||
|
echo "EasyAI AI Gateway 开发环境已就绪"
|
||||||
|
echo "当前目录:$PWD"
|
||||||
|
echo ""
|
||||||
|
echo "运行时版本:"
|
||||||
|
echo " go: $(go version | awk '{print $3}')"
|
||||||
|
echo " node: $(node --version)"
|
||||||
|
echo " pnpm: $(pnpm --version)"
|
||||||
|
echo " psql: $(psql --version | awk '{print $3}')"
|
||||||
|
echo ""
|
||||||
|
echo "常用命令:"
|
||||||
|
echo " dev 创建/迁移数据库,并启动 API 和 Web"
|
||||||
|
echo " test-all 运行 API 和 Web 测试目标"
|
||||||
|
echo " build 构建 API 和 Web"
|
||||||
|
echo " lint 运行 Web 与 contracts 类型检查"
|
||||||
|
echo " migrate 执行 API 数据库迁移"
|
||||||
|
echo " db-create 创建 AI Gateway 数据库"
|
||||||
|
echo " api-test 只运行 Go API 测试"
|
||||||
|
echo " web-build 只构建 Web 前端"
|
||||||
|
echo ""
|
||||||
|
echo "提示:根 package.json 是唯一脚本入口;上述短命令由 devenv scripts 转发。"
|
||||||
|
echo ""
|
||||||
|
'';
|
||||||
|
}
|
||||||
18
devenv.yaml
Normal file
18
devenv.yaml
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
# yaml-language-server: $schema=https://devenv.sh/devenv.schema.json
|
||||||
|
inputs:
|
||||||
|
nixpkgs:
|
||||||
|
url: github:cachix/devenv-nixpkgs/rolling
|
||||||
|
|
||||||
|
# If you're using non-OSS software, you can set allowUnfree to true.
|
||||||
|
# allowUnfree: true
|
||||||
|
|
||||||
|
# If you're not willing to allow unsupported packages:
|
||||||
|
# allowUnsupportedSystem: false
|
||||||
|
|
||||||
|
# If you're willing to use a package that's vulnerable
|
||||||
|
# permittedInsecurePackages:
|
||||||
|
# - "openssl-1.1.1w"
|
||||||
|
|
||||||
|
# If you have more than one devenv you can merge them
|
||||||
|
#imports:
|
||||||
|
# - ./backend
|
||||||
@ -1505,6 +1505,22 @@ type ModelClient interface {
|
|||||||
- progress event snapshot:确保前端进度面板兼容。
|
- progress event snapshot:确保前端进度面板兼容。
|
||||||
- billing snapshot:确保预估扣费和最终 billings 语义一致。
|
- billing snapshot:确保预估扣费和最终 billings 语义一致。
|
||||||
|
|
||||||
|
OpenAI-compatible 文本请求中的推理深度统一使用 `reasoning_effort` 表达。该字段是请求参数,不是响应中的推理内容;模型能力中用 `thinkingEffortLevels` 声明该模型支持的可选取值。`reasoning_effort` 必须按开放字符串处理,不在网关层写死枚举;实际可用集合必须以 provider 和模型能力为准。常见取值定义如下:
|
||||||
|
|
||||||
|
| 值 | 含义 |
|
||||||
|
| --- | --- |
|
||||||
|
| `none` | 不启用额外推理,适用于不需要思考链路的低延迟请求。 |
|
||||||
|
| `minimal` | 最小推理预算,优先降低延迟和成本。 |
|
||||||
|
| `low` | 较低推理预算,用于简单推理任务。 |
|
||||||
|
| `medium` | 默认/均衡推理深度,在质量、延迟和成本之间折中。 |
|
||||||
|
| `high` | 较高推理预算,用于复杂规划、代码和多步推理。 |
|
||||||
|
| `xhigh` | 最高推理预算,仅在模型和 provider 明确支持时使用,通常成本和延迟最高。 |
|
||||||
|
| `max` | 供应商自定义最高档示例,例如 DeepSeek V4 类模型可能使用该值;语义以 provider 文档为准。 |
|
||||||
|
|
||||||
|
除上表外,`thinkingEffortLevels` 可以保存任意供应商自定义值,例如 `max`、`ultra` 或后续模型新增档位。管理端只提供常见值作为快捷选项,不应阻止自定义输入;请求透传时按模型能力校验或直接交由上游 provider 返回错误。
|
||||||
|
|
||||||
|
`reasoning_content`、推理过程 delta 或思考摘要在 Chat Completions 中不是 OpenAI 标准必需字段;如需兼容 DeepSeek、Qwen 等供应商扩展,应在 adapter 层作为可选扩展透传,并避免把 hidden reasoning 默认暴露给普通兼容客户端。
|
||||||
|
|
||||||
## 11. 队列持久化、恢复与限流执行
|
## 11. 队列持久化、恢复与限流执行
|
||||||
|
|
||||||
### 11.1 持久化队列原则
|
### 11.1 持久化队列原则
|
||||||
|
|||||||
34
docs/media-client-migration.md
Normal file
34
docs/media-client-migration.md
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
# Media Client Migration Status
|
||||||
|
|
||||||
|
This document tracks the server-main media runtime migration into the AI Gateway.
|
||||||
|
|
||||||
|
## Runtime Scope
|
||||||
|
|
||||||
|
- Included model types: `image_generate`, `image_edit`, `video_generate`, `image_to_video`, `omni_video`, `video_edit`.
|
||||||
|
- Excluded provider: `easyai`, because routing AI Gateway media tasks back into server-main would create a loopback dependency.
|
||||||
|
- Universal custom scripts are supported through `integration_platforms.config`:
|
||||||
|
- `customPreprocessScript`
|
||||||
|
- `customGetParamsScript`
|
||||||
|
- `customSubmitScript`
|
||||||
|
- `customPollScript`
|
||||||
|
- `getTaskURL`
|
||||||
|
- `skipParamNormalization`
|
||||||
|
|
||||||
|
## Migrated Clients
|
||||||
|
|
||||||
|
- `universal`: custom preprocess/get params/submit/poll scripts, default submit/poll, remote task resume.
|
||||||
|
- `jimeng`: async submit/poll skeleton with Jimeng task id and status mapping.
|
||||||
|
- `blackforest`: submit with `x-key`, `polling_url` polling, image result normalization.
|
||||||
|
- `tencent-hunyuan-image`: Tencent-style `Response.JobId`/`Response.Status` image task mapping.
|
||||||
|
- `tencent-hunyuan-video`: Tencent-style `Response.JobId`/`Response.Status` video task mapping.
|
||||||
|
- `minimax`: video submit/query task mapping.
|
||||||
|
- `midjourney`: diffusion submit, job polling, original and Aliyun-style status/result mapping.
|
||||||
|
- `vidu`: Token auth, typed submit path, creations polling.
|
||||||
|
- `aliyun-bailian`: video synthesis submit and task polling.
|
||||||
|
- `newapi`: `/videos/generations` submit and task polling.
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
|
||||||
|
- Provider-specific advanced parameter shaping remains isolated inside each client/spec.
|
||||||
|
- Tencent and Jimeng production deployments should configure exact submit/poll paths and credentials in platform config when they differ from the default server-main-compatible paths.
|
||||||
|
- Each migrated client has an `httptest` submit/poll coverage case in `internal/clients`.
|
||||||
@ -68,7 +68,11 @@ export AI_GATEWAY_DATABASE_URL="${AI_GATEWAY_DATABASE_URL:-postgresql://${AI_GAT
|
|||||||
|
|
||||||
echo "[ai-gateway] using database: ${AI_GATEWAY_DATABASE_URL}"
|
echo "[ai-gateway] using database: ${AI_GATEWAY_DATABASE_URL}"
|
||||||
|
|
||||||
scripts/create-database.sh
|
if [[ "${AI_GATEWAY_SKIP_DB_CREATE:-}" == "1" ]]; then
|
||||||
|
echo "[ai-gateway] skipping Docker database creation"
|
||||||
|
else
|
||||||
|
scripts/create-database.sh
|
||||||
|
fi
|
||||||
pnpm nx run api:migrate
|
pnpm nx run api:migrate
|
||||||
stop_stale_api_processes
|
stop_stale_api_processes
|
||||||
exec pnpm nx run-many -t dev -p api web --parallel=2
|
exec pnpm nx run-many -t dev -p api web --parallel=2
|
||||||
|
|||||||
101
starship.toml
Normal file
101
starship.toml
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
# --- 全局结构 (极致紧凑,高信息密度,AI友好) ---
|
||||||
|
format = """
|
||||||
|
$directory\
|
||||||
|
$git_branch\
|
||||||
|
$git_status\
|
||||||
|
$nix_shell\
|
||||||
|
$nodejs\
|
||||||
|
$bun\
|
||||||
|
$rust\
|
||||||
|
$golang\
|
||||||
|
$cmd_duration\
|
||||||
|
$memory_usage\
|
||||||
|
$status\
|
||||||
|
$line_break\
|
||||||
|
$character"""
|
||||||
|
|
||||||
|
# --- 目录 ---
|
||||||
|
[directory]
|
||||||
|
style = "bold cyan"
|
||||||
|
truncation_length = 3
|
||||||
|
truncate_to_repo = false
|
||||||
|
truncation_symbol = ".../"
|
||||||
|
home_symbol = "~"
|
||||||
|
read_only = " [RO]"
|
||||||
|
|
||||||
|
# --- Git 状态 (纯文本,避免解析错误与乱码) ---
|
||||||
|
[git_branch]
|
||||||
|
symbol = "git:"
|
||||||
|
style = "bold purple"
|
||||||
|
format = "[$symbol$branch]($style) "
|
||||||
|
|
||||||
|
[git_status]
|
||||||
|
format = "[$all_status$ahead_behind]($style) "
|
||||||
|
style = "bold red"
|
||||||
|
conflicted = "="
|
||||||
|
ahead = ">"
|
||||||
|
behind = "<"
|
||||||
|
diverged = "<>"
|
||||||
|
untracked = "?"
|
||||||
|
stashed = "*"
|
||||||
|
modified = "!"
|
||||||
|
staged = "+"
|
||||||
|
renamed = "»"
|
||||||
|
deleted = "x"
|
||||||
|
|
||||||
|
# --- 编程语言与环境 (紧凑标签格式) ---
|
||||||
|
[nodejs]
|
||||||
|
symbol = "node:"
|
||||||
|
style = "bold green"
|
||||||
|
format = "[$symbol$version]($style) "
|
||||||
|
detect_files = ["package.json", ".node-version"]
|
||||||
|
|
||||||
|
[bun]
|
||||||
|
symbol = "bun:"
|
||||||
|
style = "bold blue"
|
||||||
|
format = "[$symbol$version]($style) "
|
||||||
|
|
||||||
|
[rust]
|
||||||
|
symbol = "rust:"
|
||||||
|
style = "bold 208"
|
||||||
|
format = "[$symbol$version]($style) "
|
||||||
|
|
||||||
|
[golang]
|
||||||
|
symbol = "go:"
|
||||||
|
style = "bold cyan"
|
||||||
|
format = "[$symbol$version]($style) "
|
||||||
|
|
||||||
|
[nix_shell]
|
||||||
|
symbol = "nix:"
|
||||||
|
style = "bold blue"
|
||||||
|
format = "[$symbol$state]($style) "
|
||||||
|
impure_msg = "impure"
|
||||||
|
pure_msg = "pure"
|
||||||
|
|
||||||
|
# --- AI 辅助决策信息 (性能与状态反馈) ---
|
||||||
|
[cmd_duration]
|
||||||
|
min_time = 2_000
|
||||||
|
format = "took [$duration]($style) "
|
||||||
|
style = "bold yellow"
|
||||||
|
|
||||||
|
[memory_usage]
|
||||||
|
symbol = "mem:"
|
||||||
|
disabled = false
|
||||||
|
threshold = 75
|
||||||
|
format = "[$symbol$ram_pct]($style) "
|
||||||
|
style = "bold dimmed white"
|
||||||
|
|
||||||
|
[status]
|
||||||
|
disabled = false
|
||||||
|
format = "[ERR:$status]($style) "
|
||||||
|
style = "bold red"
|
||||||
|
|
||||||
|
# --- 交互符号 (带明确状态码) ---
|
||||||
|
[character]
|
||||||
|
success_symbol = "[>](bold green)"
|
||||||
|
error_symbol = "[>](bold red)"
|
||||||
|
vicmd_symbol = "[<](bold green)"
|
||||||
|
|
||||||
|
# --- 兼容性补充 ---
|
||||||
|
[package]
|
||||||
|
disabled = true
|
||||||
Loading…
Reference in New Issue
Block a user