diff --git a/apps/api/internal/httpapi/core_flow_integration_test.go b/apps/api/internal/httpapi/core_flow_integration_test.go index 3b4047b..41bf15f 100644 --- a/apps/api/internal/httpapi/core_flow_integration_test.go +++ b/apps/api/internal/httpapi/core_flow_integration_test.go @@ -783,6 +783,24 @@ WHERE reference_type = 'gateway_task' if len(asyncRateLimitDetail.Attempts) != 0 { t.Fatalf("async rate-limited task should wait in queue without recording a failed attempt: %+v", asyncRateLimitDetail) } + var modelRateLimits struct { + Items []struct { + ModelName string `json:"modelName"` + ModelAlias string `json:"modelAlias"` + QueuedTasks float64 `json:"queuedTasks"` + } `json:"items"` + } + doJSON(t, server.URL, http.MethodGet, "/api/admin/runtime/model-rate-limits", loginResponse.AccessToken, nil, http.StatusOK, &modelRateLimits) + var queuedTasks float64 + for _, item := range modelRateLimits.Items { + if item.ModelName == rateLimitedModel || item.ModelAlias == rateLimitedModel { + queuedTasks = item.QueuedTasks + break + } + } + if queuedTasks < 1 { + t.Fatalf("realtime load should count async rate-limited task as queued, got %v in %+v", queuedTasks, modelRateLimits.Items) + } asyncRateLimitCompleted := waitForTaskStatus(t, server.URL, apiKeyResponse.Secret, asyncRateLimitTask.TaskID, []string{"succeeded"}, time.Duration(rateLimitWindowSeconds+3)*time.Second) if asyncRateLimitCompleted.Status != "succeeded" { t.Fatalf("async rate-limited task should be pulled from queue after the limit window resets, got %+v", asyncRateLimitCompleted) diff --git a/apps/api/internal/runner/service.go b/apps/api/internal/runner/service.go index 55dde51..8fc8780 100644 --- a/apps/api/internal/runner/service.go +++ b/apps/api/internal/runner/service.go @@ -121,11 +121,13 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut maxFailoverDuration := maxFailoverDurationForCandidates(candidates, runnerPolicy) attemptNo := task.AttemptCount var lastErr error + var lastCandidate store.RuntimeModelCandidate candidatesLoop: for index, candidate := range candidates { if index >= maxPlatforms { break } + lastCandidate = candidate clientAttempts := clientAttemptsForCandidate(candidate) var candidateErr error for clientAttempt := 1; clientAttempt <= clientAttempts; clientAttempt++ { @@ -180,7 +182,7 @@ candidatesLoop: lastErr = err candidateErr = err if task.AsyncMode && store.RateLimitRetryable(err) { - queued, delay, queueErr := s.requeueRateLimitedTask(ctx, task, err) + queued, delay, queueErr := s.requeueRateLimitedTask(ctx, task, err, candidate) if queueErr != nil { return Result{}, queueErr } @@ -297,7 +299,7 @@ candidatesLoop: return Result{Task: queued, Output: queued.Result}, &TaskQueuedError{Delay: 0} } if task.AsyncMode && errors.Is(lastErr, store.ErrRateLimited) && store.RateLimitRetryable(lastErr) { - queued, delay, queueErr := s.requeueRateLimitedTask(ctx, task, lastErr) + queued, delay, queueErr := s.requeueRateLimitedTask(ctx, task, lastErr, lastCandidate) if queueErr != nil { return Result{}, queueErr } @@ -525,12 +527,12 @@ func (s *Service) failTask(ctx context.Context, taskID string, code string, mess return failed, nil } -func (s *Service) requeueRateLimitedTask(ctx context.Context, task store.GatewayTask, cause error) (store.GatewayTask, time.Duration, error) { +func (s *Service) requeueRateLimitedTask(ctx context.Context, task store.GatewayTask, cause error, candidate store.RuntimeModelCandidate) (store.GatewayTask, time.Duration, error) { delay := localRateLimitRetryAfter(cause) if delay <= 0 { delay = 5 * time.Second } - queued, err := s.store.RequeueTask(ctx, task.ID, delay) + queued, err := s.store.RequeueTask(ctx, task.ID, delay, candidate.QueueKey) if err != nil { return store.GatewayTask{}, 0, err } @@ -546,7 +548,7 @@ func (s *Service) requeueRateLimitedTask(ctx context.Context, task store.Gateway } func (s *Service) requeueInterruptedAsyncTask(ctx context.Context, task store.GatewayTask) (store.GatewayTask, error) { - queued, err := s.store.RequeueTask(ctx, task.ID, 0) + queued, err := s.store.RequeueTask(ctx, task.ID, 0, "") if err != nil { return store.GatewayTask{}, err } diff --git a/apps/api/internal/store/rate_limit_status.go b/apps/api/internal/store/rate_limit_status.go index 2bb0152..15c8439 100644 --- a/apps/api/internal/store/rate_limit_status.go +++ b/apps/api/internal/store/rate_limit_status.go @@ -62,17 +62,31 @@ LEFT JOIN ( GROUP BY scope_key ) con ON con.scope_key = m.id::text LEFT JOIN ( - SELECT latest.platform_model_id, COUNT(*) AS waiting + SELECT queued_sources.platform_model_id, COUNT(DISTINCT queued_sources.task_id) AS waiting FROM ( - SELECT DISTINCT ON (a.task_id) a.task_id, a.platform_model_id::text AS platform_model_id + SELECT t.id::text AS task_id, qm.id::text AS platform_model_id FROM gateway_tasks t - JOIN gateway_task_attempts a ON a.task_id = t.id + JOIN integration_platforms qp ON TRUE + JOIN platform_models qm ON qm.platform_id = qp.id WHERE t.async_mode = true AND t.status = 'queued' - AND a.platform_model_id IS NOT NULL - ORDER BY a.task_id, a.attempt_no DESC, a.started_at DESC - ) latest - GROUP BY latest.platform_model_id + AND NULLIF(t.model_type, '') IS NOT NULL + AND qm.model_type @> jsonb_build_array(t.model_type) + AND t.queue_key = qp.platform_key || ':' || t.model_type || ':' || COALESCE(NULLIF(qm.provider_model_name, ''), qm.model_name) + AND NOT EXISTS (SELECT 1 FROM gateway_task_attempts existing_attempt WHERE existing_attempt.task_id = t.id) + UNION ALL + SELECT latest.task_id, latest.platform_model_id + FROM ( + SELECT DISTINCT ON (a.task_id) a.task_id, a.platform_model_id::text AS platform_model_id + FROM gateway_tasks t + JOIN gateway_task_attempts a ON a.task_id = t.id + WHERE t.async_mode = true + AND t.status = 'queued' + AND a.platform_model_id IS NOT NULL + ORDER BY a.task_id, a.attempt_no DESC, a.started_at DESC + ) latest + ) queued_sources + GROUP BY queued_sources.platform_model_id ) queued ON queued.platform_model_id = m.id::text LEFT JOIN ( SELECT DISTINCT ON (scope_key) scope_key, used_value, reserved_value, reset_at diff --git a/apps/api/internal/store/tasks_runtime.go b/apps/api/internal/store/tasks_runtime.go index a3a1f42..2776e4c 100644 --- a/apps/api/internal/store/tasks_runtime.go +++ b/apps/api/internal/store/tasks_runtime.go @@ -189,7 +189,7 @@ WHERE t.id = picked.task_id RETURNING `+gatewayTaskColumns, workerID)) } -func (s *Store) RequeueTask(ctx context.Context, taskID string, delay time.Duration) (GatewayTask, error) { +func (s *Store) RequeueTask(ctx context.Context, taskID string, delay time.Duration, queueKey string) (GatewayTask, error) { if delay < time.Second { delay = time.Second } @@ -204,12 +204,13 @@ SET status = 'queued', locked_at = NULL, heartbeat_at = NULL, next_run_at = $2::timestamptz, + queue_key = COALESCE(NULLIF($3::text, ''), queue_key), error = NULL, error_code = NULL, error_message = NULL, updated_at = now() WHERE id = $1::uuid -RETURNING `+gatewayTaskColumns, taskID, nextRunAt)) +RETURNING `+gatewayTaskColumns, taskID, nextRunAt, strings.TrimSpace(queueKey))) } func (s *Store) SetTaskRiverJobID(ctx context.Context, taskID string, riverJobID int64) error { diff --git a/apps/web/src/pages/PlaygroundPage.tsx b/apps/web/src/pages/PlaygroundPage.tsx index 8ef1a34..7fd456a 100644 --- a/apps/web/src/pages/PlaygroundPage.tsx +++ b/apps/web/src/pages/PlaygroundPage.tsx @@ -120,6 +120,7 @@ export function PlaygroundPage(props: { const [mediaRuns, setMediaRuns] = useState(readStoredMediaRuns); const [mediaMessage, setMediaMessage] = useState(''); const isMountedRef = useRef(false); + const pendingMediaModelRef = useRef(''); const resumedTaskIdsRef = useRef(new Set()); const activeMode = useMemo(() => modeOptions.find((item) => item.value === props.mode) ?? modeOptions[0], [props.mode]); const modelOptions = useMemo( @@ -139,7 +140,17 @@ export function PlaygroundPage(props: { ); useEffect(() => { - setSelectedModel((current) => modelOptions.some((item) => item.value === current) ? current : modelOptions[0]?.value ?? ''); + setSelectedModel((current) => { + const pendingModel = pendingMediaModelRef.current; + if (pendingModel) { + const resolvedPending = resolveModelOptionValue(pendingModel, modelOptions); + if (resolvedPending) { + pendingMediaModelRef.current = ''; + return resolvedPending; + } + } + return modelOptions.some((item) => item.value === current) ? current : modelOptions[0]?.value ?? ''; + }); }, [modelOptions]); useEffect(() => { @@ -186,6 +197,7 @@ export function PlaygroundPage(props: { }, [activeApiKeySecret, mediaRuns, props.token]); async function submitMediaTask(overrides?: { + model?: string; mode?: Exclude; prompt?: string; settings?: MediaGenerationSettings; @@ -204,7 +216,8 @@ export function PlaygroundPage(props: { setMediaMessage('请选择可用于测试的 API Key;如果列表为空,请先创建一个 Key。'); return; } - if (!selectedModel) { + const runModel = overrides?.model ?? selectedModel; + if (!runModel) { setMediaMessage('当前没有可用模型,请确认用户组权限或平台模型配置。'); return; } @@ -214,12 +227,13 @@ export function PlaygroundPage(props: { } const localId = newLocalId(); - const modelLabel = modelOptions.find((item) => item.value === selectedModel)?.label ?? selectedModel; + const modelLabel = modelOptions.find((item) => item.value === runModel)?.label ?? runModel; const run: MediaGenerationRun = { createdAt: new Date().toISOString(), localId, mode: runMode, modelLabel, + modelValue: runModel, prompt: trimmedPrompt, settings: runSettings, status: 'submitting', @@ -229,7 +243,7 @@ export function PlaygroundPage(props: { setMediaMessage(''); try { const requestPayload = { - model: selectedModel, + model: runModel, prompt: trimmedPrompt, ...mediaRequestPayload(runSettings, runMode), }; @@ -268,24 +282,38 @@ export function PlaygroundPage(props: { })); } + function selectMediaRunModel(run: MediaGenerationRun) { + const runModel = resolveMediaRunModelValue(run, modelOptions); + if (runModel) { + pendingMediaModelRef.current = ''; + setSelectedModel(runModel); + return runModel; + } + const fallbackModel = firstString(run.modelValue, run.task?.requestedModel, taskRequestModel(run.task), run.task?.model, run.task?.resolvedModel); + pendingMediaModelRef.current = fallbackModel; + return fallbackModel; + } + function editMediaRun(run: MediaGenerationRun) { setPrompt(run.prompt); setMediaSettings(run.settings); + selectMediaRunModel(run); if (props.mode !== run.mode) { props.onModeChange(run.mode); } - setMediaMessage('已带入这条任务的提示词和参数,可调整后再次生成。'); + setMediaMessage('已带入这条任务的模型、提示词和参数,可调整后再次生成。'); } function rerunMediaRun(run: MediaGenerationRun) { setPrompt(run.prompt); setMediaSettings(run.settings); + const runModel = selectMediaRunModel(run); if (props.mode !== run.mode) { props.onModeChange(run.mode); - setMediaMessage('已切换到对应模式并带入参数,请确认模型后再次生成。'); + setMediaMessage('已切换到对应模式并带入模型和参数,请确认后再次生成。'); return; } - void submitMediaTask({ mode: run.mode, prompt: run.prompt, settings: run.settings }); + void submitMediaTask({ mode: run.mode, model: runModel, prompt: run.prompt, settings: run.settings }); } const mediaComposer = props.mode === 'chat' ? null : ( @@ -1043,6 +1071,46 @@ function modelOptionLabel(option: ModelOption) { return `${option.label}${provider}${count}`; } +function resolveMediaRunModelValue(run: MediaGenerationRun, modelOptions: ModelOption[]) { + const candidates = [ + run.modelValue, + run.task?.requestedModel, + taskRequestModel(run.task), + run.task?.model, + run.task?.resolvedModel, + ]; + for (const candidate of candidates) { + const value = resolveModelOptionValue(candidate, modelOptions); + if (value) return value; + } + return ''; +} + +function resolveModelOptionValue(value: unknown, modelOptions: ModelOption[]) { + const raw = stringFromUnknown(value); + if (!raw) return ''; + const direct = modelOptions.find((item) => item.value === raw); + if (direct) return direct.value; + const matched = modelOptions.find((item) => item.models.some((model) => ( + model.modelAlias === raw + || model.modelName === raw + || model.displayName === raw + ))); + return matched?.value ?? ''; +} + +function taskRequestModel(task: GatewayTask | undefined) { + return stringFromUnknown(task?.request?.model); +} + +function firstString(...values: unknown[]) { + for (const value of values) { + const text = stringFromUnknown(value); + if (text) return text; + } + return ''; +} + function updateMediaRun(runs: MediaGenerationRun[], localId: string, patch: Partial) { return runs.map((run) => run.localId === localId ? { ...run, ...patch } : run); } @@ -1091,7 +1159,8 @@ function mediaRunFromStorage(value: unknown, index: number): MediaGenerationRun const task = taskFromStorage(record.task); const createdAt = dateStringFromUnknown(record.createdAt) ?? new Date().toISOString(); const localId = stringFromUnknown(record.localId) || task?.id || `stored-${index}-${createdAt}`; - const modelLabel = stringFromUnknown(record.modelLabel) || task?.model || '未知模型'; + const modelValue = stringFromUnknown(record.modelValue) || task?.requestedModel || taskRequestModel(task) || task?.model || ''; + const modelLabel = stringFromUnknown(record.modelLabel) || modelValue || '未知模型'; let status: MediaGenerationRun['status'] = stringFromUnknown(record.status) || task?.status || 'failed'; let error = stringFromUnknown(record.error); if (status === 'submitting' && !task?.id) { @@ -1104,6 +1173,7 @@ function mediaRunFromStorage(value: unknown, index: number): MediaGenerationRun localId, mode, modelLabel, + modelValue, prompt, settings: mediaSettingsFromStorage(record.settings), status, diff --git a/apps/web/src/pages/playground-media.tsx b/apps/web/src/pages/playground-media.tsx index 4496b3d..f692f57 100644 --- a/apps/web/src/pages/playground-media.tsx +++ b/apps/web/src/pages/playground-media.tsx @@ -41,6 +41,7 @@ export interface MediaGenerationRun { localId: string; mode: Exclude; modelLabel: string; + modelValue?: string; prompt: string; settings: MediaGenerationSettings; status: GatewayTask['status'] | 'submitting';