mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-12 18:22:53 +08:00
Add benchmark for each node.
This commit is contained in:
parent
059b346966
commit
69f6272edc
@ -316,6 +316,7 @@ class RequestResult:
|
||||
end_to_end_s: float
|
||||
queue_wait_ms: float | None
|
||||
execution_ms: float | None
|
||||
node_timing_ms: dict[str, dict] | None
|
||||
|
||||
|
||||
def percentile(values: list[float], pct: float) -> float:
|
||||
@ -394,10 +395,10 @@ async def wait_for_prompt_done(
|
||||
prompt_id: str,
|
||||
poll_interval_s: float,
|
||||
timeout_s: float,
|
||||
) -> tuple[float | None, float | None]:
|
||||
) -> tuple[float | None, float | None, dict | None]:
|
||||
"""
|
||||
Returns (queue_wait_ms, execution_ms) from history_item["benchmark"] written by the server.
|
||||
Falls back to (None, None) if unavailable.
|
||||
Returns (queue_wait_ms, execution_ms, node_timing_ms) from history_item["benchmark"].
|
||||
Falls back to (None, None, None) if unavailable.
|
||||
"""
|
||||
deadline = time.perf_counter() + timeout_s
|
||||
history_url = f"{base_url}/history/{prompt_id}"
|
||||
@ -424,9 +425,11 @@ async def wait_for_prompt_done(
|
||||
continue
|
||||
|
||||
benchmark = history_item.get("benchmark", {})
|
||||
queue_wait_ms = benchmark.get("queue_wait_ms")
|
||||
execution_ms = benchmark.get("execution_ms")
|
||||
return queue_wait_ms, execution_ms
|
||||
return (
|
||||
benchmark.get("queue_wait_ms"),
|
||||
benchmark.get("execution_ms"),
|
||||
benchmark.get("nodes"),
|
||||
)
|
||||
|
||||
await asyncio.sleep(poll_interval_s)
|
||||
|
||||
@ -484,7 +487,7 @@ async def run_request(
|
||||
timeout_s=args.request_timeout_s,
|
||||
)
|
||||
|
||||
queue_wait_ms, execution_ms = await wait_for_prompt_done(
|
||||
queue_wait_ms, execution_ms, node_timing_ms = await wait_for_prompt_done(
|
||||
session=session,
|
||||
base_url=args.host,
|
||||
prompt_id=prompt_id,
|
||||
@ -503,6 +506,7 @@ async def run_request(
|
||||
end_to_end_s=finished_at - queued_at,
|
||||
queue_wait_ms=queue_wait_ms,
|
||||
execution_ms=execution_ms,
|
||||
node_timing_ms=node_timing_ms,
|
||||
)
|
||||
except Exception as exc:
|
||||
finished_at = time.perf_counter()
|
||||
@ -517,6 +521,7 @@ async def run_request(
|
||||
end_to_end_s=finished_at - queued_at,
|
||||
queue_wait_ms=None,
|
||||
execution_ms=None,
|
||||
node_timing_ms=None,
|
||||
)
|
||||
|
||||
|
||||
@ -551,6 +556,19 @@ def print_summary(results: list[RequestResult], wall_s: float) -> None:
|
||||
print(f"execution_mean_ms: {statistics.mean(exec_ms):.2f}")
|
||||
print(f"execution_p95_ms: {percentile(exec_ms, 95):.2f}")
|
||||
|
||||
# Per-node timing: aggregate execution_ms across all successful results.
|
||||
node_totals: dict[str, list[float]] = {}
|
||||
for r in success:
|
||||
if not r.node_timing_ms:
|
||||
continue
|
||||
for node_id, info in r.node_timing_ms.items():
|
||||
key = f"{info.get('class_type', 'unknown')} ({node_id})"
|
||||
node_totals.setdefault(key, []).append(info.get("execution_ms", 0.0))
|
||||
if node_totals:
|
||||
print("\n--- Per-node execution time (mean ms across successful requests) ---")
|
||||
for key, times in sorted(node_totals.items(), key=lambda x: -statistics.mean(x[1])):
|
||||
print(f" {key}: mean={statistics.mean(times):.1f} p95={percentile(times, 95):.1f} n={len(times)}")
|
||||
|
||||
if fail:
|
||||
print("\nSample failures:")
|
||||
for r in fail[:5]:
|
||||
|
||||
@ -721,6 +721,7 @@ class PromptExecutor:
|
||||
self.server.client_id = None
|
||||
|
||||
self.status_messages = []
|
||||
self.node_timing_ms: dict[str, dict] = {}
|
||||
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
|
||||
|
||||
self._notify_prompt_lifecycle("start", prompt_id)
|
||||
@ -767,6 +768,7 @@ class PromptExecutor:
|
||||
break
|
||||
|
||||
assert node_id is not None, "Node ID should not be None at this point"
|
||||
node_start_s = time.perf_counter() if args.benchmark_server_only else None
|
||||
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
|
||||
self.success = result != ExecutionResult.FAILURE
|
||||
if result == ExecutionResult.FAILURE:
|
||||
@ -776,6 +778,12 @@ class PromptExecutor:
|
||||
execution_list.unstage_node_execution()
|
||||
else: # result == ExecutionResult.SUCCESS:
|
||||
execution_list.complete_node_execution()
|
||||
if node_start_s is not None:
|
||||
class_type = dynamic_prompt.get_node(node_id).get("class_type", "unknown")
|
||||
self.node_timing_ms[node_id] = {
|
||||
"class_type": class_type,
|
||||
"execution_ms": (time.perf_counter() - node_start_s) * 1000.0,
|
||||
}
|
||||
|
||||
if self.cache_type == CacheType.RAM_PRESSURE:
|
||||
comfy.model_management.free_memory(0, None, pins_required=ram_headroom, ram_required=ram_headroom)
|
||||
|
||||
14
main.py
14
main.py
@ -326,11 +326,15 @@ def prompt_worker(q, server_instance):
|
||||
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
||||
history_result = e.history_result
|
||||
if benchmark_mode:
|
||||
history_result = {"outputs": {}, "meta": {}}
|
||||
history_result["benchmark"] = {
|
||||
"execution_ms": execution_time_s * 1000.0,
|
||||
"queue_wait_ms": queue_wait_ms,
|
||||
}
|
||||
history_result = {
|
||||
"outputs": {},
|
||||
"meta": {},
|
||||
"benchmark": {
|
||||
"execution_ms": execution_time_s * 1000.0,
|
||||
"queue_wait_ms": queue_wait_ms,
|
||||
"nodes": e.node_timing_ms,
|
||||
},
|
||||
}
|
||||
|
||||
q.task_done(item_id,
|
||||
history_result,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user