mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-14 02:57:24 +08:00
Move benchmark serving client into benchmarks folder
This commit is contained in:
parent
96363fa74a
commit
00379b4acf
374
benchmarks/benchmark_comfyui_serving.py
Normal file
374
benchmarks/benchmark_comfyui_serving.py
Normal file
@ -0,0 +1,374 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Simple serving benchmark client for ComfyUI's HTTP API.
|
||||||
|
|
||||||
|
This script is inspired by diffusion serving benchmarks and is designed to:
|
||||||
|
- submit prompts to ComfyUI (/prompt or /bench/prompt),
|
||||||
|
- optionally shape request arrivals (fixed rate or Poisson),
|
||||||
|
- poll completion via /history/{prompt_id},
|
||||||
|
- report latency/throughput/error metrics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
import statistics
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, asdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RequestResult:
|
||||||
|
request_index: int
|
||||||
|
prompt_id: str | None
|
||||||
|
ok: bool
|
||||||
|
error: str | None
|
||||||
|
queued_at: float
|
||||||
|
started_at: float
|
||||||
|
finished_at: float
|
||||||
|
end_to_end_s: float
|
||||||
|
queue_wait_ms: float | None
|
||||||
|
execution_ms: float | None
|
||||||
|
|
||||||
|
|
||||||
|
def percentile(values: list[float], pct: float) -> float:
|
||||||
|
if not values:
|
||||||
|
return float("nan")
|
||||||
|
if len(values) == 1:
|
||||||
|
return values[0]
|
||||||
|
values = sorted(values)
|
||||||
|
rank = (len(values) - 1) * (pct / 100.0)
|
||||||
|
lower = math.floor(rank)
|
||||||
|
upper = math.ceil(rank)
|
||||||
|
if lower == upper:
|
||||||
|
return values[lower]
|
||||||
|
weight = rank - lower
|
||||||
|
return values[lower] * (1.0 - weight) + values[upper] * weight
|
||||||
|
|
||||||
|
|
||||||
|
def patch_seed_in_prompt(prompt: dict[str, Any], seed: int, seed_path: str | None) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Patch prompt seed in-place for common sampler nodes.
|
||||||
|
seed_path format: "<node_id>.<input_name>".
|
||||||
|
"""
|
||||||
|
if seed_path:
|
||||||
|
try:
|
||||||
|
node_id, input_name = seed_path.split(".", 1)
|
||||||
|
prompt[node_id]["inputs"][input_name] = seed
|
||||||
|
return prompt
|
||||||
|
except Exception as exc:
|
||||||
|
raise ValueError(f"Invalid --seed-path '{seed_path}': {exc}") from exc
|
||||||
|
|
||||||
|
# Best-effort fallback: update any input key named 'seed' or 'noise_seed'
|
||||||
|
for node in prompt.values():
|
||||||
|
if not isinstance(node, dict):
|
||||||
|
continue
|
||||||
|
inputs = node.get("inputs")
|
||||||
|
if not isinstance(inputs, dict):
|
||||||
|
continue
|
||||||
|
if "seed" in inputs:
|
||||||
|
inputs["seed"] = seed
|
||||||
|
if "noise_seed" in inputs:
|
||||||
|
inputs["noise_seed"] = seed
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def load_prompt_template(path: Path) -> dict[str, Any]:
|
||||||
|
data = json.loads(path.read_text())
|
||||||
|
if "prompt" in data and isinstance(data["prompt"], dict):
|
||||||
|
return data
|
||||||
|
if isinstance(data, dict):
|
||||||
|
return {"prompt": data}
|
||||||
|
raise ValueError("Prompt file must be a JSON object (prompt graph or wrapper with 'prompt').")
|
||||||
|
|
||||||
|
|
||||||
|
async def submit_prompt(
|
||||||
|
session: aiohttp.ClientSession,
|
||||||
|
base_url: str,
|
||||||
|
endpoint: str,
|
||||||
|
payload: dict[str, Any],
|
||||||
|
timeout_s: float,
|
||||||
|
) -> str:
|
||||||
|
url = f"{base_url}{endpoint}"
|
||||||
|
async with session.post(url, json=payload, timeout=timeout_s) as resp:
|
||||||
|
text = await resp.text()
|
||||||
|
if resp.status != 200:
|
||||||
|
raise RuntimeError(f"submit failed [{resp.status}] {text}")
|
||||||
|
body = json.loads(text)
|
||||||
|
prompt_id = body.get("prompt_id")
|
||||||
|
if not prompt_id:
|
||||||
|
raise RuntimeError(f"missing prompt_id in response: {body}")
|
||||||
|
return prompt_id
|
||||||
|
|
||||||
|
|
||||||
|
async def wait_for_prompt_done(
|
||||||
|
session: aiohttp.ClientSession,
|
||||||
|
base_url: str,
|
||||||
|
prompt_id: str,
|
||||||
|
poll_interval_s: float,
|
||||||
|
timeout_s: float,
|
||||||
|
) -> tuple[float | None, float | None]:
|
||||||
|
"""
|
||||||
|
Returns (queue_wait_ms, execution_ms) when available from history status messages.
|
||||||
|
Falls back to (None, None) if unavailable.
|
||||||
|
"""
|
||||||
|
deadline = time.perf_counter() + timeout_s
|
||||||
|
history_url = f"{base_url}/history/{prompt_id}"
|
||||||
|
|
||||||
|
while time.perf_counter() < deadline:
|
||||||
|
async with session.get(history_url, timeout=timeout_s) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
text = await resp.text()
|
||||||
|
raise RuntimeError(f"history failed [{resp.status}] {text}")
|
||||||
|
|
||||||
|
payload = await resp.json()
|
||||||
|
if not payload:
|
||||||
|
await asyncio.sleep(poll_interval_s)
|
||||||
|
continue
|
||||||
|
|
||||||
|
history_item = payload.get(prompt_id)
|
||||||
|
if history_item is None:
|
||||||
|
await asyncio.sleep(poll_interval_s)
|
||||||
|
continue
|
||||||
|
|
||||||
|
status = history_item.get("status", {})
|
||||||
|
status_str = status.get("status_str")
|
||||||
|
messages = status.get("messages", [])
|
||||||
|
if status_str not in ("success", "error"):
|
||||||
|
await asyncio.sleep(poll_interval_s)
|
||||||
|
continue
|
||||||
|
|
||||||
|
queue_wait_ms = None
|
||||||
|
execution_ms = None
|
||||||
|
try:
|
||||||
|
timestamp_map: dict[str, int] = {}
|
||||||
|
for event, msg in messages:
|
||||||
|
if isinstance(msg, dict) and "timestamp" in msg:
|
||||||
|
timestamp_map[event] = int(msg["timestamp"])
|
||||||
|
start_ts = timestamp_map.get("execution_start")
|
||||||
|
end_ts = timestamp_map.get("execution_success") or timestamp_map.get("execution_error")
|
||||||
|
if start_ts is not None and end_ts is not None:
|
||||||
|
execution_ms = max(0.0, end_ts - start_ts)
|
||||||
|
except Exception:
|
||||||
|
execution_ms = None
|
||||||
|
|
||||||
|
return queue_wait_ms, execution_ms
|
||||||
|
|
||||||
|
await asyncio.sleep(poll_interval_s)
|
||||||
|
|
||||||
|
raise TimeoutError(f"timed out waiting for prompt_id={prompt_id}")
|
||||||
|
|
||||||
|
|
||||||
|
def build_arrival_schedule(num_requests: int, request_rate: float, poisson: bool, seed: int) -> list[float]:
|
||||||
|
"""
|
||||||
|
Returns absolute offsets (seconds from benchmark start) for each request.
|
||||||
|
"""
|
||||||
|
if request_rate <= 0:
|
||||||
|
return [0.0] * num_requests
|
||||||
|
|
||||||
|
rnd = random.Random(seed)
|
||||||
|
offsets: list[float] = []
|
||||||
|
t = 0.0
|
||||||
|
for _ in range(num_requests):
|
||||||
|
if poisson:
|
||||||
|
delta = rnd.expovariate(request_rate)
|
||||||
|
else:
|
||||||
|
delta = 1.0 / request_rate
|
||||||
|
t += delta
|
||||||
|
offsets.append(t)
|
||||||
|
return offsets
|
||||||
|
|
||||||
|
|
||||||
|
async def run_request(
|
||||||
|
idx: int,
|
||||||
|
start_time: float,
|
||||||
|
scheduled_offset_s: float,
|
||||||
|
semaphore: asyncio.Semaphore,
|
||||||
|
session: aiohttp.ClientSession,
|
||||||
|
args: argparse.Namespace,
|
||||||
|
prompt_wrapper_template: dict[str, Any],
|
||||||
|
) -> RequestResult:
|
||||||
|
await asyncio.sleep(max(0.0, (start_time + scheduled_offset_s) - time.perf_counter()))
|
||||||
|
queued_at = time.perf_counter()
|
||||||
|
|
||||||
|
async with semaphore:
|
||||||
|
started_at = time.perf_counter()
|
||||||
|
prompt_id = None
|
||||||
|
try:
|
||||||
|
payload = json.loads(json.dumps(prompt_wrapper_template))
|
||||||
|
payload.setdefault("extra_data", {})
|
||||||
|
payload["client_id"] = args.client_id
|
||||||
|
|
||||||
|
seed = args.base_seed + idx
|
||||||
|
payload["prompt"] = patch_seed_in_prompt(payload["prompt"], seed, args.seed_path)
|
||||||
|
|
||||||
|
prompt_id = await submit_prompt(
|
||||||
|
session=session,
|
||||||
|
base_url=args.host,
|
||||||
|
endpoint=args.endpoint,
|
||||||
|
payload=payload,
|
||||||
|
timeout_s=args.request_timeout_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
queue_wait_ms, execution_ms = await wait_for_prompt_done(
|
||||||
|
session=session,
|
||||||
|
base_url=args.host,
|
||||||
|
prompt_id=prompt_id,
|
||||||
|
poll_interval_s=args.poll_interval_s,
|
||||||
|
timeout_s=args.request_timeout_s,
|
||||||
|
)
|
||||||
|
finished_at = time.perf_counter()
|
||||||
|
return RequestResult(
|
||||||
|
request_index=idx,
|
||||||
|
prompt_id=prompt_id,
|
||||||
|
ok=True,
|
||||||
|
error=None,
|
||||||
|
queued_at=queued_at,
|
||||||
|
started_at=started_at,
|
||||||
|
finished_at=finished_at,
|
||||||
|
end_to_end_s=finished_at - queued_at,
|
||||||
|
queue_wait_ms=queue_wait_ms,
|
||||||
|
execution_ms=execution_ms,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
finished_at = time.perf_counter()
|
||||||
|
return RequestResult(
|
||||||
|
request_index=idx,
|
||||||
|
prompt_id=prompt_id,
|
||||||
|
ok=False,
|
||||||
|
error=repr(exc),
|
||||||
|
queued_at=queued_at,
|
||||||
|
started_at=started_at,
|
||||||
|
finished_at=finished_at,
|
||||||
|
end_to_end_s=finished_at - queued_at,
|
||||||
|
queue_wait_ms=None,
|
||||||
|
execution_ms=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def print_summary(results: list[RequestResult], wall_s: float) -> None:
|
||||||
|
success = [r for r in results if r.ok]
|
||||||
|
fail = [r for r in results if not r.ok]
|
||||||
|
lat_s = [r.end_to_end_s for r in success]
|
||||||
|
queue_wait_ms = [r.queue_wait_ms for r in success if r.queue_wait_ms is not None]
|
||||||
|
exec_ms = [r.execution_ms for r in success if r.execution_ms is not None]
|
||||||
|
|
||||||
|
throughput = (len(success) / wall_s) if wall_s > 0 else 0.0
|
||||||
|
print("\n=== ComfyUI Serving Benchmark Summary ===")
|
||||||
|
print(f"requests_total: {len(results)}")
|
||||||
|
print(f"requests_success: {len(success)}")
|
||||||
|
print(f"requests_failed: {len(fail)}")
|
||||||
|
print(f"wall_time_s: {wall_s:.3f}")
|
||||||
|
print(f"throughput_req_s: {throughput:.3f}")
|
||||||
|
|
||||||
|
if lat_s:
|
||||||
|
print(f"latency_p50_s: {percentile(lat_s, 50):.3f}")
|
||||||
|
print(f"latency_p90_s: {percentile(lat_s, 90):.3f}")
|
||||||
|
print(f"latency_p95_s: {percentile(lat_s, 95):.3f}")
|
||||||
|
print(f"latency_p99_s: {percentile(lat_s, 99):.3f}")
|
||||||
|
print(f"latency_mean_s: {statistics.mean(lat_s):.3f}")
|
||||||
|
print(f"latency_max_s: {max(lat_s):.3f}")
|
||||||
|
|
||||||
|
if queue_wait_ms:
|
||||||
|
print(f"queue_wait_mean_ms: {statistics.mean(queue_wait_ms):.2f}")
|
||||||
|
print(f"queue_wait_p95_ms: {percentile(queue_wait_ms, 95):.2f}")
|
||||||
|
|
||||||
|
if exec_ms:
|
||||||
|
print(f"execution_mean_ms: {statistics.mean(exec_ms):.2f}")
|
||||||
|
print(f"execution_p95_ms: {percentile(exec_ms, 95):.2f}")
|
||||||
|
|
||||||
|
if fail:
|
||||||
|
print("\nSample failures:")
|
||||||
|
for r in fail[:5]:
|
||||||
|
print(f" idx={r.request_index} prompt_id={r.prompt_id} error={r.error}")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
p = argparse.ArgumentParser(description="Benchmark ComfyUI request serving.")
|
||||||
|
p.add_argument("--host", type=str, default="http://127.0.0.1:8188", help="ComfyUI base URL.")
|
||||||
|
p.add_argument(
|
||||||
|
"--endpoint",
|
||||||
|
type=str,
|
||||||
|
default="/prompt",
|
||||||
|
choices=("/prompt", "/bench/prompt"),
|
||||||
|
help="Submission endpoint.",
|
||||||
|
)
|
||||||
|
p.add_argument("--prompt-file", type=Path, required=True, help="Path to prompt JSON.")
|
||||||
|
p.add_argument("--num-requests", type=int, default=50)
|
||||||
|
p.add_argument("--max-concurrency", type=int, default=8)
|
||||||
|
p.add_argument("--request-rate", type=float, default=0.0, help="Requests/sec. 0 = fire immediately.")
|
||||||
|
p.add_argument("--poisson", action="store_true", help="Use Poisson inter-arrival when request-rate > 0.")
|
||||||
|
p.add_argument("--base-seed", type=int, default=1234)
|
||||||
|
p.add_argument(
|
||||||
|
"--seed-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Optional path to seed field in prompt: <node_id>.<input_name> (e.g. 3.seed).",
|
||||||
|
)
|
||||||
|
p.add_argument("--client-id", type=str, default=f"bench-{uuid.uuid4().hex[:12]}")
|
||||||
|
p.add_argument("--request-timeout-s", type=float, default=600.0)
|
||||||
|
p.add_argument("--poll-interval-s", type=float, default=0.2)
|
||||||
|
p.add_argument("--output-json", type=Path, default=None, help="Write detailed result JSON.")
|
||||||
|
p.add_argument("--seed", type=int, default=0, help="RNG seed for schedule generation.")
|
||||||
|
return p.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
async def async_main(args: argparse.Namespace) -> None:
|
||||||
|
prompt_template = load_prompt_template(args.prompt_file)
|
||||||
|
schedule = build_arrival_schedule(
|
||||||
|
num_requests=args.num_requests,
|
||||||
|
request_rate=args.request_rate,
|
||||||
|
poisson=args.poisson,
|
||||||
|
seed=args.seed,
|
||||||
|
)
|
||||||
|
semaphore = asyncio.Semaphore(args.max_concurrency)
|
||||||
|
connector = aiohttp.TCPConnector(limit=max(args.max_concurrency * 2, 32))
|
||||||
|
|
||||||
|
started = time.perf_counter()
|
||||||
|
async with aiohttp.ClientSession(connector=connector) as session:
|
||||||
|
tasks = [
|
||||||
|
asyncio.create_task(
|
||||||
|
run_request(
|
||||||
|
idx=i,
|
||||||
|
start_time=started,
|
||||||
|
scheduled_offset_s=schedule[i],
|
||||||
|
semaphore=semaphore,
|
||||||
|
session=session,
|
||||||
|
args=args,
|
||||||
|
prompt_wrapper_template=prompt_template,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for i in range(args.num_requests)
|
||||||
|
]
|
||||||
|
results = await asyncio.gather(*tasks)
|
||||||
|
wall_s = time.perf_counter() - started
|
||||||
|
|
||||||
|
print_summary(results, wall_s)
|
||||||
|
|
||||||
|
if args.output_json is not None:
|
||||||
|
out = {
|
||||||
|
"config": vars(args),
|
||||||
|
"wall_time_s": wall_s,
|
||||||
|
"results": [asdict(r) for r in sorted(results, key=lambda x: x.request_index)],
|
||||||
|
}
|
||||||
|
args.output_json.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
args.output_json.write_text(json.dumps(out, indent=2))
|
||||||
|
print(f"\nWrote results to: {args.output_json}")
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = parse_args()
|
||||||
|
asyncio.run(async_main(args))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -225,6 +225,7 @@ parser.add_argument(
|
|||||||
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path. Overrides --base-directory.")
|
parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path. Overrides --base-directory.")
|
||||||
|
|
||||||
parser.add_argument("--enable-compress-response-body", action="store_true", help="Enable compressing response body.")
|
parser.add_argument("--enable-compress-response-body", action="store_true", help="Enable compressing response body.")
|
||||||
|
parser.add_argument("--benchmark-server-only", action="store_true", help="Enable lightweight benchmark routes and worker fast-paths focused on model serving throughput/latency.")
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--comfy-api-base",
|
"--comfy-api-base",
|
||||||
|
|||||||
37
main.py
37
main.py
@ -293,6 +293,7 @@ def prompt_worker(q, server_instance):
|
|||||||
gc_collect_interval = 10.0
|
gc_collect_interval = 10.0
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
benchmark_mode = args.benchmark_server_only
|
||||||
timeout = 1000.0
|
timeout = 1000.0
|
||||||
if need_gc:
|
if need_gc:
|
||||||
timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0)
|
timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0)
|
||||||
@ -301,6 +302,7 @@ def prompt_worker(q, server_instance):
|
|||||||
if queue_item is not None:
|
if queue_item is not None:
|
||||||
item, item_id = queue_item
|
item, item_id = queue_item
|
||||||
execution_start_time = time.perf_counter()
|
execution_start_time = time.perf_counter()
|
||||||
|
execution_start_wall_ms = int(time.time() * 1000)
|
||||||
prompt_id = item[1]
|
prompt_id = item[1]
|
||||||
server_instance.last_prompt_id = prompt_id
|
server_instance.last_prompt_id = prompt_id
|
||||||
|
|
||||||
@ -308,15 +310,21 @@ def prompt_worker(q, server_instance):
|
|||||||
extra_data = item[3].copy()
|
extra_data = item[3].copy()
|
||||||
for k in sensitive:
|
for k in sensitive:
|
||||||
extra_data[k] = sensitive[k]
|
extra_data[k] = sensitive[k]
|
||||||
|
benchmark_mode = args.benchmark_server_only or extra_data.get("benchmark_server_only", False)
|
||||||
|
|
||||||
asset_seeder.pause()
|
if not benchmark_mode:
|
||||||
|
asset_seeder.pause()
|
||||||
e.execute(item[2], prompt_id, extra_data, item[4])
|
e.execute(item[2], prompt_id, extra_data, item[4])
|
||||||
|
|
||||||
need_gc = True
|
need_gc = True
|
||||||
|
|
||||||
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
||||||
|
history_result = e.history_result
|
||||||
|
if benchmark_mode:
|
||||||
|
history_result = {"outputs": {}, "meta": {}}
|
||||||
|
|
||||||
q.task_done(item_id,
|
q.task_done(item_id,
|
||||||
e.history_result,
|
history_result,
|
||||||
status=execution.PromptQueue.ExecutionStatus(
|
status=execution.PromptQueue.ExecutionStatus(
|
||||||
status_str='success' if e.success else 'error',
|
status_str='success' if e.success else 'error',
|
||||||
completed=e.success,
|
completed=e.success,
|
||||||
@ -325,16 +333,24 @@ def prompt_worker(q, server_instance):
|
|||||||
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
|
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
|
||||||
|
|
||||||
current_time = time.perf_counter()
|
current_time = time.perf_counter()
|
||||||
execution_time = current_time - execution_start_time
|
execution_time_s = current_time - execution_start_time
|
||||||
|
|
||||||
# Log Time in a more readable way after 10 minutes
|
# Log Time in a more readable way after 10 minutes
|
||||||
if execution_time > 600:
|
if execution_time_s > 600:
|
||||||
execution_time = time.strftime("%H:%M:%S", time.gmtime(execution_time))
|
execution_time_formatted = time.strftime("%H:%M:%S", time.gmtime(execution_time_s))
|
||||||
logging.info(f"Prompt executed in {execution_time}")
|
logging.info(f"Prompt executed in {execution_time_formatted}")
|
||||||
else:
|
else:
|
||||||
logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
|
logging.info("Prompt executed in {:.2f} seconds".format(execution_time_s))
|
||||||
|
|
||||||
if not asset_seeder.is_disabled():
|
queue_wait_ms = 0.0
|
||||||
|
created_at = extra_data.get("create_time")
|
||||||
|
if isinstance(created_at, int):
|
||||||
|
queue_wait_ms = max(0.0, execution_start_wall_ms - created_at)
|
||||||
|
|
||||||
|
if benchmark_mode:
|
||||||
|
server_instance.record_benchmark_result(prompt_id, e.success, execution_time_s * 1000.0, queue_wait_ms)
|
||||||
|
|
||||||
|
if not benchmark_mode and not asset_seeder.is_disabled():
|
||||||
paths = _collect_output_absolute_paths(e.history_result)
|
paths = _collect_output_absolute_paths(e.history_result)
|
||||||
register_output_files(paths, job_id=prompt_id)
|
register_output_files(paths, job_id=prompt_id)
|
||||||
|
|
||||||
@ -360,9 +376,10 @@ def prompt_worker(q, server_instance):
|
|||||||
need_gc = False
|
need_gc = False
|
||||||
hook_breaker_ac10a0.restore_functions()
|
hook_breaker_ac10a0.restore_functions()
|
||||||
|
|
||||||
if not asset_seeder.is_disabled():
|
if not benchmark_mode and not asset_seeder.is_disabled():
|
||||||
asset_seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
|
asset_seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
|
||||||
asset_seeder.resume()
|
if not benchmark_mode:
|
||||||
|
asset_seeder.resume()
|
||||||
|
|
||||||
|
|
||||||
async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None):
|
async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None):
|
||||||
|
|||||||
78
server.py
78
server.py
@ -16,6 +16,7 @@ import struct
|
|||||||
import ssl
|
import ssl
|
||||||
import socket
|
import socket
|
||||||
import ipaddress
|
import ipaddress
|
||||||
|
import threading
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
from PIL.PngImagePlugin import PngInfo
|
from PIL.PngImagePlugin import PngInfo
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
@ -252,6 +253,17 @@ class PromptServer():
|
|||||||
self.client_id = None
|
self.client_id = None
|
||||||
|
|
||||||
self.on_prompt_handlers = []
|
self.on_prompt_handlers = []
|
||||||
|
self._benchmark_lock = threading.Lock()
|
||||||
|
self._benchmark_stats = {
|
||||||
|
"requests_total": 0,
|
||||||
|
"requests_success": 0,
|
||||||
|
"requests_error": 0,
|
||||||
|
"latency_ms_total": 0.0,
|
||||||
|
"latency_ms_max": 0.0,
|
||||||
|
"queue_wait_ms_total": 0.0,
|
||||||
|
"queue_wait_ms_max": 0.0,
|
||||||
|
"last_prompt_id": None,
|
||||||
|
}
|
||||||
|
|
||||||
@routes.get('/ws')
|
@routes.get('/ws')
|
||||||
async def websocket_handler(request):
|
async def websocket_handler(request):
|
||||||
@ -912,12 +924,17 @@ class PromptServer():
|
|||||||
queue_info['queue_pending'] = _remove_sensitive_from_queue(current_queue[1])
|
queue_info['queue_pending'] = _remove_sensitive_from_queue(current_queue[1])
|
||||||
return web.json_response(queue_info)
|
return web.json_response(queue_info)
|
||||||
|
|
||||||
@routes.post("/prompt")
|
@routes.get("/bench/stats")
|
||||||
async def post_prompt(request):
|
async def get_bench_stats(request):
|
||||||
logging.info("got prompt")
|
stats = self.get_benchmark_stats()
|
||||||
json_data = await request.json()
|
return web.json_response(stats)
|
||||||
json_data = self.trigger_on_prompt(json_data)
|
|
||||||
|
|
||||||
|
@routes.post("/bench/reset")
|
||||||
|
async def reset_bench_stats(request):
|
||||||
|
self.reset_benchmark_stats()
|
||||||
|
return web.json_response({"status": "ok"})
|
||||||
|
|
||||||
|
async def enqueue_prompt(json_data):
|
||||||
if "number" in json_data:
|
if "number" in json_data:
|
||||||
number = float(json_data['number'])
|
number = float(json_data['number'])
|
||||||
else:
|
else:
|
||||||
@ -967,6 +984,22 @@ class PromptServer():
|
|||||||
}
|
}
|
||||||
return web.json_response({"error": error, "node_errors": {}}, status=400)
|
return web.json_response({"error": error, "node_errors": {}}, status=400)
|
||||||
|
|
||||||
|
@routes.post("/bench/prompt")
|
||||||
|
async def post_bench_prompt(request):
|
||||||
|
json_data = await request.json()
|
||||||
|
json_data = self.trigger_on_prompt(json_data)
|
||||||
|
extra_data = json_data.setdefault("extra_data", {})
|
||||||
|
extra_data["benchmark_server_only"] = True
|
||||||
|
extra_data.setdefault("preview_method", "none")
|
||||||
|
return await enqueue_prompt(json_data)
|
||||||
|
|
||||||
|
@routes.post("/prompt")
|
||||||
|
async def post_prompt(request):
|
||||||
|
logging.info("got prompt")
|
||||||
|
json_data = await request.json()
|
||||||
|
json_data = self.trigger_on_prompt(json_data)
|
||||||
|
return await enqueue_prompt(json_data)
|
||||||
|
|
||||||
@routes.post("/queue")
|
@routes.post("/queue")
|
||||||
async def post_queue(request):
|
async def post_queue(request):
|
||||||
json_data = await request.json()
|
json_data = await request.json()
|
||||||
@ -1111,6 +1144,41 @@ class PromptServer():
|
|||||||
prompt_info['exec_info'] = exec_info
|
prompt_info['exec_info'] = exec_info
|
||||||
return prompt_info
|
return prompt_info
|
||||||
|
|
||||||
|
def reset_benchmark_stats(self):
|
||||||
|
with self._benchmark_lock:
|
||||||
|
self._benchmark_stats = {
|
||||||
|
"requests_total": 0,
|
||||||
|
"requests_success": 0,
|
||||||
|
"requests_error": 0,
|
||||||
|
"latency_ms_total": 0.0,
|
||||||
|
"latency_ms_max": 0.0,
|
||||||
|
"queue_wait_ms_total": 0.0,
|
||||||
|
"queue_wait_ms_max": 0.0,
|
||||||
|
"last_prompt_id": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
def record_benchmark_result(self, prompt_id, success, latency_ms, queue_wait_ms=0.0):
|
||||||
|
with self._benchmark_lock:
|
||||||
|
self._benchmark_stats["requests_total"] += 1
|
||||||
|
if success:
|
||||||
|
self._benchmark_stats["requests_success"] += 1
|
||||||
|
else:
|
||||||
|
self._benchmark_stats["requests_error"] += 1
|
||||||
|
self._benchmark_stats["latency_ms_total"] += max(0.0, latency_ms)
|
||||||
|
self._benchmark_stats["queue_wait_ms_total"] += max(0.0, queue_wait_ms)
|
||||||
|
self._benchmark_stats["latency_ms_max"] = max(self._benchmark_stats["latency_ms_max"], max(0.0, latency_ms))
|
||||||
|
self._benchmark_stats["queue_wait_ms_max"] = max(self._benchmark_stats["queue_wait_ms_max"], max(0.0, queue_wait_ms))
|
||||||
|
self._benchmark_stats["last_prompt_id"] = prompt_id
|
||||||
|
|
||||||
|
def get_benchmark_stats(self):
|
||||||
|
with self._benchmark_lock:
|
||||||
|
stats = dict(self._benchmark_stats)
|
||||||
|
|
||||||
|
total = stats["requests_total"]
|
||||||
|
stats["latency_ms_avg"] = (stats["latency_ms_total"] / total) if total > 0 else 0.0
|
||||||
|
stats["queue_wait_ms_avg"] = (stats["queue_wait_ms_total"] / total) if total > 0 else 0.0
|
||||||
|
return stats
|
||||||
|
|
||||||
async def send(self, event, data, sid=None):
|
async def send(self, event, data, sid=None):
|
||||||
if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
|
if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
|
||||||
await self.send_image(data, sid=sid)
|
await self.send_image(data, sid=sid)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user