Add benchmark for each node.

This commit is contained in:
Tara Ding 2026-04-27 21:39:57 -07:00
parent 059b346966
commit 69f6272edc
3 changed files with 42 additions and 12 deletions

View File

@ -316,6 +316,7 @@ class RequestResult:
end_to_end_s: float
queue_wait_ms: float | None
execution_ms: float | None
node_timing_ms: dict[str, dict] | None
def percentile(values: list[float], pct: float) -> float:
@ -394,10 +395,10 @@ async def wait_for_prompt_done(
prompt_id: str,
poll_interval_s: float,
timeout_s: float,
) -> tuple[float | None, float | None]:
) -> tuple[float | None, float | None, dict | None]:
"""
Returns (queue_wait_ms, execution_ms) from history_item["benchmark"] written by the server.
Falls back to (None, None) if unavailable.
Returns (queue_wait_ms, execution_ms, node_timing_ms) from history_item["benchmark"].
Falls back to (None, None, None) if unavailable.
"""
deadline = time.perf_counter() + timeout_s
history_url = f"{base_url}/history/{prompt_id}"
@ -424,9 +425,11 @@ async def wait_for_prompt_done(
continue
benchmark = history_item.get("benchmark", {})
queue_wait_ms = benchmark.get("queue_wait_ms")
execution_ms = benchmark.get("execution_ms")
return queue_wait_ms, execution_ms
return (
benchmark.get("queue_wait_ms"),
benchmark.get("execution_ms"),
benchmark.get("nodes"),
)
await asyncio.sleep(poll_interval_s)
@ -484,7 +487,7 @@ async def run_request(
timeout_s=args.request_timeout_s,
)
queue_wait_ms, execution_ms = await wait_for_prompt_done(
queue_wait_ms, execution_ms, node_timing_ms = await wait_for_prompt_done(
session=session,
base_url=args.host,
prompt_id=prompt_id,
@ -503,6 +506,7 @@ async def run_request(
end_to_end_s=finished_at - queued_at,
queue_wait_ms=queue_wait_ms,
execution_ms=execution_ms,
node_timing_ms=node_timing_ms,
)
except Exception as exc:
finished_at = time.perf_counter()
@ -517,6 +521,7 @@ async def run_request(
end_to_end_s=finished_at - queued_at,
queue_wait_ms=None,
execution_ms=None,
node_timing_ms=None,
)
@ -551,6 +556,19 @@ def print_summary(results: list[RequestResult], wall_s: float) -> None:
print(f"execution_mean_ms: {statistics.mean(exec_ms):.2f}")
print(f"execution_p95_ms: {percentile(exec_ms, 95):.2f}")
# Per-node timing: aggregate execution_ms across all successful results.
node_totals: dict[str, list[float]] = {}
for r in success:
if not r.node_timing_ms:
continue
for node_id, info in r.node_timing_ms.items():
key = f"{info.get('class_type', 'unknown')} ({node_id})"
node_totals.setdefault(key, []).append(info.get("execution_ms", 0.0))
if node_totals:
print("\n--- Per-node execution time (mean ms across successful requests) ---")
for key, times in sorted(node_totals.items(), key=lambda x: -statistics.mean(x[1])):
print(f" {key}: mean={statistics.mean(times):.1f} p95={percentile(times, 95):.1f} n={len(times)}")
if fail:
print("\nSample failures:")
for r in fail[:5]:

View File

@ -721,6 +721,7 @@ class PromptExecutor:
self.server.client_id = None
self.status_messages = []
self.node_timing_ms: dict[str, dict] = {}
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
self._notify_prompt_lifecycle("start", prompt_id)
@ -767,6 +768,7 @@ class PromptExecutor:
break
assert node_id is not None, "Node ID should not be None at this point"
node_start_s = time.perf_counter() if args.benchmark_server_only else None
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
self.success = result != ExecutionResult.FAILURE
if result == ExecutionResult.FAILURE:
@ -776,6 +778,12 @@ class PromptExecutor:
execution_list.unstage_node_execution()
else: # result == ExecutionResult.SUCCESS:
execution_list.complete_node_execution()
if node_start_s is not None:
class_type = dynamic_prompt.get_node(node_id).get("class_type", "unknown")
self.node_timing_ms[node_id] = {
"class_type": class_type,
"execution_ms": (time.perf_counter() - node_start_s) * 1000.0,
}
if self.cache_type == CacheType.RAM_PRESSURE:
comfy.model_management.free_memory(0, None, pins_required=ram_headroom, ram_required=ram_headroom)

14
main.py
View File

@ -326,11 +326,15 @@ def prompt_worker(q, server_instance):
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
history_result = e.history_result
if benchmark_mode:
history_result = {"outputs": {}, "meta": {}}
history_result["benchmark"] = {
"execution_ms": execution_time_s * 1000.0,
"queue_wait_ms": queue_wait_ms,
}
history_result = {
"outputs": {},
"meta": {},
"benchmark": {
"execution_ms": execution_time_s * 1000.0,
"queue_wait_ms": queue_wait_ms,
"nodes": e.node_timing_ms,
},
}
q.task_done(item_id,
history_result,