Move benchmark serving client into benchmarks folder

This commit is contained in:
Tara Ding 2026-04-26 19:41:55 -07:00
parent 96363fa74a
commit 00379b4acf
4 changed files with 475 additions and 15 deletions

View File

@ -0,0 +1,374 @@
#!/usr/bin/env python3
"""
Simple serving benchmark client for ComfyUI's HTTP API.
This script is inspired by diffusion serving benchmarks and is designed to:
- submit prompts to ComfyUI (/prompt or /bench/prompt),
- optionally shape request arrivals (fixed rate or Poisson),
- poll completion via /history/{prompt_id},
- report latency/throughput/error metrics.
"""
from __future__ import annotations
import argparse
import asyncio
import json
import math
import random
import statistics
import time
import uuid
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Any
import aiohttp
@dataclass
class RequestResult:
request_index: int
prompt_id: str | None
ok: bool
error: str | None
queued_at: float
started_at: float
finished_at: float
end_to_end_s: float
queue_wait_ms: float | None
execution_ms: float | None
def percentile(values: list[float], pct: float) -> float:
if not values:
return float("nan")
if len(values) == 1:
return values[0]
values = sorted(values)
rank = (len(values) - 1) * (pct / 100.0)
lower = math.floor(rank)
upper = math.ceil(rank)
if lower == upper:
return values[lower]
weight = rank - lower
return values[lower] * (1.0 - weight) + values[upper] * weight
def patch_seed_in_prompt(prompt: dict[str, Any], seed: int, seed_path: str | None) -> dict[str, Any]:
"""
Patch prompt seed in-place for common sampler nodes.
seed_path format: "<node_id>.<input_name>".
"""
if seed_path:
try:
node_id, input_name = seed_path.split(".", 1)
prompt[node_id]["inputs"][input_name] = seed
return prompt
except Exception as exc:
raise ValueError(f"Invalid --seed-path '{seed_path}': {exc}") from exc
# Best-effort fallback: update any input key named 'seed' or 'noise_seed'
for node in prompt.values():
if not isinstance(node, dict):
continue
inputs = node.get("inputs")
if not isinstance(inputs, dict):
continue
if "seed" in inputs:
inputs["seed"] = seed
if "noise_seed" in inputs:
inputs["noise_seed"] = seed
return prompt
def load_prompt_template(path: Path) -> dict[str, Any]:
data = json.loads(path.read_text())
if "prompt" in data and isinstance(data["prompt"], dict):
return data
if isinstance(data, dict):
return {"prompt": data}
raise ValueError("Prompt file must be a JSON object (prompt graph or wrapper with 'prompt').")
async def submit_prompt(
session: aiohttp.ClientSession,
base_url: str,
endpoint: str,
payload: dict[str, Any],
timeout_s: float,
) -> str:
url = f"{base_url}{endpoint}"
async with session.post(url, json=payload, timeout=timeout_s) as resp:
text = await resp.text()
if resp.status != 200:
raise RuntimeError(f"submit failed [{resp.status}] {text}")
body = json.loads(text)
prompt_id = body.get("prompt_id")
if not prompt_id:
raise RuntimeError(f"missing prompt_id in response: {body}")
return prompt_id
async def wait_for_prompt_done(
session: aiohttp.ClientSession,
base_url: str,
prompt_id: str,
poll_interval_s: float,
timeout_s: float,
) -> tuple[float | None, float | None]:
"""
Returns (queue_wait_ms, execution_ms) when available from history status messages.
Falls back to (None, None) if unavailable.
"""
deadline = time.perf_counter() + timeout_s
history_url = f"{base_url}/history/{prompt_id}"
while time.perf_counter() < deadline:
async with session.get(history_url, timeout=timeout_s) as resp:
if resp.status != 200:
text = await resp.text()
raise RuntimeError(f"history failed [{resp.status}] {text}")
payload = await resp.json()
if not payload:
await asyncio.sleep(poll_interval_s)
continue
history_item = payload.get(prompt_id)
if history_item is None:
await asyncio.sleep(poll_interval_s)
continue
status = history_item.get("status", {})
status_str = status.get("status_str")
messages = status.get("messages", [])
if status_str not in ("success", "error"):
await asyncio.sleep(poll_interval_s)
continue
queue_wait_ms = None
execution_ms = None
try:
timestamp_map: dict[str, int] = {}
for event, msg in messages:
if isinstance(msg, dict) and "timestamp" in msg:
timestamp_map[event] = int(msg["timestamp"])
start_ts = timestamp_map.get("execution_start")
end_ts = timestamp_map.get("execution_success") or timestamp_map.get("execution_error")
if start_ts is not None and end_ts is not None:
execution_ms = max(0.0, end_ts - start_ts)
except Exception:
execution_ms = None
return queue_wait_ms, execution_ms
await asyncio.sleep(poll_interval_s)
raise TimeoutError(f"timed out waiting for prompt_id={prompt_id}")
def build_arrival_schedule(num_requests: int, request_rate: float, poisson: bool, seed: int) -> list[float]:
"""
Returns absolute offsets (seconds from benchmark start) for each request.
"""
if request_rate <= 0:
return [0.0] * num_requests
rnd = random.Random(seed)
offsets: list[float] = []
t = 0.0
for _ in range(num_requests):
if poisson:
delta = rnd.expovariate(request_rate)
else:
delta = 1.0 / request_rate
t += delta
offsets.append(t)
return offsets
async def run_request(
idx: int,
start_time: float,
scheduled_offset_s: float,
semaphore: asyncio.Semaphore,
session: aiohttp.ClientSession,
args: argparse.Namespace,
prompt_wrapper_template: dict[str, Any],
) -> RequestResult:
await asyncio.sleep(max(0.0, (start_time + scheduled_offset_s) - time.perf_counter()))
queued_at = time.perf_counter()
async with semaphore:
started_at = time.perf_counter()
prompt_id = None
try:
payload = json.loads(json.dumps(prompt_wrapper_template))
payload.setdefault("extra_data", {})
payload["client_id"] = args.client_id
seed = args.base_seed + idx
payload["prompt"] = patch_seed_in_prompt(payload["prompt"], seed, args.seed_path)
prompt_id = await submit_prompt(
session=session,
base_url=args.host,
endpoint=args.endpoint,
payload=payload,
timeout_s=args.request_timeout_s,
)
queue_wait_ms, execution_ms = await wait_for_prompt_done(
session=session,
base_url=args.host,
prompt_id=prompt_id,
poll_interval_s=args.poll_interval_s,
timeout_s=args.request_timeout_s,
)
finished_at = time.perf_counter()
return RequestResult(
request_index=idx,
prompt_id=prompt_id,
ok=True,
error=None,
queued_at=queued_at,
started_at=started_at,
finished_at=finished_at,
end_to_end_s=finished_at - queued_at,
queue_wait_ms=queue_wait_ms,
execution_ms=execution_ms,
)
except Exception as exc:
finished_at = time.perf_counter()
return RequestResult(
request_index=idx,
prompt_id=prompt_id,
ok=False,
error=repr(exc),
queued_at=queued_at,
started_at=started_at,
finished_at=finished_at,
end_to_end_s=finished_at - queued_at,
queue_wait_ms=None,
execution_ms=None,
)
def print_summary(results: list[RequestResult], wall_s: float) -> None:
success = [r for r in results if r.ok]
fail = [r for r in results if not r.ok]
lat_s = [r.end_to_end_s for r in success]
queue_wait_ms = [r.queue_wait_ms for r in success if r.queue_wait_ms is not None]
exec_ms = [r.execution_ms for r in success if r.execution_ms is not None]
throughput = (len(success) / wall_s) if wall_s > 0 else 0.0
print("\n=== ComfyUI Serving Benchmark Summary ===")
print(f"requests_total: {len(results)}")
print(f"requests_success: {len(success)}")
print(f"requests_failed: {len(fail)}")
print(f"wall_time_s: {wall_s:.3f}")
print(f"throughput_req_s: {throughput:.3f}")
if lat_s:
print(f"latency_p50_s: {percentile(lat_s, 50):.3f}")
print(f"latency_p90_s: {percentile(lat_s, 90):.3f}")
print(f"latency_p95_s: {percentile(lat_s, 95):.3f}")
print(f"latency_p99_s: {percentile(lat_s, 99):.3f}")
print(f"latency_mean_s: {statistics.mean(lat_s):.3f}")
print(f"latency_max_s: {max(lat_s):.3f}")
if queue_wait_ms:
print(f"queue_wait_mean_ms: {statistics.mean(queue_wait_ms):.2f}")
print(f"queue_wait_p95_ms: {percentile(queue_wait_ms, 95):.2f}")
if exec_ms:
print(f"execution_mean_ms: {statistics.mean(exec_ms):.2f}")
print(f"execution_p95_ms: {percentile(exec_ms, 95):.2f}")
if fail:
print("\nSample failures:")
for r in fail[:5]:
print(f" idx={r.request_index} prompt_id={r.prompt_id} error={r.error}")
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Benchmark ComfyUI request serving.")
p.add_argument("--host", type=str, default="http://127.0.0.1:8188", help="ComfyUI base URL.")
p.add_argument(
"--endpoint",
type=str,
default="/prompt",
choices=("/prompt", "/bench/prompt"),
help="Submission endpoint.",
)
p.add_argument("--prompt-file", type=Path, required=True, help="Path to prompt JSON.")
p.add_argument("--num-requests", type=int, default=50)
p.add_argument("--max-concurrency", type=int, default=8)
p.add_argument("--request-rate", type=float, default=0.0, help="Requests/sec. 0 = fire immediately.")
p.add_argument("--poisson", action="store_true", help="Use Poisson inter-arrival when request-rate > 0.")
p.add_argument("--base-seed", type=int, default=1234)
p.add_argument(
"--seed-path",
type=str,
default=None,
help="Optional path to seed field in prompt: <node_id>.<input_name> (e.g. 3.seed).",
)
p.add_argument("--client-id", type=str, default=f"bench-{uuid.uuid4().hex[:12]}")
p.add_argument("--request-timeout-s", type=float, default=600.0)
p.add_argument("--poll-interval-s", type=float, default=0.2)
p.add_argument("--output-json", type=Path, default=None, help="Write detailed result JSON.")
p.add_argument("--seed", type=int, default=0, help="RNG seed for schedule generation.")
return p.parse_args()
async def async_main(args: argparse.Namespace) -> None:
prompt_template = load_prompt_template(args.prompt_file)
schedule = build_arrival_schedule(
num_requests=args.num_requests,
request_rate=args.request_rate,
poisson=args.poisson,
seed=args.seed,
)
semaphore = asyncio.Semaphore(args.max_concurrency)
connector = aiohttp.TCPConnector(limit=max(args.max_concurrency * 2, 32))
started = time.perf_counter()
async with aiohttp.ClientSession(connector=connector) as session:
tasks = [
asyncio.create_task(
run_request(
idx=i,
start_time=started,
scheduled_offset_s=schedule[i],
semaphore=semaphore,
session=session,
args=args,
prompt_wrapper_template=prompt_template,
)
)
for i in range(args.num_requests)
]
results = await asyncio.gather(*tasks)
wall_s = time.perf_counter() - started
print_summary(results, wall_s)
if args.output_json is not None:
out = {
"config": vars(args),
"wall_time_s": wall_s,
"results": [asdict(r) for r in sorted(results, key=lambda x: x.request_index)],
}
args.output_json.parent.mkdir(parents=True, exist_ok=True)
args.output_json.write_text(json.dumps(out, indent=2))
print(f"\nWrote results to: {args.output_json}")
def main() -> None:
args = parse_args()
asyncio.run(async_main(args))
if __name__ == "__main__":
main()

