Moving collecting summary to benchmark_comfyui_serving

This commit is contained in:
Tara Ding 2026-04-27 16:50:40 -07:00
parent 512deb3cd6
commit ca56e224a0
3 changed files with 16 additions and 95 deletions

View File

@ -396,7 +396,7 @@ async def wait_for_prompt_done(
timeout_s: float, timeout_s: float,
) -> tuple[float | None, float | None]: ) -> tuple[float | None, float | None]:
""" """
Returns (queue_wait_ms, execution_ms) when available from history status messages. Returns (queue_wait_ms, execution_ms) from history_item["benchmark"] written by the server.
Falls back to (None, None) if unavailable. Falls back to (None, None) if unavailable.
""" """
deadline = time.perf_counter() + timeout_s deadline = time.perf_counter() + timeout_s
@ -419,26 +419,13 @@ async def wait_for_prompt_done(
continue continue
status = history_item.get("status", {}) status = history_item.get("status", {})
status_str = status.get("status_str") if status.get("status_str") not in ("success", "error"):
messages = status.get("messages", [])
if status_str not in ("success", "error"):
await asyncio.sleep(poll_interval_s) await asyncio.sleep(poll_interval_s)
continue continue
queue_wait_ms = None benchmark = history_item.get("benchmark", {})
execution_ms = None queue_wait_ms = benchmark.get("queue_wait_ms")
try: execution_ms = benchmark.get("execution_ms")
timestamp_map: dict[str, int] = {}
for event, msg in messages:
if isinstance(msg, dict) and "timestamp" in msg:
timestamp_map[event] = int(msg["timestamp"])
start_ts = timestamp_map.get("execution_start")
end_ts = timestamp_map.get("execution_success") or timestamp_map.get("execution_error")
if start_ts is not None and end_ts is not None:
execution_ms = max(0.0, end_ts - start_ts)
except Exception:
execution_ms = None
return queue_wait_ms, execution_ms return queue_wait_ms, execution_ms
await asyncio.sleep(poll_interval_s) await asyncio.sleep(poll_interval_s)

22
main.py
View File

@ -293,7 +293,6 @@ def prompt_worker(q, server_instance):
gc_collect_interval = 10.0 gc_collect_interval = 10.0
while True: while True:
benchmark_mode = args.benchmark_server_only
timeout = 1000.0 timeout = 1000.0
if need_gc: if need_gc:
timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0) timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0)
@ -310,18 +309,28 @@ def prompt_worker(q, server_instance):
extra_data = item[3].copy() extra_data = item[3].copy()
for k in sensitive: for k in sensitive:
extra_data[k] = sensitive[k] extra_data[k] = sensitive[k]
benchmark_mode = args.benchmark_server_only or extra_data.get("benchmark_server_only", False) benchmark_mode = args.benchmark_server_only
if not benchmark_mode: if not benchmark_mode:
asset_seeder.pause() asset_seeder.pause()
e.execute(item[2], prompt_id, extra_data, item[4]) e.execute(item[2], prompt_id, extra_data, item[4])
execution_time_s = time.perf_counter() - execution_start_time
need_gc = True need_gc = True
queue_wait_ms = 0.0
created_at = extra_data.get("create_time")
if isinstance(created_at, int):
queue_wait_ms = max(0.0, execution_start_wall_ms - created_at)
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:] remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
history_result = e.history_result history_result = e.history_result
if benchmark_mode: if benchmark_mode:
history_result = {"outputs": {}, "meta": {}} history_result = {"outputs": {}, "meta": {}}
history_result["benchmark"] = {
"execution_ms": execution_time_s * 1000.0,
"queue_wait_ms": queue_wait_ms,
}
q.task_done(item_id, q.task_done(item_id,
history_result, history_result,
@ -333,7 +342,6 @@ def prompt_worker(q, server_instance):
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id) server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
current_time = time.perf_counter() current_time = time.perf_counter()
execution_time_s = current_time - execution_start_time
# Log Time in a more readable way after 10 minutes # Log Time in a more readable way after 10 minutes
if execution_time_s > 600: if execution_time_s > 600:
@ -342,14 +350,6 @@ def prompt_worker(q, server_instance):
else: else:
logging.info("Prompt executed in {:.2f} seconds".format(execution_time_s)) logging.info("Prompt executed in {:.2f} seconds".format(execution_time_s))
queue_wait_ms = 0.0
created_at = extra_data.get("create_time")
if isinstance(created_at, int):
queue_wait_ms = max(0.0, execution_start_wall_ms - created_at)
if benchmark_mode:
server_instance.record_benchmark_result(prompt_id, e.success, execution_time_s * 1000.0, queue_wait_ms)
if not benchmark_mode and not asset_seeder.is_disabled(): if not benchmark_mode and not asset_seeder.is_disabled():
paths = _collect_output_absolute_paths(e.history_result) paths = _collect_output_absolute_paths(e.history_result)
register_output_files(paths, job_id=prompt_id) register_output_files(paths, job_id=prompt_id)

View File

@ -16,7 +16,6 @@ import struct
import ssl import ssl
import socket import socket
import ipaddress import ipaddress
import threading
from PIL import Image, ImageOps from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
from io import BytesIO from io import BytesIO
@ -253,17 +252,6 @@ class PromptServer():
self.client_id = None self.client_id = None
self.on_prompt_handlers = [] self.on_prompt_handlers = []
self._benchmark_lock = threading.Lock()
self._benchmark_stats = {
"requests_total": 0,
"requests_success": 0,
"requests_error": 0,
"latency_ms_total": 0.0,
"latency_ms_max": 0.0,
"queue_wait_ms_total": 0.0,
"queue_wait_ms_max": 0.0,
"last_prompt_id": None,
}
@routes.get('/ws') @routes.get('/ws')
async def websocket_handler(request): async def websocket_handler(request):
@ -924,16 +912,6 @@ class PromptServer():
queue_info['queue_pending'] = _remove_sensitive_from_queue(current_queue[1]) queue_info['queue_pending'] = _remove_sensitive_from_queue(current_queue[1])
return web.json_response(queue_info) return web.json_response(queue_info)
@routes.get("/bench/stats")
async def get_bench_stats(request):
stats = self.get_benchmark_stats()
return web.json_response(stats)
@routes.post("/bench/reset")
async def reset_bench_stats(request):
self.reset_benchmark_stats()
return web.json_response({"status": "ok"})
async def enqueue_prompt(json_data): async def enqueue_prompt(json_data):
if "number" in json_data: if "number" in json_data:
number = float(json_data['number']) number = float(json_data['number'])
@ -984,15 +962,6 @@ class PromptServer():
} }
return web.json_response({"error": error, "node_errors": {}}, status=400) return web.json_response({"error": error, "node_errors": {}}, status=400)
@routes.post("/bench/prompt")
async def post_bench_prompt(request):
json_data = await request.json()
json_data = self.trigger_on_prompt(json_data)
extra_data = json_data.setdefault("extra_data", {})
extra_data["benchmark_server_only"] = True
extra_data.setdefault("preview_method", "none")
return await enqueue_prompt(json_data)
@routes.post("/prompt") @routes.post("/prompt")
async def post_prompt(request): async def post_prompt(request):
logging.info("got prompt") logging.info("got prompt")
@ -1144,41 +1113,6 @@ class PromptServer():
prompt_info['exec_info'] = exec_info prompt_info['exec_info'] = exec_info
return prompt_info return prompt_info
def reset_benchmark_stats(self):
with self._benchmark_lock:
self._benchmark_stats = {
"requests_total": 0,
"requests_success": 0,
"requests_error": 0,
"latency_ms_total": 0.0,
"latency_ms_max": 0.0,
"queue_wait_ms_total": 0.0,
"queue_wait_ms_max": 0.0,
"last_prompt_id": None,
}
def record_benchmark_result(self, prompt_id, success, latency_ms, queue_wait_ms=0.0):
with self._benchmark_lock:
self._benchmark_stats["requests_total"] += 1
if success:
self._benchmark_stats["requests_success"] += 1
else:
self._benchmark_stats["requests_error"] += 1
self._benchmark_stats["latency_ms_total"] += max(0.0, latency_ms)
self._benchmark_stats["queue_wait_ms_total"] += max(0.0, queue_wait_ms)
self._benchmark_stats["latency_ms_max"] = max(self._benchmark_stats["latency_ms_max"], max(0.0, latency_ms))
self._benchmark_stats["queue_wait_ms_max"] = max(self._benchmark_stats["queue_wait_ms_max"], max(0.0, queue_wait_ms))
self._benchmark_stats["last_prompt_id"] = prompt_id
def get_benchmark_stats(self):
with self._benchmark_lock:
stats = dict(self._benchmark_stats)
total = stats["requests_total"]
stats["latency_ms_avg"] = (stats["latency_ms_total"] / total) if total > 0 else 0.0
stats["queue_wait_ms_avg"] = (stats["queue_wait_ms_total"] / total) if total > 0 else 0.0
return stats
async def send(self, event, data, sid=None): async def send(self, event, data, sid=None):
if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE: if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
await self.send_image(data, sid=sid) await self.send_image(data, sid=sid)