mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-14 02:57:24 +08:00
Add benchmark for each node.
This commit is contained in:
parent
059b346966
commit
69f6272edc
@ -316,6 +316,7 @@ class RequestResult:
|
|||||||
end_to_end_s: float
|
end_to_end_s: float
|
||||||
queue_wait_ms: float | None
|
queue_wait_ms: float | None
|
||||||
execution_ms: float | None
|
execution_ms: float | None
|
||||||
|
node_timing_ms: dict[str, dict] | None
|
||||||
|
|
||||||
|
|
||||||
def percentile(values: list[float], pct: float) -> float:
|
def percentile(values: list[float], pct: float) -> float:
|
||||||
@ -394,10 +395,10 @@ async def wait_for_prompt_done(
|
|||||||
prompt_id: str,
|
prompt_id: str,
|
||||||
poll_interval_s: float,
|
poll_interval_s: float,
|
||||||
timeout_s: float,
|
timeout_s: float,
|
||||||
) -> tuple[float | None, float | None]:
|
) -> tuple[float | None, float | None, dict | None]:
|
||||||
"""
|
"""
|
||||||
Returns (queue_wait_ms, execution_ms) from history_item["benchmark"] written by the server.
|
Returns (queue_wait_ms, execution_ms, node_timing_ms) from history_item["benchmark"].
|
||||||
Falls back to (None, None) if unavailable.
|
Falls back to (None, None, None) if unavailable.
|
||||||
"""
|
"""
|
||||||
deadline = time.perf_counter() + timeout_s
|
deadline = time.perf_counter() + timeout_s
|
||||||
history_url = f"{base_url}/history/{prompt_id}"
|
history_url = f"{base_url}/history/{prompt_id}"
|
||||||
@ -424,9 +425,11 @@ async def wait_for_prompt_done(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
benchmark = history_item.get("benchmark", {})
|
benchmark = history_item.get("benchmark", {})
|
||||||
queue_wait_ms = benchmark.get("queue_wait_ms")
|
return (
|
||||||
execution_ms = benchmark.get("execution_ms")
|
benchmark.get("queue_wait_ms"),
|
||||||
return queue_wait_ms, execution_ms
|
benchmark.get("execution_ms"),
|
||||||
|
benchmark.get("nodes"),
|
||||||
|
)
|
||||||
|
|
||||||
await asyncio.sleep(poll_interval_s)
|
await asyncio.sleep(poll_interval_s)
|
||||||
|
|
||||||
@ -484,7 +487,7 @@ async def run_request(
|
|||||||
timeout_s=args.request_timeout_s,
|
timeout_s=args.request_timeout_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
queue_wait_ms, execution_ms = await wait_for_prompt_done(
|
queue_wait_ms, execution_ms, node_timing_ms = await wait_for_prompt_done(
|
||||||
session=session,
|
session=session,
|
||||||
base_url=args.host,
|
base_url=args.host,
|
||||||
prompt_id=prompt_id,
|
prompt_id=prompt_id,
|
||||||
@ -503,6 +506,7 @@ async def run_request(
|
|||||||
end_to_end_s=finished_at - queued_at,
|
end_to_end_s=finished_at - queued_at,
|
||||||
queue_wait_ms=queue_wait_ms,
|
queue_wait_ms=queue_wait_ms,
|
||||||
execution_ms=execution_ms,
|
execution_ms=execution_ms,
|
||||||
|
node_timing_ms=node_timing_ms,
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
finished_at = time.perf_counter()
|
finished_at = time.perf_counter()
|
||||||
@ -517,6 +521,7 @@ async def run_request(
|
|||||||
end_to_end_s=finished_at - queued_at,
|
end_to_end_s=finished_at - queued_at,
|
||||||
queue_wait_ms=None,
|
queue_wait_ms=None,
|
||||||
execution_ms=None,
|
execution_ms=None,
|
||||||
|
node_timing_ms=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -551,6 +556,19 @@ def print_summary(results: list[RequestResult], wall_s: float) -> None:
|
|||||||
print(f"execution_mean_ms: {statistics.mean(exec_ms):.2f}")
|
print(f"execution_mean_ms: {statistics.mean(exec_ms):.2f}")
|
||||||
print(f"execution_p95_ms: {percentile(exec_ms, 95):.2f}")
|
print(f"execution_p95_ms: {percentile(exec_ms, 95):.2f}")
|
||||||
|
|
||||||
|
# Per-node timing: aggregate execution_ms across all successful results.
|
||||||
|
node_totals: dict[str, list[float]] = {}
|
||||||
|
for r in success:
|
||||||
|
if not r.node_timing_ms:
|
||||||
|
continue
|
||||||
|
for node_id, info in r.node_timing_ms.items():
|
||||||
|
key = f"{info.get('class_type', 'unknown')} ({node_id})"
|
||||||
|
node_totals.setdefault(key, []).append(info.get("execution_ms", 0.0))
|
||||||
|
if node_totals:
|
||||||
|
print("\n--- Per-node execution time (mean ms across successful requests) ---")
|
||||||
|
for key, times in sorted(node_totals.items(), key=lambda x: -statistics.mean(x[1])):
|
||||||
|
print(f" {key}: mean={statistics.mean(times):.1f} p95={percentile(times, 95):.1f} n={len(times)}")
|
||||||
|
|
||||||
if fail:
|
if fail:
|
||||||
print("\nSample failures:")
|
print("\nSample failures:")
|
||||||
for r in fail[:5]:
|
for r in fail[:5]:
|
||||||
|
|||||||
@ -721,6 +721,7 @@ class PromptExecutor:
|
|||||||
self.server.client_id = None
|
self.server.client_id = None
|
||||||
|
|
||||||
self.status_messages = []
|
self.status_messages = []
|
||||||
|
self.node_timing_ms: dict[str, dict] = {}
|
||||||
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
|
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
|
||||||
|
|
||||||
self._notify_prompt_lifecycle("start", prompt_id)
|
self._notify_prompt_lifecycle("start", prompt_id)
|
||||||
@ -767,6 +768,7 @@ class PromptExecutor:
|
|||||||
break
|
break
|
||||||
|
|
||||||
assert node_id is not None, "Node ID should not be None at this point"
|
assert node_id is not None, "Node ID should not be None at this point"
|
||||||
|
node_start_s = time.perf_counter() if args.benchmark_server_only else None
|
||||||
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
|
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
|
||||||
self.success = result != ExecutionResult.FAILURE
|
self.success = result != ExecutionResult.FAILURE
|
||||||
if result == ExecutionResult.FAILURE:
|
if result == ExecutionResult.FAILURE:
|
||||||
@ -776,6 +778,12 @@ class PromptExecutor:
|
|||||||
execution_list.unstage_node_execution()
|
execution_list.unstage_node_execution()
|
||||||
else: # result == ExecutionResult.SUCCESS:
|
else: # result == ExecutionResult.SUCCESS:
|
||||||
execution_list.complete_node_execution()
|
execution_list.complete_node_execution()
|
||||||
|
if node_start_s is not None:
|
||||||
|
class_type = dynamic_prompt.get_node(node_id).get("class_type", "unknown")
|
||||||
|
self.node_timing_ms[node_id] = {
|
||||||
|
"class_type": class_type,
|
||||||
|
"execution_ms": (time.perf_counter() - node_start_s) * 1000.0,
|
||||||
|
}
|
||||||
|
|
||||||
if self.cache_type == CacheType.RAM_PRESSURE:
|
if self.cache_type == CacheType.RAM_PRESSURE:
|
||||||
comfy.model_management.free_memory(0, None, pins_required=ram_headroom, ram_required=ram_headroom)
|
comfy.model_management.free_memory(0, None, pins_required=ram_headroom, ram_required=ram_headroom)
|
||||||
|
|||||||
14
main.py
14
main.py
@ -326,11 +326,15 @@ def prompt_worker(q, server_instance):
|
|||||||
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
||||||
history_result = e.history_result
|
history_result = e.history_result
|
||||||
if benchmark_mode:
|
if benchmark_mode:
|
||||||
history_result = {"outputs": {}, "meta": {}}
|
history_result = {
|
||||||
history_result["benchmark"] = {
|
"outputs": {},
|
||||||
"execution_ms": execution_time_s * 1000.0,
|
"meta": {},
|
||||||
"queue_wait_ms": queue_wait_ms,
|
"benchmark": {
|
||||||
}
|
"execution_ms": execution_time_s * 1000.0,
|
||||||
|
"queue_wait_ms": queue_wait_ms,
|
||||||
|
"nodes": e.node_timing_ms,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
q.task_done(item_id,
|
q.task_done(item_id,
|
||||||
history_result,
|
history_result,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user