Fix realtime queued task counts

This commit is contained in:
wangbo 2026-05-12 10:41:07 +08:00
parent 2a91b31d12
commit 9ea83be718
6 changed files with 128 additions and 22 deletions

View File

@ -783,6 +783,24 @@ WHERE reference_type = 'gateway_task'
if len(asyncRateLimitDetail.Attempts) != 0 {
t.Fatalf("async rate-limited task should wait in queue without recording a failed attempt: %+v", asyncRateLimitDetail)
}
var modelRateLimits struct {
Items []struct {
ModelName string `json:"modelName"`
ModelAlias string `json:"modelAlias"`
QueuedTasks float64 `json:"queuedTasks"`
} `json:"items"`
}
doJSON(t, server.URL, http.MethodGet, "/api/admin/runtime/model-rate-limits", loginResponse.AccessToken, nil, http.StatusOK, &modelRateLimits)
var queuedTasks float64
for _, item := range modelRateLimits.Items {
if item.ModelName == rateLimitedModel || item.ModelAlias == rateLimitedModel {
queuedTasks = item.QueuedTasks
break
}
}
if queuedTasks < 1 {
t.Fatalf("realtime load should count async rate-limited task as queued, got %v in %+v", queuedTasks, modelRateLimits.Items)
}
asyncRateLimitCompleted := waitForTaskStatus(t, server.URL, apiKeyResponse.Secret, asyncRateLimitTask.TaskID, []string{"succeeded"}, time.Duration(rateLimitWindowSeconds+3)*time.Second)
if asyncRateLimitCompleted.Status != "succeeded" {
t.Fatalf("async rate-limited task should be pulled from queue after the limit window resets, got %+v", asyncRateLimitCompleted)

View File

@ -121,11 +121,13 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
maxFailoverDuration := maxFailoverDurationForCandidates(candidates, runnerPolicy)
attemptNo := task.AttemptCount
var lastErr error
var lastCandidate store.RuntimeModelCandidate
candidatesLoop:
for index, candidate := range candidates {
if index >= maxPlatforms {
break
}
lastCandidate = candidate
clientAttempts := clientAttemptsForCandidate(candidate)
var candidateErr error
for clientAttempt := 1; clientAttempt <= clientAttempts; clientAttempt++ {
@ -180,7 +182,7 @@ candidatesLoop:
lastErr = err
candidateErr = err
if task.AsyncMode && store.RateLimitRetryable(err) {
queued, delay, queueErr := s.requeueRateLimitedTask(ctx, task, err)
queued, delay, queueErr := s.requeueRateLimitedTask(ctx, task, err, candidate)
if queueErr != nil {
return Result{}, queueErr
}
@ -297,7 +299,7 @@ candidatesLoop:
return Result{Task: queued, Output: queued.Result}, &TaskQueuedError{Delay: 0}
}
if task.AsyncMode && errors.Is(lastErr, store.ErrRateLimited) && store.RateLimitRetryable(lastErr) {
queued, delay, queueErr := s.requeueRateLimitedTask(ctx, task, lastErr)
queued, delay, queueErr := s.requeueRateLimitedTask(ctx, task, lastErr, lastCandidate)
if queueErr != nil {
return Result{}, queueErr
}
@ -525,12 +527,12 @@ func (s *Service) failTask(ctx context.Context, taskID string, code string, mess
return failed, nil
}
func (s *Service) requeueRateLimitedTask(ctx context.Context, task store.GatewayTask, cause error) (store.GatewayTask, time.Duration, error) {
func (s *Service) requeueRateLimitedTask(ctx context.Context, task store.GatewayTask, cause error, candidate store.RuntimeModelCandidate) (store.GatewayTask, time.Duration, error) {
delay := localRateLimitRetryAfter(cause)
if delay <= 0 {
delay = 5 * time.Second
}
queued, err := s.store.RequeueTask(ctx, task.ID, delay)
queued, err := s.store.RequeueTask(ctx, task.ID, delay, candidate.QueueKey)
if err != nil {
return store.GatewayTask{}, 0, err
}
@ -546,7 +548,7 @@ func (s *Service) requeueRateLimitedTask(ctx context.Context, task store.Gateway
}
func (s *Service) requeueInterruptedAsyncTask(ctx context.Context, task store.GatewayTask) (store.GatewayTask, error) {
queued, err := s.store.RequeueTask(ctx, task.ID, 0)
queued, err := s.store.RequeueTask(ctx, task.ID, 0, "")
if err != nil {
return store.GatewayTask{}, err
}

View File

@ -62,17 +62,31 @@ LEFT JOIN (
GROUP BY scope_key
) con ON con.scope_key = m.id::text
LEFT JOIN (
SELECT latest.platform_model_id, COUNT(*) AS waiting
SELECT queued_sources.platform_model_id, COUNT(DISTINCT queued_sources.task_id) AS waiting
FROM (
SELECT DISTINCT ON (a.task_id) a.task_id, a.platform_model_id::text AS platform_model_id
SELECT t.id::text AS task_id, qm.id::text AS platform_model_id
FROM gateway_tasks t
JOIN gateway_task_attempts a ON a.task_id = t.id
JOIN integration_platforms qp ON TRUE
JOIN platform_models qm ON qm.platform_id = qp.id
WHERE t.async_mode = true
AND t.status = 'queued'
AND a.platform_model_id IS NOT NULL
ORDER BY a.task_id, a.attempt_no DESC, a.started_at DESC
) latest
GROUP BY latest.platform_model_id
AND NULLIF(t.model_type, '') IS NOT NULL
AND qm.model_type @> jsonb_build_array(t.model_type)
AND t.queue_key = qp.platform_key || ':' || t.model_type || ':' || COALESCE(NULLIF(qm.provider_model_name, ''), qm.model_name)
AND NOT EXISTS (SELECT 1 FROM gateway_task_attempts existing_attempt WHERE existing_attempt.task_id = t.id)
UNION ALL
SELECT latest.task_id, latest.platform_model_id
FROM (
SELECT DISTINCT ON (a.task_id) a.task_id, a.platform_model_id::text AS platform_model_id
FROM gateway_tasks t
JOIN gateway_task_attempts a ON a.task_id = t.id
WHERE t.async_mode = true
AND t.status = 'queued'
AND a.platform_model_id IS NOT NULL
ORDER BY a.task_id, a.attempt_no DESC, a.started_at DESC
) latest
) queued_sources
GROUP BY queued_sources.platform_model_id
) queued ON queued.platform_model_id = m.id::text
LEFT JOIN (
SELECT DISTINCT ON (scope_key) scope_key, used_value, reserved_value, reset_at

View File

@ -189,7 +189,7 @@ WHERE t.id = picked.task_id
RETURNING `+gatewayTaskColumns, workerID))
}
func (s *Store) RequeueTask(ctx context.Context, taskID string, delay time.Duration) (GatewayTask, error) {
func (s *Store) RequeueTask(ctx context.Context, taskID string, delay time.Duration, queueKey string) (GatewayTask, error) {
if delay < time.Second {
delay = time.Second
}
@ -204,12 +204,13 @@ SET status = 'queued',
locked_at = NULL,
heartbeat_at = NULL,
next_run_at = $2::timestamptz,
queue_key = COALESCE(NULLIF($3::text, ''), queue_key),
error = NULL,
error_code = NULL,
error_message = NULL,
updated_at = now()
WHERE id = $1::uuid
RETURNING `+gatewayTaskColumns, taskID, nextRunAt))
RETURNING `+gatewayTaskColumns, taskID, nextRunAt, strings.TrimSpace(queueKey)))
}
func (s *Store) SetTaskRiverJobID(ctx context.Context, taskID string, riverJobID int64) error {

View File

@ -120,6 +120,7 @@ export function PlaygroundPage(props: {
const [mediaRuns, setMediaRuns] = useState<MediaGenerationRun[]>(readStoredMediaRuns);
const [mediaMessage, setMediaMessage] = useState('');
const isMountedRef = useRef(false);
const pendingMediaModelRef = useRef('');
const resumedTaskIdsRef = useRef(new Set<string>());
const activeMode = useMemo(() => modeOptions.find((item) => item.value === props.mode) ?? modeOptions[0], [props.mode]);
const modelOptions = useMemo(
@ -139,7 +140,17 @@ export function PlaygroundPage(props: {
);
useEffect(() => {
setSelectedModel((current) => modelOptions.some((item) => item.value === current) ? current : modelOptions[0]?.value ?? '');
setSelectedModel((current) => {
const pendingModel = pendingMediaModelRef.current;
if (pendingModel) {
const resolvedPending = resolveModelOptionValue(pendingModel, modelOptions);
if (resolvedPending) {
pendingMediaModelRef.current = '';
return resolvedPending;
}
}
return modelOptions.some((item) => item.value === current) ? current : modelOptions[0]?.value ?? '';
});
}, [modelOptions]);
useEffect(() => {
@ -186,6 +197,7 @@ export function PlaygroundPage(props: {
}, [activeApiKeySecret, mediaRuns, props.token]);
async function submitMediaTask(overrides?: {
model?: string;
mode?: Exclude<PlaygroundMode, 'chat'>;
prompt?: string;
settings?: MediaGenerationSettings;
@ -204,7 +216,8 @@ export function PlaygroundPage(props: {
setMediaMessage('请选择可用于测试的 API Key如果列表为空请先创建一个 Key。');
return;
}
if (!selectedModel) {
const runModel = overrides?.model ?? selectedModel;
if (!runModel) {
setMediaMessage('当前没有可用模型,请确认用户组权限或平台模型配置。');
return;
}
@ -214,12 +227,13 @@ export function PlaygroundPage(props: {
}
const localId = newLocalId();
const modelLabel = modelOptions.find((item) => item.value === selectedModel)?.label ?? selectedModel;
const modelLabel = modelOptions.find((item) => item.value === runModel)?.label ?? runModel;
const run: MediaGenerationRun = {
createdAt: new Date().toISOString(),
localId,
mode: runMode,
modelLabel,
modelValue: runModel,
prompt: trimmedPrompt,
settings: runSettings,
status: 'submitting',
@ -229,7 +243,7 @@ export function PlaygroundPage(props: {
setMediaMessage('');
try {
const requestPayload = {
model: selectedModel,
model: runModel,
prompt: trimmedPrompt,
...mediaRequestPayload(runSettings, runMode),
};
@ -268,24 +282,38 @@ export function PlaygroundPage(props: {
}));
}
function selectMediaRunModel(run: MediaGenerationRun) {
const runModel = resolveMediaRunModelValue(run, modelOptions);
if (runModel) {
pendingMediaModelRef.current = '';
setSelectedModel(runModel);
return runModel;
}
const fallbackModel = firstString(run.modelValue, run.task?.requestedModel, taskRequestModel(run.task), run.task?.model, run.task?.resolvedModel);
pendingMediaModelRef.current = fallbackModel;
return fallbackModel;
}
function editMediaRun(run: MediaGenerationRun) {
setPrompt(run.prompt);
setMediaSettings(run.settings);
selectMediaRunModel(run);
if (props.mode !== run.mode) {
props.onModeChange(run.mode);
}
setMediaMessage('已带入这条任务的提示词和参数,可调整后再次生成。');
setMediaMessage('已带入这条任务的模型、提示词和参数,可调整后再次生成。');
}
function rerunMediaRun(run: MediaGenerationRun) {
setPrompt(run.prompt);
setMediaSettings(run.settings);
const runModel = selectMediaRunModel(run);
if (props.mode !== run.mode) {
props.onModeChange(run.mode);
setMediaMessage('已切换到对应模式并带入参数,请确认模型后再次生成。');
setMediaMessage('已切换到对应模式并带入模型和参数,请确认后再次生成。');
return;
}
void submitMediaTask({ mode: run.mode, prompt: run.prompt, settings: run.settings });
void submitMediaTask({ mode: run.mode, model: runModel, prompt: run.prompt, settings: run.settings });
}
const mediaComposer = props.mode === 'chat' ? null : (
@ -1043,6 +1071,46 @@ function modelOptionLabel(option: ModelOption) {
return `${option.label}${provider}${count}`;
}
function resolveMediaRunModelValue(run: MediaGenerationRun, modelOptions: ModelOption[]) {
const candidates = [
run.modelValue,
run.task?.requestedModel,
taskRequestModel(run.task),
run.task?.model,
run.task?.resolvedModel,
];
for (const candidate of candidates) {
const value = resolveModelOptionValue(candidate, modelOptions);
if (value) return value;
}
return '';
}
function resolveModelOptionValue(value: unknown, modelOptions: ModelOption[]) {
const raw = stringFromUnknown(value);
if (!raw) return '';
const direct = modelOptions.find((item) => item.value === raw);
if (direct) return direct.value;
const matched = modelOptions.find((item) => item.models.some((model) => (
model.modelAlias === raw
|| model.modelName === raw
|| model.displayName === raw
)));
return matched?.value ?? '';
}
function taskRequestModel(task: GatewayTask | undefined) {
return stringFromUnknown(task?.request?.model);
}
function firstString(...values: unknown[]) {
for (const value of values) {
const text = stringFromUnknown(value);
if (text) return text;
}
return '';
}
function updateMediaRun(runs: MediaGenerationRun[], localId: string, patch: Partial<MediaGenerationRun>) {
return runs.map((run) => run.localId === localId ? { ...run, ...patch } : run);
}
@ -1091,7 +1159,8 @@ function mediaRunFromStorage(value: unknown, index: number): MediaGenerationRun
const task = taskFromStorage(record.task);
const createdAt = dateStringFromUnknown(record.createdAt) ?? new Date().toISOString();
const localId = stringFromUnknown(record.localId) || task?.id || `stored-${index}-${createdAt}`;
const modelLabel = stringFromUnknown(record.modelLabel) || task?.model || '未知模型';
const modelValue = stringFromUnknown(record.modelValue) || task?.requestedModel || taskRequestModel(task) || task?.model || '';
const modelLabel = stringFromUnknown(record.modelLabel) || modelValue || '未知模型';
let status: MediaGenerationRun['status'] = stringFromUnknown(record.status) || task?.status || 'failed';
let error = stringFromUnknown(record.error);
if (status === 'submitting' && !task?.id) {
@ -1104,6 +1173,7 @@ function mediaRunFromStorage(value: unknown, index: number): MediaGenerationRun
localId,
mode,
modelLabel,
modelValue,
prompt,
settings: mediaSettingsFromStorage(record.settings),
status,

View File

@ -41,6 +41,7 @@ export interface MediaGenerationRun {
localId: string;
mode: Exclude<PlaygroundMode, 'chat'>;
modelLabel: string;
modelValue?: string;
prompt: string;
settings: MediaGenerationSettings;
status: GatewayTask['status'] | 'submitting';