Add benchmark for model serving

This commit is contained in:
Tara Ding 2026-04-24 16:24:28 -07:00
parent 7636599389
commit ac85d7887f

View File

@ -0,0 +1,444 @@
from __future__ import annotations
import argparse
import asyncio
import dataclasses
import json
import math
import os
import random
import statistics
import sys
import time
from pathlib import Path
from typing import Any, Dict, List, Optional
import torch
# Make the repo root importable when running directly from the benchmarks/ dir.
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import comfy.model_management
import comfy.sd
# -----------------------------
# Data models
# -----------------------------
@dataclasses.dataclass
class RequestSpec:
profile_name: str
batch_size: int
width: int
height: int
num_frames: int
steps: int
cfg_scale: float
seed: int
timeout_s: float = 180.0
extra: Dict[str, Any] = dataclasses.field(default_factory=dict)
@dataclasses.dataclass
class RequestResult:
request_id: int
profile_name: str
ok: bool
error: Optional[str]
latency_ms: float
queue_wait_ms: float
step_latencies_ms: List[float]
ttfs_ms: float # time to first (denoising) step
peak_vram_mb: float
est_mem_mb: Optional[float]
started_at: float
ended_at: float
@dataclasses.dataclass
class RunSummary:
total_requests: int
success: int
failed: int
throughput_req_s: float
p50_ms: float
p90_ms: float
p95_ms: float
p99_ms: float
mean_ms: float
ttfs_p50_ms: float
ttfs_p99_ms: float
step_mean_ms: float
step_p99_ms: float
max_vram_mb: float
# -----------------------------
# Helpers
# -----------------------------
def percentile(values: List[float], p: float) -> float:
if not values:
return float("nan")
values = sorted(values)
k = (len(values) - 1) * (p / 100.0)
f = math.floor(k)
c = math.ceil(k)
if f == c:
return values[int(k)]
return values[f] * (c - k) + values[c] * (k - f)
def now() -> float:
return time.perf_counter()
def gpu_peak_mb() -> float:
if not torch.cuda.is_available():
return 0.0
return torch.cuda.max_memory_allocated() / (1024 ** 2)
def reset_gpu_peak() -> None:
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
def sync_cuda() -> None:
if torch.cuda.is_available():
torch.cuda.synchronize()
def build_request_stream(
num_requests: int,
base_seed: int,
profiles: List[RequestSpec],
weighted: Optional[List[float]] = None,
) -> List[RequestSpec]:
rnd = random.Random(base_seed)
out: List[RequestSpec] = []
for i in range(num_requests):
p = rnd.choices(profiles, weights=weighted, k=1)[0]
out.append(dataclasses.replace(p, seed=base_seed + i))
return out
# -----------------------------
# Model adapter
# -----------------------------
class WanRunner:
"""
Thin adapter around ComfyUI model loading + the BaseModel.apply_model call path.
Only the DiT denoiser is timed no VAE encode/decode, no CLIP, no scheduler
overhead so measurements reflect true model inference cost.
Latent shape convention (WAN): [B, 16, T, H//8, W//8]
Text conditioning shape (UMT5): [B, text_seq_len, text_dim] (zeros for benchmarking)
Sigma schedule (flow-matching): linspace(1.0 1/steps, steps)
"""
def __init__(
self,
checkpoint: str,
device: str,
dtype_str: str,
text_seq_len: int = 512,
text_dim: int = 4096,
):
self.checkpoint = checkpoint
self.device_str = device
self.dtype_str = dtype_str
self.text_seq_len = text_seq_len
self.text_dim = text_dim
self.patcher, self.model = self._load_model()
# ------------------------------------------------------------------
# Internals
# ------------------------------------------------------------------
def _load_model(self):
dtype_map = {
"fp16": torch.float16,
"bf16": torch.bfloat16,
"fp32": torch.float32,
}
dtype = dtype_map.get(self.dtype_str)
model_opts = {"dtype": dtype} if dtype is not None else {}
patcher = comfy.sd.load_diffusion_model(self.checkpoint, model_options=model_opts)
# force_full_load=True keeps the whole model resident on GPU rather than
# streaming weights on demand (important for latency benchmarking).
comfy.model_management.load_models_gpu([patcher], force_full_load=True)
return patcher, patcher.model
def _estimate_mem_mb(self, latent_shape: tuple, text_seq_len: int) -> Optional[float]:
cond_shapes = {
"c_crossattn": [(latent_shape[0], text_seq_len, self.text_dim)],
}
try:
return self.model.memory_required(latent_shape, cond_shapes) / (1024 ** 2)
except Exception:
return None
# ------------------------------------------------------------------
# Single-request execution
# ------------------------------------------------------------------
@torch.inference_mode()
def run_one(self, req: RequestSpec) -> RequestResult:
start = now()
reset_gpu_peak()
step_latencies: List[float] = []
ttfs_ms = float("nan")
est_mem_mb: Optional[float] = None
ok = True
err = None
try:
device = comfy.model_management.get_torch_device()
dtype = self.model.get_dtype_inference()
# Latent noise tensor: [B, 16 channels, T frames, H/8, W/8]
latent_shape = (
req.batch_size, 16,
req.num_frames,
req.height // 8,
req.width // 8,
)
x = torch.randn(latent_shape, dtype=dtype, device=device)
est_mem_mb = self._estimate_mem_mb(latent_shape, self.text_seq_len)
# Fake text conditioning — zeros have the right shape, non-zero
# values are not needed for throughput/latency benchmarking.
cross_attn = torch.zeros(
req.batch_size, self.text_seq_len, self.text_dim,
dtype=dtype, device=device,
)
# Linear sigma schedule: 1.0 → 1/steps (flow-matching, noise→clean)
sigmas = torch.linspace(1.0, 1.0 / req.steps, req.steps, device=device)
for step_i, sigma_val in enumerate(sigmas):
sigma_t = sigma_val.expand(req.batch_size)
t0 = now()
x = self.model.apply_model(x, sigma_t, c_crossattn=cross_attn)
sync_cuda()
elapsed_ms = (now() - t0) * 1000.0
step_latencies.append(elapsed_ms)
if step_i == 0:
ttfs_ms = elapsed_ms
except Exception as e:
ok = False
err = repr(e)
end = now()
return RequestResult(
request_id=-1,
profile_name=req.profile_name,
ok=ok,
error=err,
latency_ms=(end - start) * 1000.0,
queue_wait_ms=0.0, # filled in by the scheduler
step_latencies_ms=step_latencies,
ttfs_ms=ttfs_ms,
peak_vram_mb=gpu_peak_mb(),
est_mem_mb=est_mem_mb,
started_at=start,
ended_at=end,
)
# -----------------------------
# Serving-style scheduler
# -----------------------------
async def run_closed_loop(
runner: WanRunner,
requests: List[RequestSpec],
concurrency: int,
request_rate: float = float("inf"),
) -> List[RequestResult]:
"""
Closed-loop scheduler (default) or Poisson open-loop when request_rate is finite.
Each request is dispatched to a thread so the asyncio event loop stays
free to issue the next request while the GPU is busy.
"""
sem = asyncio.Semaphore(concurrency)
results: List[Optional[RequestResult]] = [None] * len(requests)
async def worker(i: int, req: RequestSpec) -> None:
async with sem:
t_enq = now()
res = await asyncio.to_thread(runner.run_one, req)
res.request_id = i
res.queue_wait_ms = max(0.0, (res.started_at - t_enq) * 1000.0)
results[i] = res
if request_rate == float("inf") or request_rate <= 0:
await asyncio.gather(*(worker(i, r) for i, r in enumerate(requests)))
else:
tasks: List[asyncio.Task] = []
for i, req in enumerate(requests):
if i > 0:
await asyncio.sleep(random.expovariate(request_rate))
tasks.append(asyncio.create_task(worker(i, req)))
await asyncio.gather(*tasks)
return [r for r in results if r is not None]
def summarize(results: List[RequestResult], wall_s: float) -> RunSummary:
lat = [r.latency_ms for r in results if r.ok]
ttfs = [r.ttfs_ms for r in results if r.ok and math.isfinite(r.ttfs_ms)]
all_steps = [s for r in results if r.ok for s in r.step_latencies_ms]
succ = sum(1 for r in results if r.ok)
fail = len(results) - succ
return RunSummary(
total_requests=len(results),
success=succ,
failed=fail,
throughput_req_s=(succ / wall_s) if wall_s > 0 else 0.0,
p50_ms=percentile(lat, 50),
p90_ms=percentile(lat, 90),
p95_ms=percentile(lat, 95),
p99_ms=percentile(lat, 99),
mean_ms=(statistics.mean(lat) if lat else float("nan")),
ttfs_p50_ms=percentile(ttfs, 50),
ttfs_p99_ms=percentile(ttfs, 99),
step_mean_ms=(statistics.mean(all_steps) if all_steps else float("nan")),
step_p99_ms=percentile(all_steps, 99),
max_vram_mb=max((r.peak_vram_mb for r in results), default=0.0),
)
def print_summary(
args: argparse.Namespace,
summ: RunSummary,
total_requests: int,
wall_s: float,
) -> None:
w = 60
sep = "-" * w
print("\n" + "=" * w)
print("{s:^{n}}".format(s=" WAN Benchmark Result ", n=w))
print("=" * w)
print("{:<40} {:<}".format("Checkpoint:", Path(args.checkpoint).name))
print("{:<40} {:<}".format("Device / dtype:", f"{args.device}/{args.dtype}"))
print("{:<40} {:<}".format("Concurrency:", args.concurrency))
rate_str = f"{args.request_rate:.1f} req/s" if args.request_rate != float("inf") else "inf (closed-loop)"
print("{:<40} {:<}".format("Request rate:", rate_str))
print(sep)
print("{:<40} {:<.2f}".format("Benchmark duration (s):", wall_s))
print("{:<40} {}/{}".format("Successful requests:", summ.success, total_requests))
if summ.failed:
print("{:<40} {:<}".format("Failed requests:", summ.failed))
print(sep)
print("{:<40} {:<.3f}".format("Throughput (req/s):", summ.throughput_req_s))
print("{:<40} {:<.1f}".format("Latency mean (ms):", summ.mean_ms))
print("{:<40} {:<.1f}".format("Latency p50 (ms):", summ.p50_ms))
print("{:<40} {:<.1f}".format("Latency p90 (ms):", summ.p90_ms))
print("{:<40} {:<.1f}".format("Latency p95 (ms):", summ.p95_ms))
print("{:<40} {:<.1f}".format("Latency p99 (ms):", summ.p99_ms))
print(sep)
print("{:<40} {:<.1f}".format("TTFS p50 (ms):", summ.ttfs_p50_ms))
print("{:<40} {:<.1f}".format("TTFS p99 (ms):", summ.ttfs_p99_ms))
print("{:<40} {:<.1f}".format("Step latency mean (ms):", summ.step_mean_ms))
print("{:<40} {:<.1f}".format("Step latency p99 (ms):", summ.step_p99_ms))
print(sep)
print("{:<40} {:<.1f}".format("Peak VRAM (MB):", summ.max_vram_mb))
print("=" * w)
# -----------------------------
# CLI
# -----------------------------
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description="Benchmark ComfyUI WAN diffusion model denoising throughput and latency."
)
p.add_argument(
"--checkpoint", required=True,
help="Path to the WAN diffusion-model checkpoint (.safetensors / .pt).",
)
p.add_argument("--device", default="cuda")
p.add_argument("--dtype", default="fp16", choices=["fp16", "bf16", "fp32"])
p.add_argument("--num-requests", type=int, default=100)
p.add_argument("--concurrency", type=int, default=4,
help="Max number of in-flight requests (semaphore width).")
p.add_argument(
"--request-rate", type=float, default=float("inf"),
help="Poisson arrival rate in req/s. inf = closed-loop (default).",
)
p.add_argument("--warmup-requests", type=int, default=2,
help="Warmup iterations excluded from metrics.")
p.add_argument("--seed", type=int, default=1234)
p.add_argument("--text-seq-len", type=int, default=512,
help="Cross-attention sequence length (UMT5 default: 512).")
p.add_argument("--text-dim", type=int, default=4096,
help="Text embedding width (UMT5-XXL: 4096).")
p.add_argument("--out-dir", type=Path, default=Path("benchmarks/out"))
p.add_argument("--output-file", type=Path, default=None,
help="Override path for the summary JSON output.")
return p.parse_args()
def default_profiles() -> List[RequestSpec]:
return [
RequestSpec("wan21_t2v_720p_16f_30s", 1, 1280, 720, 16, 30, 6.0, 0),
RequestSpec("wan21_t2v_720p_32f_30s", 1, 1280, 720, 32, 30, 6.0, 0),
RequestSpec("wan21_t2v_480p_32f_20s", 1, 854, 480, 32, 20, 6.0, 0),
]
async def main_async() -> None:
args = parse_args()
args.out_dir.mkdir(parents=True, exist_ok=True)
runner = WanRunner(
checkpoint=args.checkpoint,
device=args.device,
dtype_str=args.dtype,
text_seq_len=args.text_seq_len,
text_dim=args.text_dim,
)
all_reqs = build_request_stream(
args.num_requests + args.warmup_requests,
args.seed,
default_profiles(),
)
warmup_reqs = all_reqs[: args.warmup_requests]
bench_reqs = all_reqs[args.warmup_requests :]
if warmup_reqs:
print(f"Running {len(warmup_reqs)} warmup request(s)...")
for req in warmup_reqs:
runner.run_one(req)
print("Warmup complete.")
print(f"Benchmarking {len(bench_reqs)} requests (concurrency={args.concurrency})...")
t0 = now()
results = await run_closed_loop(runner, bench_reqs, args.concurrency, args.request_rate)
wall_s = now() - t0
summ = summarize(results, wall_s)
print_summary(args, summ, len(bench_reqs), wall_s)
out_file = args.output_file or (args.out_dir / "summary.json")
with open(args.out_dir / "requests.jsonl", "w") as f:
for r in results:
f.write(json.dumps(dataclasses.asdict(r)) + "\n")
with open(out_file, "w") as f:
json.dump(dataclasses.asdict(summ), f, indent=2)
print(f"\nResults written to {args.out_dir}/")
if __name__ == "__main__":
asyncio.run(main_async())