View File

@ -225,6 +225,7 @@ parser.add_argument(
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path. Overrides --base-directory.")
parser.add_argument("--enable-compress-response-body", action="store_true", help="Enable compressing response body.")
parser.add_argument("--benchmark-server-only", action="store_true", help="Enable lightweight benchmark routes and worker fast-paths focused on model serving throughput/latency.")
parser.add_argument(
"--comfy-api-base",

37
main.py
View File

@ -293,6 +293,7 @@ def prompt_worker(q, server_instance):
gc_collect_interval = 10.0
while True:
benchmark_mode = args.benchmark_server_only
timeout = 1000.0
if need_gc:
timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0)
@ -301,6 +302,7 @@ def prompt_worker(q, server_instance):
if queue_item is not None:
item, item_id = queue_item
execution_start_time = time.perf_counter()
execution_start_wall_ms = int(time.time() * 1000)
prompt_id = item[1]
server_instance.last_prompt_id = prompt_id
@ -308,15 +310,21 @@ def prompt_worker(q, server_instance):
extra_data = item[3].copy()
for k in sensitive:
extra_data[k] = sensitive[k]
benchmark_mode = args.benchmark_server_only or extra_data.get("benchmark_server_only", False)
asset_seeder.pause()
if not benchmark_mode:
asset_seeder.pause()
e.execute(item[2], prompt_id, extra_data, item[4])
need_gc = True
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
history_result = e.history_result
if benchmark_mode:
history_result = {"outputs": {}, "meta": {}}
q.task_done(item_id,
e.history_result,
history_result,
status=execution.PromptQueue.ExecutionStatus(
status_str='success' if e.success else 'error',
completed=e.success,
@ -325,16 +333,24 @@ def prompt_worker(q, server_instance):
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
current_time = time.perf_counter()
execution_time = current_time - execution_start_time
execution_time_s = current_time - execution_start_time
# Log Time in a more readable way after 10 minutes
if execution_time > 600:
execution_time = time.strftime("%H:%M:%S", time.gmtime(execution_time))
logging.info(f"Prompt executed in {execution_time}")
if execution_time_s > 600:
execution_time_formatted = time.strftime("%H:%M:%S", time.gmtime(execution_time_s))
logging.info(f"Prompt executed in {execution_time_formatted}")
else:
logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
logging.info("Prompt executed in {:.2f} seconds".format(execution_time_s))
if not asset_seeder.is_disabled():
queue_wait_ms = 0.0
created_at = extra_data.get("create_time")
if isinstance(created_at, int):
queue_wait_ms = max(0.0, execution_start_wall_ms - created_at)
if benchmark_mode:
server_instance.record_benchmark_result(prompt_id, e.success, execution_time_s * 1000.0, queue_wait_ms)
if not benchmark_mode and not asset_seeder.is_disabled():
paths = _collect_output_absolute_paths(e.history_result)
register_output_files(paths, job_id=prompt_id)
@ -360,9 +376,10 @@ def prompt_worker(q, server_instance):
need_gc = False
hook_breaker_ac10a0.restore_functions()
if not asset_seeder.is_disabled():
if not benchmark_mode and not asset_seeder.is_disabled():
asset_seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
asset_seeder.resume()
if not benchmark_mode:
asset_seeder.resume()
async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None):

View File

@ -16,6 +16,7 @@ import struct
import ssl
import socket
import ipaddress
import threading
from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo
from io import BytesIO
@ -252,6 +253,17 @@ class PromptServer():
self.client_id = None
self.on_prompt_handlers = []
self._benchmark_lock = threading.Lock()
self._benchmark_stats = {
"requests_total": 0,
"requests_success": 0,
"requests_error": 0,
"latency_ms_total": 0.0,
"latency_ms_max": 0.0,
"queue_wait_ms_total": 0.0,
"queue_wait_ms_max": 0.0,
"last_prompt_id": None,
}
@routes.get('/ws')
async def websocket_handler(request):
@ -912,12 +924,17 @@ class PromptServer():
queue_info['queue_pending'] = _remove_sensitive_from_queue(current_queue[1])
return web.json_response(queue_info)
@routes.post("/prompt")
async def post_prompt(request):
logging.info("got prompt")
json_data = await request.json()
json_data = self.trigger_on_prompt(json_data)
@routes.get("/bench/stats")
async def get_bench_stats(request):
stats = self.get_benchmark_stats()
return web.json_response(stats)
@routes.post("/bench/reset")
async def reset_bench_stats(request):
self.reset_benchmark_stats()
return web.json_response({"status": "ok"})
async def enqueue_prompt(json_data):
if "number" in json_data:
number = float(json_data['number'])
else:
@ -967,6 +984,22 @@ class PromptServer():
}
return web.json_response({"error": error, "node_errors": {}}, status=400)
@routes.post("/bench/prompt")
async def post_bench_prompt(request):
json_data = await request.json()
json_data = self.trigger_on_prompt(json_data)
extra_data = json_data.setdefault("extra_data", {})
extra_data["benchmark_server_only"] = True
extra_data.setdefault("preview_method", "none")
return await enqueue_prompt(json_data)
@routes.post("/prompt")
async def post_prompt(request):
logging.info("got prompt")
json_data = await request.json()
json_data = self.trigger_on_prompt(json_data)
return await enqueue_prompt(json_data)
@routes.post("/queue")
async def post_queue(request):
json_data = await request.json()
@ -1111,6 +1144,41 @@ class PromptServer():
prompt_info['exec_info'] = exec_info
return prompt_info
def reset_benchmark_stats(self):
with self._benchmark_lock:
self._benchmark_stats = {
"requests_total": 0,
"requests_success": 0,
"requests_error": 0,
"latency_ms_total": 0.0,
"latency_ms_max": 0.0,
"queue_wait_ms_total": 0.0,
"queue_wait_ms_max": 0.0,
"last_prompt_id": None,
}
def record_benchmark_result(self, prompt_id, success, latency_ms, queue_wait_ms=0.0):
with self._benchmark_lock:
self._benchmark_stats["requests_total"] += 1
if success:
self._benchmark_stats["requests_success"] += 1
else:
self._benchmark_stats["requests_error"] += 1
self._benchmark_stats["latency_ms_total"] += max(0.0, latency_ms)
self._benchmark_stats["queue_wait_ms_total"] += max(0.0, queue_wait_ms)
self._benchmark_stats["latency_ms_max"] = max(self._benchmark_stats["latency_ms_max"], max(0.0, latency_ms))
self._benchmark_stats["queue_wait_ms_max"] = max(self._benchmark_stats["queue_wait_ms_max"], max(0.0, queue_wait_ms))
self._benchmark_stats["last_prompt_id"] = prompt_id
def get_benchmark_stats(self):
with self._benchmark_lock:
stats = dict(self._benchmark_stats)
total = stats["requests_total"]
stats["latency_ms_avg"] = (stats["latency_ms_total"] / total) if total > 0 else 0.0
stats["queue_wait_ms_avg"] = (stats["queue_wait_ms_total"] / total) if total > 0 else 0.0
return stats
async def send(self, event, data, sid=None):
if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
await self.send_image(data, sid=sid)