diff --git a/benchmarks/benchmarking_model_serving.py b/benchmarks/benchmarking_model_serving.py new file mode 100644 index 000000000..493d0a574 --- /dev/null +++ b/benchmarks/benchmarking_model_serving.py @@ -0,0 +1,444 @@ +from __future__ import annotations + +import argparse +import asyncio +import dataclasses +import json +import math +import os +import random +import statistics +import sys +import time +from pathlib import Path +from typing import Any, Dict, List, Optional + +import torch + +# Make the repo root importable when running directly from the benchmarks/ dir. +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import comfy.model_management +import comfy.sd + + +# ----------------------------- +# Data models +# ----------------------------- + +@dataclasses.dataclass +class RequestSpec: + profile_name: str + batch_size: int + width: int + height: int + num_frames: int + steps: int + cfg_scale: float + seed: int + timeout_s: float = 180.0 + extra: Dict[str, Any] = dataclasses.field(default_factory=dict) + + +@dataclasses.dataclass +class RequestResult: + request_id: int + profile_name: str + ok: bool + error: Optional[str] + latency_ms: float + queue_wait_ms: float + step_latencies_ms: List[float] + ttfs_ms: float # time to first (denoising) step + peak_vram_mb: float + est_mem_mb: Optional[float] + started_at: float + ended_at: float + + +@dataclasses.dataclass +class RunSummary: + total_requests: int + success: int + failed: int + throughput_req_s: float + p50_ms: float + p90_ms: float + p95_ms: float + p99_ms: float + mean_ms: float + ttfs_p50_ms: float + ttfs_p99_ms: float + step_mean_ms: float + step_p99_ms: float + max_vram_mb: float + + +# ----------------------------- +# Helpers +# ----------------------------- + +def percentile(values: List[float], p: float) -> float: + if not values: + return float("nan") + values = sorted(values) + k = (len(values) - 1) * (p / 100.0) + f = math.floor(k) + c = math.ceil(k) + if f == c: + return values[int(k)] + return values[f] * (c - k) + values[c] * (k - f) + + +def now() -> float: + return time.perf_counter() + + +def gpu_peak_mb() -> float: + if not torch.cuda.is_available(): + return 0.0 + return torch.cuda.max_memory_allocated() / (1024 ** 2) + + +def reset_gpu_peak() -> None: + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats() + + +def sync_cuda() -> None: + if torch.cuda.is_available(): + torch.cuda.synchronize() + + +def build_request_stream( + num_requests: int, + base_seed: int, + profiles: List[RequestSpec], + weighted: Optional[List[float]] = None, +) -> List[RequestSpec]: + rnd = random.Random(base_seed) + out: List[RequestSpec] = [] + for i in range(num_requests): + p = rnd.choices(profiles, weights=weighted, k=1)[0] + out.append(dataclasses.replace(p, seed=base_seed + i)) + return out + + +# ----------------------------- +# Model adapter +# ----------------------------- + +class WanRunner: + """ + Thin adapter around ComfyUI model loading + the BaseModel.apply_model call path. + + Only the DiT denoiser is timed — no VAE encode/decode, no CLIP, no scheduler + overhead — so measurements reflect true model inference cost. + + Latent shape convention (WAN): [B, 16, T, H//8, W//8] + Text conditioning shape (UMT5): [B, text_seq_len, text_dim] (zeros for benchmarking) + Sigma schedule (flow-matching): linspace(1.0 → 1/steps, steps) + """ + + def __init__( + self, + checkpoint: str, + device: str, + dtype_str: str, + text_seq_len: int = 512, + text_dim: int = 4096, + ): + self.checkpoint = checkpoint + self.device_str = device + self.dtype_str = dtype_str + self.text_seq_len = text_seq_len + self.text_dim = text_dim + self.patcher, self.model = self._load_model() + + # ------------------------------------------------------------------ + # Internals + # ------------------------------------------------------------------ + + def _load_model(self): + dtype_map = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp32": torch.float32, + } + dtype = dtype_map.get(self.dtype_str) + model_opts = {"dtype": dtype} if dtype is not None else {} + + patcher = comfy.sd.load_diffusion_model(self.checkpoint, model_options=model_opts) + # force_full_load=True keeps the whole model resident on GPU rather than + # streaming weights on demand (important for latency benchmarking). + comfy.model_management.load_models_gpu([patcher], force_full_load=True) + return patcher, patcher.model + + def _estimate_mem_mb(self, latent_shape: tuple, text_seq_len: int) -> Optional[float]: + cond_shapes = { + "c_crossattn": [(latent_shape[0], text_seq_len, self.text_dim)], + } + try: + return self.model.memory_required(latent_shape, cond_shapes) / (1024 ** 2) + except Exception: + return None + + # ------------------------------------------------------------------ + # Single-request execution + # ------------------------------------------------------------------ + + @torch.inference_mode() + def run_one(self, req: RequestSpec) -> RequestResult: + start = now() + reset_gpu_peak() + + step_latencies: List[float] = [] + ttfs_ms = float("nan") + est_mem_mb: Optional[float] = None + ok = True + err = None + + try: + device = comfy.model_management.get_torch_device() + dtype = self.model.get_dtype_inference() + + # Latent noise tensor: [B, 16 channels, T frames, H/8, W/8] + latent_shape = ( + req.batch_size, 16, + req.num_frames, + req.height // 8, + req.width // 8, + ) + x = torch.randn(latent_shape, dtype=dtype, device=device) + est_mem_mb = self._estimate_mem_mb(latent_shape, self.text_seq_len) + + # Fake text conditioning — zeros have the right shape, non-zero + # values are not needed for throughput/latency benchmarking. + cross_attn = torch.zeros( + req.batch_size, self.text_seq_len, self.text_dim, + dtype=dtype, device=device, + ) + + # Linear sigma schedule: 1.0 → 1/steps (flow-matching, noise→clean) + sigmas = torch.linspace(1.0, 1.0 / req.steps, req.steps, device=device) + + for step_i, sigma_val in enumerate(sigmas): + sigma_t = sigma_val.expand(req.batch_size) + t0 = now() + x = self.model.apply_model(x, sigma_t, c_crossattn=cross_attn) + sync_cuda() + elapsed_ms = (now() - t0) * 1000.0 + step_latencies.append(elapsed_ms) + if step_i == 0: + ttfs_ms = elapsed_ms + + except Exception as e: + ok = False + err = repr(e) + + end = now() + return RequestResult( + request_id=-1, + profile_name=req.profile_name, + ok=ok, + error=err, + latency_ms=(end - start) * 1000.0, + queue_wait_ms=0.0, # filled in by the scheduler + step_latencies_ms=step_latencies, + ttfs_ms=ttfs_ms, + peak_vram_mb=gpu_peak_mb(), + est_mem_mb=est_mem_mb, + started_at=start, + ended_at=end, + ) + + +# ----------------------------- +# Serving-style scheduler +# ----------------------------- + +async def run_closed_loop( + runner: WanRunner, + requests: List[RequestSpec], + concurrency: int, + request_rate: float = float("inf"), +) -> List[RequestResult]: + """ + Closed-loop scheduler (default) or Poisson open-loop when request_rate is finite. + + Each request is dispatched to a thread so the asyncio event loop stays + free to issue the next request while the GPU is busy. + """ + sem = asyncio.Semaphore(concurrency) + results: List[Optional[RequestResult]] = [None] * len(requests) + + async def worker(i: int, req: RequestSpec) -> None: + async with sem: + t_enq = now() + res = await asyncio.to_thread(runner.run_one, req) + res.request_id = i + res.queue_wait_ms = max(0.0, (res.started_at - t_enq) * 1000.0) + results[i] = res + + if request_rate == float("inf") or request_rate <= 0: + await asyncio.gather(*(worker(i, r) for i, r in enumerate(requests))) + else: + tasks: List[asyncio.Task] = [] + for i, req in enumerate(requests): + if i > 0: + await asyncio.sleep(random.expovariate(request_rate)) + tasks.append(asyncio.create_task(worker(i, req))) + await asyncio.gather(*tasks) + + return [r for r in results if r is not None] + + +def summarize(results: List[RequestResult], wall_s: float) -> RunSummary: + lat = [r.latency_ms for r in results if r.ok] + ttfs = [r.ttfs_ms for r in results if r.ok and math.isfinite(r.ttfs_ms)] + all_steps = [s for r in results if r.ok for s in r.step_latencies_ms] + succ = sum(1 for r in results if r.ok) + fail = len(results) - succ + return RunSummary( + total_requests=len(results), + success=succ, + failed=fail, + throughput_req_s=(succ / wall_s) if wall_s > 0 else 0.0, + p50_ms=percentile(lat, 50), + p90_ms=percentile(lat, 90), + p95_ms=percentile(lat, 95), + p99_ms=percentile(lat, 99), + mean_ms=(statistics.mean(lat) if lat else float("nan")), + ttfs_p50_ms=percentile(ttfs, 50), + ttfs_p99_ms=percentile(ttfs, 99), + step_mean_ms=(statistics.mean(all_steps) if all_steps else float("nan")), + step_p99_ms=percentile(all_steps, 99), + max_vram_mb=max((r.peak_vram_mb for r in results), default=0.0), + ) + + +def print_summary( + args: argparse.Namespace, + summ: RunSummary, + total_requests: int, + wall_s: float, +) -> None: + w = 60 + sep = "-" * w + print("\n" + "=" * w) + print("{s:^{n}}".format(s=" WAN Benchmark Result ", n=w)) + print("=" * w) + print("{:<40} {:<}".format("Checkpoint:", Path(args.checkpoint).name)) + print("{:<40} {:<}".format("Device / dtype:", f"{args.device}/{args.dtype}")) + print("{:<40} {:<}".format("Concurrency:", args.concurrency)) + rate_str = f"{args.request_rate:.1f} req/s" if args.request_rate != float("inf") else "inf (closed-loop)" + print("{:<40} {:<}".format("Request rate:", rate_str)) + print(sep) + print("{:<40} {:<.2f}".format("Benchmark duration (s):", wall_s)) + print("{:<40} {}/{}".format("Successful requests:", summ.success, total_requests)) + if summ.failed: + print("{:<40} {:<}".format("Failed requests:", summ.failed)) + print(sep) + print("{:<40} {:<.3f}".format("Throughput (req/s):", summ.throughput_req_s)) + print("{:<40} {:<.1f}".format("Latency mean (ms):", summ.mean_ms)) + print("{:<40} {:<.1f}".format("Latency p50 (ms):", summ.p50_ms)) + print("{:<40} {:<.1f}".format("Latency p90 (ms):", summ.p90_ms)) + print("{:<40} {:<.1f}".format("Latency p95 (ms):", summ.p95_ms)) + print("{:<40} {:<.1f}".format("Latency p99 (ms):", summ.p99_ms)) + print(sep) + print("{:<40} {:<.1f}".format("TTFS p50 (ms):", summ.ttfs_p50_ms)) + print("{:<40} {:<.1f}".format("TTFS p99 (ms):", summ.ttfs_p99_ms)) + print("{:<40} {:<.1f}".format("Step latency mean (ms):", summ.step_mean_ms)) + print("{:<40} {:<.1f}".format("Step latency p99 (ms):", summ.step_p99_ms)) + print(sep) + print("{:<40} {:<.1f}".format("Peak VRAM (MB):", summ.max_vram_mb)) + print("=" * w) + + +# ----------------------------- +# CLI +# ----------------------------- + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser( + description="Benchmark ComfyUI WAN diffusion model denoising throughput and latency." + ) + p.add_argument( + "--checkpoint", required=True, + help="Path to the WAN diffusion-model checkpoint (.safetensors / .pt).", + ) + p.add_argument("--device", default="cuda") + p.add_argument("--dtype", default="fp16", choices=["fp16", "bf16", "fp32"]) + p.add_argument("--num-requests", type=int, default=100) + p.add_argument("--concurrency", type=int, default=4, + help="Max number of in-flight requests (semaphore width).") + p.add_argument( + "--request-rate", type=float, default=float("inf"), + help="Poisson arrival rate in req/s. inf = closed-loop (default).", + ) + p.add_argument("--warmup-requests", type=int, default=2, + help="Warmup iterations excluded from metrics.") + p.add_argument("--seed", type=int, default=1234) + p.add_argument("--text-seq-len", type=int, default=512, + help="Cross-attention sequence length (UMT5 default: 512).") + p.add_argument("--text-dim", type=int, default=4096, + help="Text embedding width (UMT5-XXL: 4096).") + p.add_argument("--out-dir", type=Path, default=Path("benchmarks/out")) + p.add_argument("--output-file", type=Path, default=None, + help="Override path for the summary JSON output.") + return p.parse_args() + + +def default_profiles() -> List[RequestSpec]: + return [ + RequestSpec("wan21_t2v_720p_16f_30s", 1, 1280, 720, 16, 30, 6.0, 0), + RequestSpec("wan21_t2v_720p_32f_30s", 1, 1280, 720, 32, 30, 6.0, 0), + RequestSpec("wan21_t2v_480p_32f_20s", 1, 854, 480, 32, 20, 6.0, 0), + ] + + +async def main_async() -> None: + args = parse_args() + args.out_dir.mkdir(parents=True, exist_ok=True) + + runner = WanRunner( + checkpoint=args.checkpoint, + device=args.device, + dtype_str=args.dtype, + text_seq_len=args.text_seq_len, + text_dim=args.text_dim, + ) + + all_reqs = build_request_stream( + args.num_requests + args.warmup_requests, + args.seed, + default_profiles(), + ) + warmup_reqs = all_reqs[: args.warmup_requests] + bench_reqs = all_reqs[args.warmup_requests :] + + if warmup_reqs: + print(f"Running {len(warmup_reqs)} warmup request(s)...") + for req in warmup_reqs: + runner.run_one(req) + print("Warmup complete.") + + print(f"Benchmarking {len(bench_reqs)} requests (concurrency={args.concurrency})...") + t0 = now() + results = await run_closed_loop(runner, bench_reqs, args.concurrency, args.request_rate) + wall_s = now() - t0 + + summ = summarize(results, wall_s) + print_summary(args, summ, len(bench_reqs), wall_s) + + out_file = args.output_file or (args.out_dir / "summary.json") + with open(args.out_dir / "requests.jsonl", "w") as f: + for r in results: + f.write(json.dumps(dataclasses.asdict(r)) + "\n") + with open(out_file, "w") as f: + json.dump(dataclasses.asdict(summ), f, indent=2) + print(f"\nResults written to {args.out_dir}/") + + +if __name__ == "__main__": + asyncio.run(main_async())