From 00379b4acf8c65822812b4ad5c7dd68ec6d3043d Mon Sep 17 00:00:00 2001 From: Tara Ding <38710454+windtara0619@users.noreply.github.com> Date: Sun, 26 Apr 2026 19:41:55 -0700 Subject: [PATCH] Move benchmark serving client into benchmarks folder --- benchmarks/benchmark_comfyui_serving.py | 374 ++++++++++++++++++++++++ comfy/cli_args.py | 1 + main.py | 37 ++- server.py | 78 ++++- 4 files changed, 475 insertions(+), 15 deletions(-) create mode 100644 benchmarks/benchmark_comfyui_serving.py diff --git a/benchmarks/benchmark_comfyui_serving.py b/benchmarks/benchmark_comfyui_serving.py new file mode 100644 index 000000000..07d498c21 --- /dev/null +++ b/benchmarks/benchmark_comfyui_serving.py @@ -0,0 +1,374 @@ +#!/usr/bin/env python3 +""" +Simple serving benchmark client for ComfyUI's HTTP API. + +This script is inspired by diffusion serving benchmarks and is designed to: + - submit prompts to ComfyUI (/prompt or /bench/prompt), + - optionally shape request arrivals (fixed rate or Poisson), + - poll completion via /history/{prompt_id}, + - report latency/throughput/error metrics. +""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import math +import random +import statistics +import time +import uuid +from dataclasses import dataclass, asdict +from pathlib import Path +from typing import Any + +import aiohttp + + +@dataclass +class RequestResult: + request_index: int + prompt_id: str | None + ok: bool + error: str | None + queued_at: float + started_at: float + finished_at: float + end_to_end_s: float + queue_wait_ms: float | None + execution_ms: float | None + + +def percentile(values: list[float], pct: float) -> float: + if not values: + return float("nan") + if len(values) == 1: + return values[0] + values = sorted(values) + rank = (len(values) - 1) * (pct / 100.0) + lower = math.floor(rank) + upper = math.ceil(rank) + if lower == upper: + return values[lower] + weight = rank - lower + return values[lower] * (1.0 - weight) + values[upper] * weight + + +def patch_seed_in_prompt(prompt: dict[str, Any], seed: int, seed_path: str | None) -> dict[str, Any]: + """ + Patch prompt seed in-place for common sampler nodes. + seed_path format: ".". + """ + if seed_path: + try: + node_id, input_name = seed_path.split(".", 1) + prompt[node_id]["inputs"][input_name] = seed + return prompt + except Exception as exc: + raise ValueError(f"Invalid --seed-path '{seed_path}': {exc}") from exc + + # Best-effort fallback: update any input key named 'seed' or 'noise_seed' + for node in prompt.values(): + if not isinstance(node, dict): + continue + inputs = node.get("inputs") + if not isinstance(inputs, dict): + continue + if "seed" in inputs: + inputs["seed"] = seed + if "noise_seed" in inputs: + inputs["noise_seed"] = seed + return prompt + + +def load_prompt_template(path: Path) -> dict[str, Any]: + data = json.loads(path.read_text()) + if "prompt" in data and isinstance(data["prompt"], dict): + return data + if isinstance(data, dict): + return {"prompt": data} + raise ValueError("Prompt file must be a JSON object (prompt graph or wrapper with 'prompt').") + + +async def submit_prompt( + session: aiohttp.ClientSession, + base_url: str, + endpoint: str, + payload: dict[str, Any], + timeout_s: float, +) -> str: + url = f"{base_url}{endpoint}" + async with session.post(url, json=payload, timeout=timeout_s) as resp: + text = await resp.text() + if resp.status != 200: + raise RuntimeError(f"submit failed [{resp.status}] {text}") + body = json.loads(text) + prompt_id = body.get("prompt_id") + if not prompt_id: + raise RuntimeError(f"missing prompt_id in response: {body}") + return prompt_id + + +async def wait_for_prompt_done( + session: aiohttp.ClientSession, + base_url: str, + prompt_id: str, + poll_interval_s: float, + timeout_s: float, +) -> tuple[float | None, float | None]: + """ + Returns (queue_wait_ms, execution_ms) when available from history status messages. + Falls back to (None, None) if unavailable. + """ + deadline = time.perf_counter() + timeout_s + history_url = f"{base_url}/history/{prompt_id}" + + while time.perf_counter() < deadline: + async with session.get(history_url, timeout=timeout_s) as resp: + if resp.status != 200: + text = await resp.text() + raise RuntimeError(f"history failed [{resp.status}] {text}") + + payload = await resp.json() + if not payload: + await asyncio.sleep(poll_interval_s) + continue + + history_item = payload.get(prompt_id) + if history_item is None: + await asyncio.sleep(poll_interval_s) + continue + + status = history_item.get("status", {}) + status_str = status.get("status_str") + messages = status.get("messages", []) + if status_str not in ("success", "error"): + await asyncio.sleep(poll_interval_s) + continue + + queue_wait_ms = None + execution_ms = None + try: + timestamp_map: dict[str, int] = {} + for event, msg in messages: + if isinstance(msg, dict) and "timestamp" in msg: + timestamp_map[event] = int(msg["timestamp"]) + start_ts = timestamp_map.get("execution_start") + end_ts = timestamp_map.get("execution_success") or timestamp_map.get("execution_error") + if start_ts is not None and end_ts is not None: + execution_ms = max(0.0, end_ts - start_ts) + except Exception: + execution_ms = None + + return queue_wait_ms, execution_ms + + await asyncio.sleep(poll_interval_s) + + raise TimeoutError(f"timed out waiting for prompt_id={prompt_id}") + + +def build_arrival_schedule(num_requests: int, request_rate: float, poisson: bool, seed: int) -> list[float]: + """ + Returns absolute offsets (seconds from benchmark start) for each request. + """ + if request_rate <= 0: + return [0.0] * num_requests + + rnd = random.Random(seed) + offsets: list[float] = [] + t = 0.0 + for _ in range(num_requests): + if poisson: + delta = rnd.expovariate(request_rate) + else: + delta = 1.0 / request_rate + t += delta + offsets.append(t) + return offsets + + +async def run_request( + idx: int, + start_time: float, + scheduled_offset_s: float, + semaphore: asyncio.Semaphore, + session: aiohttp.ClientSession, + args: argparse.Namespace, + prompt_wrapper_template: dict[str, Any], +) -> RequestResult: + await asyncio.sleep(max(0.0, (start_time + scheduled_offset_s) - time.perf_counter())) + queued_at = time.perf_counter() + + async with semaphore: + started_at = time.perf_counter() + prompt_id = None + try: + payload = json.loads(json.dumps(prompt_wrapper_template)) + payload.setdefault("extra_data", {}) + payload["client_id"] = args.client_id + + seed = args.base_seed + idx + payload["prompt"] = patch_seed_in_prompt(payload["prompt"], seed, args.seed_path) + + prompt_id = await submit_prompt( + session=session, + base_url=args.host, + endpoint=args.endpoint, + payload=payload, + timeout_s=args.request_timeout_s, + ) + + queue_wait_ms, execution_ms = await wait_for_prompt_done( + session=session, + base_url=args.host, + prompt_id=prompt_id, + poll_interval_s=args.poll_interval_s, + timeout_s=args.request_timeout_s, + ) + finished_at = time.perf_counter() + return RequestResult( + request_index=idx, + prompt_id=prompt_id, + ok=True, + error=None, + queued_at=queued_at, + started_at=started_at, + finished_at=finished_at, + end_to_end_s=finished_at - queued_at, + queue_wait_ms=queue_wait_ms, + execution_ms=execution_ms, + ) + except Exception as exc: + finished_at = time.perf_counter() + return RequestResult( + request_index=idx, + prompt_id=prompt_id, + ok=False, + error=repr(exc), + queued_at=queued_at, + started_at=started_at, + finished_at=finished_at, + end_to_end_s=finished_at - queued_at, + queue_wait_ms=None, + execution_ms=None, + ) + + +def print_summary(results: list[RequestResult], wall_s: float) -> None: + success = [r for r in results if r.ok] + fail = [r for r in results if not r.ok] + lat_s = [r.end_to_end_s for r in success] + queue_wait_ms = [r.queue_wait_ms for r in success if r.queue_wait_ms is not None] + exec_ms = [r.execution_ms for r in success if r.execution_ms is not None] + + throughput = (len(success) / wall_s) if wall_s > 0 else 0.0 + print("\n=== ComfyUI Serving Benchmark Summary ===") + print(f"requests_total: {len(results)}") + print(f"requests_success: {len(success)}") + print(f"requests_failed: {len(fail)}") + print(f"wall_time_s: {wall_s:.3f}") + print(f"throughput_req_s: {throughput:.3f}") + + if lat_s: + print(f"latency_p50_s: {percentile(lat_s, 50):.3f}") + print(f"latency_p90_s: {percentile(lat_s, 90):.3f}") + print(f"latency_p95_s: {percentile(lat_s, 95):.3f}") + print(f"latency_p99_s: {percentile(lat_s, 99):.3f}") + print(f"latency_mean_s: {statistics.mean(lat_s):.3f}") + print(f"latency_max_s: {max(lat_s):.3f}") + + if queue_wait_ms: + print(f"queue_wait_mean_ms: {statistics.mean(queue_wait_ms):.2f}") + print(f"queue_wait_p95_ms: {percentile(queue_wait_ms, 95):.2f}") + + if exec_ms: + print(f"execution_mean_ms: {statistics.mean(exec_ms):.2f}") + print(f"execution_p95_ms: {percentile(exec_ms, 95):.2f}") + + if fail: + print("\nSample failures:") + for r in fail[:5]: + print(f" idx={r.request_index} prompt_id={r.prompt_id} error={r.error}") + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description="Benchmark ComfyUI request serving.") + p.add_argument("--host", type=str, default="http://127.0.0.1:8188", help="ComfyUI base URL.") + p.add_argument( + "--endpoint", + type=str, + default="/prompt", + choices=("/prompt", "/bench/prompt"), + help="Submission endpoint.", + ) + p.add_argument("--prompt-file", type=Path, required=True, help="Path to prompt JSON.") + p.add_argument("--num-requests", type=int, default=50) + p.add_argument("--max-concurrency", type=int, default=8) + p.add_argument("--request-rate", type=float, default=0.0, help="Requests/sec. 0 = fire immediately.") + p.add_argument("--poisson", action="store_true", help="Use Poisson inter-arrival when request-rate > 0.") + p.add_argument("--base-seed", type=int, default=1234) + p.add_argument( + "--seed-path", + type=str, + default=None, + help="Optional path to seed field in prompt: . (e.g. 3.seed).", + ) + p.add_argument("--client-id", type=str, default=f"bench-{uuid.uuid4().hex[:12]}") + p.add_argument("--request-timeout-s", type=float, default=600.0) + p.add_argument("--poll-interval-s", type=float, default=0.2) + p.add_argument("--output-json", type=Path, default=None, help="Write detailed result JSON.") + p.add_argument("--seed", type=int, default=0, help="RNG seed for schedule generation.") + return p.parse_args() + + +async def async_main(args: argparse.Namespace) -> None: + prompt_template = load_prompt_template(args.prompt_file) + schedule = build_arrival_schedule( + num_requests=args.num_requests, + request_rate=args.request_rate, + poisson=args.poisson, + seed=args.seed, + ) + semaphore = asyncio.Semaphore(args.max_concurrency) + connector = aiohttp.TCPConnector(limit=max(args.max_concurrency * 2, 32)) + + started = time.perf_counter() + async with aiohttp.ClientSession(connector=connector) as session: + tasks = [ + asyncio.create_task( + run_request( + idx=i, + start_time=started, + scheduled_offset_s=schedule[i], + semaphore=semaphore, + session=session, + args=args, + prompt_wrapper_template=prompt_template, + ) + ) + for i in range(args.num_requests) + ] + results = await asyncio.gather(*tasks) + wall_s = time.perf_counter() - started + + print_summary(results, wall_s) + + if args.output_json is not None: + out = { + "config": vars(args), + "wall_time_s": wall_s, + "results": [asdict(r) for r in sorted(results, key=lambda x: x.request_index)], + } + args.output_json.parent.mkdir(parents=True, exist_ok=True) + args.output_json.write_text(json.dumps(out, indent=2)) + print(f"\nWrote results to: {args.output_json}") + + +def main() -> None: + args = parse_args() + asyncio.run(async_main(args)) + + +if __name__ == "__main__": + main() diff --git a/comfy/cli_args.py b/comfy/cli_args.py index dbaadf723..e9828a5db 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -225,6 +225,7 @@ parser.add_argument( parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path. Overrides --base-directory.") parser.add_argument("--enable-compress-response-body", action="store_true", help="Enable compressing response body.") +parser.add_argument("--benchmark-server-only", action="store_true", help="Enable lightweight benchmark routes and worker fast-paths focused on model serving throughput/latency.") parser.add_argument( "--comfy-api-base", diff --git a/main.py b/main.py index dbaf2745c..5013bac42 100644 --- a/main.py +++ b/main.py @@ -293,6 +293,7 @@ def prompt_worker(q, server_instance): gc_collect_interval = 10.0 while True: + benchmark_mode = args.benchmark_server_only timeout = 1000.0 if need_gc: timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0) @@ -301,6 +302,7 @@ def prompt_worker(q, server_instance): if queue_item is not None: item, item_id = queue_item execution_start_time = time.perf_counter() + execution_start_wall_ms = int(time.time() * 1000) prompt_id = item[1] server_instance.last_prompt_id = prompt_id @@ -308,15 +310,21 @@ def prompt_worker(q, server_instance): extra_data = item[3].copy() for k in sensitive: extra_data[k] = sensitive[k] + benchmark_mode = args.benchmark_server_only or extra_data.get("benchmark_server_only", False) - asset_seeder.pause() + if not benchmark_mode: + asset_seeder.pause() e.execute(item[2], prompt_id, extra_data, item[4]) need_gc = True remove_sensitive = lambda prompt: prompt[:5] + prompt[6:] + history_result = e.history_result + if benchmark_mode: + history_result = {"outputs": {}, "meta": {}} + q.task_done(item_id, - e.history_result, + history_result, status=execution.PromptQueue.ExecutionStatus( status_str='success' if e.success else 'error', completed=e.success, @@ -325,16 +333,24 @@ def prompt_worker(q, server_instance): server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id) current_time = time.perf_counter() - execution_time = current_time - execution_start_time + execution_time_s = current_time - execution_start_time # Log Time in a more readable way after 10 minutes - if execution_time > 600: - execution_time = time.strftime("%H:%M:%S", time.gmtime(execution_time)) - logging.info(f"Prompt executed in {execution_time}") + if execution_time_s > 600: + execution_time_formatted = time.strftime("%H:%M:%S", time.gmtime(execution_time_s)) + logging.info(f"Prompt executed in {execution_time_formatted}") else: - logging.info("Prompt executed in {:.2f} seconds".format(execution_time)) + logging.info("Prompt executed in {:.2f} seconds".format(execution_time_s)) - if not asset_seeder.is_disabled(): + queue_wait_ms = 0.0 + created_at = extra_data.get("create_time") + if isinstance(created_at, int): + queue_wait_ms = max(0.0, execution_start_wall_ms - created_at) + + if benchmark_mode: + server_instance.record_benchmark_result(prompt_id, e.success, execution_time_s * 1000.0, queue_wait_ms) + + if not benchmark_mode and not asset_seeder.is_disabled(): paths = _collect_output_absolute_paths(e.history_result) register_output_files(paths, job_id=prompt_id) @@ -360,9 +376,10 @@ def prompt_worker(q, server_instance): need_gc = False hook_breaker_ac10a0.restore_functions() - if not asset_seeder.is_disabled(): + if not benchmark_mode and not asset_seeder.is_disabled(): asset_seeder.enqueue_enrich(roots=("output",), compute_hashes=True) - asset_seeder.resume() + if not benchmark_mode: + asset_seeder.resume() async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None): diff --git a/server.py b/server.py index 881da8e66..5db448b7f 100644 --- a/server.py +++ b/server.py @@ -16,6 +16,7 @@ import struct import ssl import socket import ipaddress +import threading from PIL import Image, ImageOps from PIL.PngImagePlugin import PngInfo from io import BytesIO @@ -252,6 +253,17 @@ class PromptServer(): self.client_id = None self.on_prompt_handlers = [] + self._benchmark_lock = threading.Lock() + self._benchmark_stats = { + "requests_total": 0, + "requests_success": 0, + "requests_error": 0, + "latency_ms_total": 0.0, + "latency_ms_max": 0.0, + "queue_wait_ms_total": 0.0, + "queue_wait_ms_max": 0.0, + "last_prompt_id": None, + } @routes.get('/ws') async def websocket_handler(request): @@ -912,12 +924,17 @@ class PromptServer(): queue_info['queue_pending'] = _remove_sensitive_from_queue(current_queue[1]) return web.json_response(queue_info) - @routes.post("/prompt") - async def post_prompt(request): - logging.info("got prompt") - json_data = await request.json() - json_data = self.trigger_on_prompt(json_data) + @routes.get("/bench/stats") + async def get_bench_stats(request): + stats = self.get_benchmark_stats() + return web.json_response(stats) + @routes.post("/bench/reset") + async def reset_bench_stats(request): + self.reset_benchmark_stats() + return web.json_response({"status": "ok"}) + + async def enqueue_prompt(json_data): if "number" in json_data: number = float(json_data['number']) else: @@ -967,6 +984,22 @@ class PromptServer(): } return web.json_response({"error": error, "node_errors": {}}, status=400) + @routes.post("/bench/prompt") + async def post_bench_prompt(request): + json_data = await request.json() + json_data = self.trigger_on_prompt(json_data) + extra_data = json_data.setdefault("extra_data", {}) + extra_data["benchmark_server_only"] = True + extra_data.setdefault("preview_method", "none") + return await enqueue_prompt(json_data) + + @routes.post("/prompt") + async def post_prompt(request): + logging.info("got prompt") + json_data = await request.json() + json_data = self.trigger_on_prompt(json_data) + return await enqueue_prompt(json_data) + @routes.post("/queue") async def post_queue(request): json_data = await request.json() @@ -1111,6 +1144,41 @@ class PromptServer(): prompt_info['exec_info'] = exec_info return prompt_info + def reset_benchmark_stats(self): + with self._benchmark_lock: + self._benchmark_stats = { + "requests_total": 0, + "requests_success": 0, + "requests_error": 0, + "latency_ms_total": 0.0, + "latency_ms_max": 0.0, + "queue_wait_ms_total": 0.0, + "queue_wait_ms_max": 0.0, + "last_prompt_id": None, + } + + def record_benchmark_result(self, prompt_id, success, latency_ms, queue_wait_ms=0.0): + with self._benchmark_lock: + self._benchmark_stats["requests_total"] += 1 + if success: + self._benchmark_stats["requests_success"] += 1 + else: + self._benchmark_stats["requests_error"] += 1 + self._benchmark_stats["latency_ms_total"] += max(0.0, latency_ms) + self._benchmark_stats["queue_wait_ms_total"] += max(0.0, queue_wait_ms) + self._benchmark_stats["latency_ms_max"] = max(self._benchmark_stats["latency_ms_max"], max(0.0, latency_ms)) + self._benchmark_stats["queue_wait_ms_max"] = max(self._benchmark_stats["queue_wait_ms_max"], max(0.0, queue_wait_ms)) + self._benchmark_stats["last_prompt_id"] = prompt_id + + def get_benchmark_stats(self): + with self._benchmark_lock: + stats = dict(self._benchmark_stats) + + total = stats["requests_total"] + stats["latency_ms_avg"] = (stats["latency_ms_total"] / total) if total > 0 else 0.0 + stats["queue_wait_ms_avg"] = (stats["queue_wait_ms_total"] / total) if total > 0 else 0.0 + return stats + async def send(self, event, data, sid=None): if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE: await self.send_image(data, sid=sid)