mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-02 13:22:32 +08:00
Add benchmark for model serving
This commit is contained in:
parent
7636599389
commit
ac85d7887f
444
benchmarks/benchmarking_model_serving.py
Normal file
444
benchmarks/benchmarking_model_serving.py
Normal file
@ -0,0 +1,444 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import statistics
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
# Make the repo root importable when running directly from the benchmarks/ dir.
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.sd
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Data models
|
||||
# -----------------------------
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RequestSpec:
|
||||
profile_name: str
|
||||
batch_size: int
|
||||
width: int
|
||||
height: int
|
||||
num_frames: int
|
||||
steps: int
|
||||
cfg_scale: float
|
||||
seed: int
|
||||
timeout_s: float = 180.0
|
||||
extra: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RequestResult:
|
||||
request_id: int
|
||||
profile_name: str
|
||||
ok: bool
|
||||
error: Optional[str]
|
||||
latency_ms: float
|
||||
queue_wait_ms: float
|
||||
step_latencies_ms: List[float]
|
||||
ttfs_ms: float # time to first (denoising) step
|
||||
peak_vram_mb: float
|
||||
est_mem_mb: Optional[float]
|
||||
started_at: float
|
||||
ended_at: float
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RunSummary:
|
||||
total_requests: int
|
||||
success: int
|
||||
failed: int
|
||||
throughput_req_s: float
|
||||
p50_ms: float
|
||||
p90_ms: float
|
||||
p95_ms: float
|
||||
p99_ms: float
|
||||
mean_ms: float
|
||||
ttfs_p50_ms: float
|
||||
ttfs_p99_ms: float
|
||||
step_mean_ms: float
|
||||
step_p99_ms: float
|
||||
max_vram_mb: float
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Helpers
|
||||
# -----------------------------
|
||||
|
||||
def percentile(values: List[float], p: float) -> float:
|
||||
if not values:
|
||||
return float("nan")
|
||||
values = sorted(values)
|
||||
k = (len(values) - 1) * (p / 100.0)
|
||||
f = math.floor(k)
|
||||
c = math.ceil(k)
|
||||
if f == c:
|
||||
return values[int(k)]
|
||||
return values[f] * (c - k) + values[c] * (k - f)
|
||||
|
||||
|
||||
def now() -> float:
|
||||
return time.perf_counter()
|
||||
|
||||
|
||||
def gpu_peak_mb() -> float:
|
||||
if not torch.cuda.is_available():
|
||||
return 0.0
|
||||
return torch.cuda.max_memory_allocated() / (1024 ** 2)
|
||||
|
||||
|
||||
def reset_gpu_peak() -> None:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
|
||||
def sync_cuda() -> None:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def build_request_stream(
|
||||
num_requests: int,
|
||||
base_seed: int,
|
||||
profiles: List[RequestSpec],
|
||||
weighted: Optional[List[float]] = None,
|
||||
) -> List[RequestSpec]:
|
||||
rnd = random.Random(base_seed)
|
||||
out: List[RequestSpec] = []
|
||||
for i in range(num_requests):
|
||||
p = rnd.choices(profiles, weights=weighted, k=1)[0]
|
||||
out.append(dataclasses.replace(p, seed=base_seed + i))
|
||||
return out
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Model adapter
|
||||
# -----------------------------
|
||||
|
||||
class WanRunner:
|
||||
"""
|
||||
Thin adapter around ComfyUI model loading + the BaseModel.apply_model call path.
|
||||
|
||||
Only the DiT denoiser is timed — no VAE encode/decode, no CLIP, no scheduler
|
||||
overhead — so measurements reflect true model inference cost.
|
||||
|
||||
Latent shape convention (WAN): [B, 16, T, H//8, W//8]
|
||||
Text conditioning shape (UMT5): [B, text_seq_len, text_dim] (zeros for benchmarking)
|
||||
Sigma schedule (flow-matching): linspace(1.0 → 1/steps, steps)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint: str,
|
||||
device: str,
|
||||
dtype_str: str,
|
||||
text_seq_len: int = 512,
|
||||
text_dim: int = 4096,
|
||||
):
|
||||
self.checkpoint = checkpoint
|
||||
self.device_str = device
|
||||
self.dtype_str = dtype_str
|
||||
self.text_seq_len = text_seq_len
|
||||
self.text_dim = text_dim
|
||||
self.patcher, self.model = self._load_model()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internals
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _load_model(self):
|
||||
dtype_map = {
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
"fp32": torch.float32,
|
||||
}
|
||||
dtype = dtype_map.get(self.dtype_str)
|
||||
model_opts = {"dtype": dtype} if dtype is not None else {}
|
||||
|
||||
patcher = comfy.sd.load_diffusion_model(self.checkpoint, model_options=model_opts)
|
||||
# force_full_load=True keeps the whole model resident on GPU rather than
|
||||
# streaming weights on demand (important for latency benchmarking).
|
||||
comfy.model_management.load_models_gpu([patcher], force_full_load=True)
|
||||
return patcher, patcher.model
|
||||
|
||||
def _estimate_mem_mb(self, latent_shape: tuple, text_seq_len: int) -> Optional[float]:
|
||||
cond_shapes = {
|
||||
"c_crossattn": [(latent_shape[0], text_seq_len, self.text_dim)],
|
||||
}
|
||||
try:
|
||||
return self.model.memory_required(latent_shape, cond_shapes) / (1024 ** 2)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Single-request execution
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_one(self, req: RequestSpec) -> RequestResult:
|
||||
start = now()
|
||||
reset_gpu_peak()
|
||||
|
||||
step_latencies: List[float] = []
|
||||
ttfs_ms = float("nan")
|
||||
est_mem_mb: Optional[float] = None
|
||||
ok = True
|
||||
err = None
|
||||
|
||||
try:
|
||||
device = comfy.model_management.get_torch_device()
|
||||
dtype = self.model.get_dtype_inference()
|
||||
|
||||
# Latent noise tensor: [B, 16 channels, T frames, H/8, W/8]
|
||||
latent_shape = (
|
||||
req.batch_size, 16,
|
||||
req.num_frames,
|
||||
req.height // 8,
|
||||
req.width // 8,
|
||||
)
|
||||
x = torch.randn(latent_shape, dtype=dtype, device=device)
|
||||
est_mem_mb = self._estimate_mem_mb(latent_shape, self.text_seq_len)
|
||||
|
||||
# Fake text conditioning — zeros have the right shape, non-zero
|
||||
# values are not needed for throughput/latency benchmarking.
|
||||
cross_attn = torch.zeros(
|
||||
req.batch_size, self.text_seq_len, self.text_dim,
|
||||
dtype=dtype, device=device,
|
||||
)
|
||||
|
||||
# Linear sigma schedule: 1.0 → 1/steps (flow-matching, noise→clean)
|
||||
sigmas = torch.linspace(1.0, 1.0 / req.steps, req.steps, device=device)
|
||||
|
||||
for step_i, sigma_val in enumerate(sigmas):
|
||||
sigma_t = sigma_val.expand(req.batch_size)
|
||||
t0 = now()
|
||||
x = self.model.apply_model(x, sigma_t, c_crossattn=cross_attn)
|
||||
sync_cuda()
|
||||
elapsed_ms = (now() - t0) * 1000.0
|
||||
step_latencies.append(elapsed_ms)
|
||||
if step_i == 0:
|
||||
ttfs_ms = elapsed_ms
|
||||
|
||||
except Exception as e:
|
||||
ok = False
|
||||
err = repr(e)
|
||||
|
||||
end = now()
|
||||
return RequestResult(
|
||||
request_id=-1,
|
||||
profile_name=req.profile_name,
|
||||
ok=ok,
|
||||
error=err,
|
||||
latency_ms=(end - start) * 1000.0,
|
||||
queue_wait_ms=0.0, # filled in by the scheduler
|
||||
step_latencies_ms=step_latencies,
|
||||
ttfs_ms=ttfs_ms,
|
||||
peak_vram_mb=gpu_peak_mb(),
|
||||
est_mem_mb=est_mem_mb,
|
||||
started_at=start,
|
||||
ended_at=end,
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Serving-style scheduler
|
||||
# -----------------------------
|
||||
|
||||
async def run_closed_loop(
|
||||
runner: WanRunner,
|
||||
requests: List[RequestSpec],
|
||||
concurrency: int,
|
||||
request_rate: float = float("inf"),
|
||||
) -> List[RequestResult]:
|
||||
"""
|
||||
Closed-loop scheduler (default) or Poisson open-loop when request_rate is finite.
|
||||
|
||||
Each request is dispatched to a thread so the asyncio event loop stays
|
||||
free to issue the next request while the GPU is busy.
|
||||
"""
|
||||
sem = asyncio.Semaphore(concurrency)
|
||||
results: List[Optional[RequestResult]] = [None] * len(requests)
|
||||
|
||||
async def worker(i: int, req: RequestSpec) -> None:
|
||||
async with sem:
|
||||
t_enq = now()
|
||||
res = await asyncio.to_thread(runner.run_one, req)
|
||||
res.request_id = i
|
||||
res.queue_wait_ms = max(0.0, (res.started_at - t_enq) * 1000.0)
|
||||
results[i] = res
|
||||
|
||||
if request_rate == float("inf") or request_rate <= 0:
|
||||
await asyncio.gather(*(worker(i, r) for i, r in enumerate(requests)))
|
||||
else:
|
||||
tasks: List[asyncio.Task] = []
|
||||
for i, req in enumerate(requests):
|
||||
if i > 0:
|
||||
await asyncio.sleep(random.expovariate(request_rate))
|
||||
tasks.append(asyncio.create_task(worker(i, req)))
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
return [r for r in results if r is not None]
|
||||
|
||||
|
||||
def summarize(results: List[RequestResult], wall_s: float) -> RunSummary:
|
||||
lat = [r.latency_ms for r in results if r.ok]
|
||||
ttfs = [r.ttfs_ms for r in results if r.ok and math.isfinite(r.ttfs_ms)]
|
||||
all_steps = [s for r in results if r.ok for s in r.step_latencies_ms]
|
||||
succ = sum(1 for r in results if r.ok)
|
||||
fail = len(results) - succ
|
||||
return RunSummary(
|
||||
total_requests=len(results),
|
||||
success=succ,
|
||||
failed=fail,
|
||||
throughput_req_s=(succ / wall_s) if wall_s > 0 else 0.0,
|
||||
p50_ms=percentile(lat, 50),
|
||||
p90_ms=percentile(lat, 90),
|
||||
p95_ms=percentile(lat, 95),
|
||||
p99_ms=percentile(lat, 99),
|
||||
mean_ms=(statistics.mean(lat) if lat else float("nan")),
|
||||
ttfs_p50_ms=percentile(ttfs, 50),
|
||||
ttfs_p99_ms=percentile(ttfs, 99),
|
||||
step_mean_ms=(statistics.mean(all_steps) if all_steps else float("nan")),
|
||||
step_p99_ms=percentile(all_steps, 99),
|
||||
max_vram_mb=max((r.peak_vram_mb for r in results), default=0.0),
|
||||
)
|
||||
|
||||
|
||||
def print_summary(
|
||||
args: argparse.Namespace,
|
||||
summ: RunSummary,
|
||||
total_requests: int,
|
||||
wall_s: float,
|
||||
) -> None:
|
||||
w = 60
|
||||
sep = "-" * w
|
||||
print("\n" + "=" * w)
|
||||
print("{s:^{n}}".format(s=" WAN Benchmark Result ", n=w))
|
||||
print("=" * w)
|
||||
print("{:<40} {:<}".format("Checkpoint:", Path(args.checkpoint).name))
|
||||
print("{:<40} {:<}".format("Device / dtype:", f"{args.device}/{args.dtype}"))
|
||||
print("{:<40} {:<}".format("Concurrency:", args.concurrency))
|
||||
rate_str = f"{args.request_rate:.1f} req/s" if args.request_rate != float("inf") else "inf (closed-loop)"
|
||||
print("{:<40} {:<}".format("Request rate:", rate_str))
|
||||
print(sep)
|
||||
print("{:<40} {:<.2f}".format("Benchmark duration (s):", wall_s))
|
||||
print("{:<40} {}/{}".format("Successful requests:", summ.success, total_requests))
|
||||
if summ.failed:
|
||||
print("{:<40} {:<}".format("Failed requests:", summ.failed))
|
||||
print(sep)
|
||||
print("{:<40} {:<.3f}".format("Throughput (req/s):", summ.throughput_req_s))
|
||||
print("{:<40} {:<.1f}".format("Latency mean (ms):", summ.mean_ms))
|
||||
print("{:<40} {:<.1f}".format("Latency p50 (ms):", summ.p50_ms))
|
||||
print("{:<40} {:<.1f}".format("Latency p90 (ms):", summ.p90_ms))
|
||||
print("{:<40} {:<.1f}".format("Latency p95 (ms):", summ.p95_ms))
|
||||
print("{:<40} {:<.1f}".format("Latency p99 (ms):", summ.p99_ms))
|
||||
print(sep)
|
||||
print("{:<40} {:<.1f}".format("TTFS p50 (ms):", summ.ttfs_p50_ms))
|
||||
print("{:<40} {:<.1f}".format("TTFS p99 (ms):", summ.ttfs_p99_ms))
|
||||
print("{:<40} {:<.1f}".format("Step latency mean (ms):", summ.step_mean_ms))
|
||||
print("{:<40} {:<.1f}".format("Step latency p99 (ms):", summ.step_p99_ms))
|
||||
print(sep)
|
||||
print("{:<40} {:<.1f}".format("Peak VRAM (MB):", summ.max_vram_mb))
|
||||
print("=" * w)
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# CLI
|
||||
# -----------------------------
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(
|
||||
description="Benchmark ComfyUI WAN diffusion model denoising throughput and latency."
|
||||
)
|
||||
p.add_argument(
|
||||
"--checkpoint", required=True,
|
||||
help="Path to the WAN diffusion-model checkpoint (.safetensors / .pt).",
|
||||
)
|
||||
p.add_argument("--device", default="cuda")
|
||||
p.add_argument("--dtype", default="fp16", choices=["fp16", "bf16", "fp32"])
|
||||
p.add_argument("--num-requests", type=int, default=100)
|
||||
p.add_argument("--concurrency", type=int, default=4,
|
||||
help="Max number of in-flight requests (semaphore width).")
|
||||
p.add_argument(
|
||||
"--request-rate", type=float, default=float("inf"),
|
||||
help="Poisson arrival rate in req/s. inf = closed-loop (default).",
|
||||
)
|
||||
p.add_argument("--warmup-requests", type=int, default=2,
|
||||
help="Warmup iterations excluded from metrics.")
|
||||
p.add_argument("--seed", type=int, default=1234)
|
||||
p.add_argument("--text-seq-len", type=int, default=512,
|
||||
help="Cross-attention sequence length (UMT5 default: 512).")
|
||||
p.add_argument("--text-dim", type=int, default=4096,
|
||||
help="Text embedding width (UMT5-XXL: 4096).")
|
||||
p.add_argument("--out-dir", type=Path, default=Path("benchmarks/out"))
|
||||
p.add_argument("--output-file", type=Path, default=None,
|
||||
help="Override path for the summary JSON output.")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def default_profiles() -> List[RequestSpec]:
|
||||
return [
|
||||
RequestSpec("wan21_t2v_720p_16f_30s", 1, 1280, 720, 16, 30, 6.0, 0),
|
||||
RequestSpec("wan21_t2v_720p_32f_30s", 1, 1280, 720, 32, 30, 6.0, 0),
|
||||
RequestSpec("wan21_t2v_480p_32f_20s", 1, 854, 480, 32, 20, 6.0, 0),
|
||||
]
|
||||
|
||||
|
||||
async def main_async() -> None:
|
||||
args = parse_args()
|
||||
args.out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
runner = WanRunner(
|
||||
checkpoint=args.checkpoint,
|
||||
device=args.device,
|
||||
dtype_str=args.dtype,
|
||||
text_seq_len=args.text_seq_len,
|
||||
text_dim=args.text_dim,
|
||||
)
|
||||
|
||||
all_reqs = build_request_stream(
|
||||
args.num_requests + args.warmup_requests,
|
||||
args.seed,
|
||||
default_profiles(),
|
||||
)
|
||||
warmup_reqs = all_reqs[: args.warmup_requests]
|
||||
bench_reqs = all_reqs[args.warmup_requests :]
|
||||
|
||||
if warmup_reqs:
|
||||
print(f"Running {len(warmup_reqs)} warmup request(s)...")
|
||||
for req in warmup_reqs:
|
||||
runner.run_one(req)
|
||||
print("Warmup complete.")
|
||||
|
||||
print(f"Benchmarking {len(bench_reqs)} requests (concurrency={args.concurrency})...")
|
||||
t0 = now()
|
||||
results = await run_closed_loop(runner, bench_reqs, args.concurrency, args.request_rate)
|
||||
wall_s = now() - t0
|
||||
|
||||
summ = summarize(results, wall_s)
|
||||
print_summary(args, summ, len(bench_reqs), wall_s)
|
||||
|
||||
out_file = args.output_file or (args.out_dir / "summary.json")
|
||||
with open(args.out_dir / "requests.jsonl", "w") as f:
|
||||
for r in results:
|
||||
f.write(json.dumps(dataclasses.asdict(r)) + "\n")
|
||||
with open(out_file, "w") as f:
|
||||
json.dump(dataclasses.asdict(summ), f, indent=2)
|
||||
print(f"\nResults written to {args.out_dir}/")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main_async())
|
||||
Loading…
Reference in New Issue
Block a user