mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-02 21:32:31 +08:00
445 lines
14 KiB
Python
445 lines
14 KiB
Python
from __future__ import annotations
|
|
|
|
import argparse
|
|
import asyncio
|
|
import dataclasses
|
|
import json
|
|
import math
|
|
import os
|
|
import random
|
|
import statistics
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import torch
|
|
|
|
# Make the repo root importable when running directly from the benchmarks/ dir.
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
import comfy.model_management
|
|
import comfy.sd
|
|
|
|
|
|
# -----------------------------
|
|
# Data models
|
|
# -----------------------------
|
|
|
|
@dataclasses.dataclass
|
|
class RequestSpec:
|
|
profile_name: str
|
|
batch_size: int
|
|
width: int
|
|
height: int
|
|
num_frames: int
|
|
steps: int
|
|
cfg_scale: float
|
|
seed: int
|
|
timeout_s: float = 180.0
|
|
extra: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class RequestResult:
|
|
request_id: int
|
|
profile_name: str
|
|
ok: bool
|
|
error: Optional[str]
|
|
latency_ms: float
|
|
queue_wait_ms: float
|
|
step_latencies_ms: List[float]
|
|
ttfs_ms: float # time to first (denoising) step
|
|
peak_vram_mb: float
|
|
est_mem_mb: Optional[float]
|
|
started_at: float
|
|
ended_at: float
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class RunSummary:
|
|
total_requests: int
|
|
success: int
|
|
failed: int
|
|
throughput_req_s: float
|
|
p50_ms: float
|
|
p90_ms: float
|
|
p95_ms: float
|
|
p99_ms: float
|
|
mean_ms: float
|
|
ttfs_p50_ms: float
|
|
ttfs_p99_ms: float
|
|
step_mean_ms: float
|
|
step_p99_ms: float
|
|
max_vram_mb: float
|
|
|
|
|
|
# -----------------------------
|
|
# Helpers
|
|
# -----------------------------
|
|
|
|
def percentile(values: List[float], p: float) -> float:
|
|
if not values:
|
|
return float("nan")
|
|
values = sorted(values)
|
|
k = (len(values) - 1) * (p / 100.0)
|
|
f = math.floor(k)
|
|
c = math.ceil(k)
|
|
if f == c:
|
|
return values[int(k)]
|
|
return values[f] * (c - k) + values[c] * (k - f)
|
|
|
|
|
|
def now() -> float:
|
|
return time.perf_counter()
|
|
|
|
|
|
def gpu_peak_mb() -> float:
|
|
if not torch.cuda.is_available():
|
|
return 0.0
|
|
return torch.cuda.max_memory_allocated() / (1024 ** 2)
|
|
|
|
|
|
def reset_gpu_peak() -> None:
|
|
if torch.cuda.is_available():
|
|
torch.cuda.reset_peak_memory_stats()
|
|
|
|
|
|
def sync_cuda() -> None:
|
|
if torch.cuda.is_available():
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
def build_request_stream(
|
|
num_requests: int,
|
|
base_seed: int,
|
|
profiles: List[RequestSpec],
|
|
weighted: Optional[List[float]] = None,
|
|
) -> List[RequestSpec]:
|
|
rnd = random.Random(base_seed)
|
|
out: List[RequestSpec] = []
|
|
for i in range(num_requests):
|
|
p = rnd.choices(profiles, weights=weighted, k=1)[0]
|
|
out.append(dataclasses.replace(p, seed=base_seed + i))
|
|
return out
|
|
|
|
|
|
# -----------------------------
|
|
# Model adapter
|
|
# -----------------------------
|
|
|
|
class WanRunner:
|
|
"""
|
|
Thin adapter around ComfyUI model loading + the BaseModel.apply_model call path.
|
|
|
|
Only the DiT denoiser is timed — no VAE encode/decode, no CLIP, no scheduler
|
|
overhead — so measurements reflect true model inference cost.
|
|
|
|
Latent shape convention (WAN): [B, 16, T, H//8, W//8]
|
|
Text conditioning shape (UMT5): [B, text_seq_len, text_dim] (zeros for benchmarking)
|
|
Sigma schedule (flow-matching): linspace(1.0 → 1/steps, steps)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
checkpoint: str,
|
|
device: str,
|
|
dtype_str: str,
|
|
text_seq_len: int = 512,
|
|
text_dim: int = 4096,
|
|
):
|
|
self.checkpoint = checkpoint
|
|
self.device_str = device
|
|
self.dtype_str = dtype_str
|
|
self.text_seq_len = text_seq_len
|
|
self.text_dim = text_dim
|
|
self.patcher, self.model = self._load_model()
|
|
|
|
# ------------------------------------------------------------------
|
|
# Internals
|
|
# ------------------------------------------------------------------
|
|
|
|
def _load_model(self):
|
|
dtype_map = {
|
|
"fp16": torch.float16,
|
|
"bf16": torch.bfloat16,
|
|
"fp32": torch.float32,
|
|
}
|
|
dtype = dtype_map.get(self.dtype_str)
|
|
model_opts = {"dtype": dtype} if dtype is not None else {}
|
|
|
|
patcher = comfy.sd.load_diffusion_model(self.checkpoint, model_options=model_opts)
|
|
# force_full_load=True keeps the whole model resident on GPU rather than
|
|
# streaming weights on demand (important for latency benchmarking).
|
|
comfy.model_management.load_models_gpu([patcher], force_full_load=True)
|
|
return patcher, patcher.model
|
|
|
|
def _estimate_mem_mb(self, latent_shape: tuple, text_seq_len: int) -> Optional[float]:
|
|
cond_shapes = {
|
|
"c_crossattn": [(latent_shape[0], text_seq_len, self.text_dim)],
|
|
}
|
|
try:
|
|
return self.model.memory_required(latent_shape, cond_shapes) / (1024 ** 2)
|
|
except Exception:
|
|
return None
|
|
|
|
# ------------------------------------------------------------------
|
|
# Single-request execution
|
|
# ------------------------------------------------------------------
|
|
|
|
@torch.inference_mode()
|
|
def run_one(self, req: RequestSpec) -> RequestResult:
|
|
start = now()
|
|
reset_gpu_peak()
|
|
|
|
step_latencies: List[float] = []
|
|
ttfs_ms = float("nan")
|
|
est_mem_mb: Optional[float] = None
|
|
ok = True
|
|
err = None
|
|
|
|
try:
|
|
device = comfy.model_management.get_torch_device()
|
|
dtype = self.model.get_dtype_inference()
|
|
|
|
# Latent noise tensor: [B, 16 channels, T frames, H/8, W/8]
|
|
latent_shape = (
|
|
req.batch_size, 16,
|
|
req.num_frames,
|
|
req.height // 8,
|
|
req.width // 8,
|
|
)
|
|
x = torch.randn(latent_shape, dtype=dtype, device=device)
|
|
est_mem_mb = self._estimate_mem_mb(latent_shape, self.text_seq_len)
|
|
|
|
# Fake text conditioning — zeros have the right shape, non-zero
|
|
# values are not needed for throughput/latency benchmarking.
|
|
cross_attn = torch.zeros(
|
|
req.batch_size, self.text_seq_len, self.text_dim,
|
|
dtype=dtype, device=device,
|
|
)
|
|
|
|
# Linear sigma schedule: 1.0 → 1/steps (flow-matching, noise→clean)
|
|
sigmas = torch.linspace(1.0, 1.0 / req.steps, req.steps, device=device)
|
|
|
|
for step_i, sigma_val in enumerate(sigmas):
|
|
sigma_t = sigma_val.expand(req.batch_size)
|
|
t0 = now()
|
|
x = self.model.apply_model(x, sigma_t, c_crossattn=cross_attn)
|
|
sync_cuda()
|
|
elapsed_ms = (now() - t0) * 1000.0
|
|
step_latencies.append(elapsed_ms)
|
|
if step_i == 0:
|
|
ttfs_ms = elapsed_ms
|
|
|
|
except Exception as e:
|
|
ok = False
|
|
err = repr(e)
|
|
|
|
end = now()
|
|
return RequestResult(
|
|
request_id=-1,
|
|
profile_name=req.profile_name,
|
|
ok=ok,
|
|
error=err,
|
|
latency_ms=(end - start) * 1000.0,
|
|
queue_wait_ms=0.0, # filled in by the scheduler
|
|
step_latencies_ms=step_latencies,
|
|
ttfs_ms=ttfs_ms,
|
|
peak_vram_mb=gpu_peak_mb(),
|
|
est_mem_mb=est_mem_mb,
|
|
started_at=start,
|
|
ended_at=end,
|
|
)
|
|
|
|
|
|
# -----------------------------
|
|
# Serving-style scheduler
|
|
# -----------------------------
|
|
|
|
async def run_closed_loop(
|
|
runner: WanRunner,
|
|
requests: List[RequestSpec],
|
|
concurrency: int,
|
|
request_rate: float = float("inf"),
|
|
) -> List[RequestResult]:
|
|
"""
|
|
Closed-loop scheduler (default) or Poisson open-loop when request_rate is finite.
|
|
|
|
Each request is dispatched to a thread so the asyncio event loop stays
|
|
free to issue the next request while the GPU is busy.
|
|
"""
|
|
sem = asyncio.Semaphore(concurrency)
|
|
results: List[Optional[RequestResult]] = [None] * len(requests)
|
|
|
|
async def worker(i: int, req: RequestSpec) -> None:
|
|
async with sem:
|
|
t_enq = now()
|
|
res = await asyncio.to_thread(runner.run_one, req)
|
|
res.request_id = i
|
|
res.queue_wait_ms = max(0.0, (res.started_at - t_enq) * 1000.0)
|
|
results[i] = res
|
|
|
|
if request_rate == float("inf") or request_rate <= 0:
|
|
await asyncio.gather(*(worker(i, r) for i, r in enumerate(requests)))
|
|
else:
|
|
tasks: List[asyncio.Task] = []
|
|
for i, req in enumerate(requests):
|
|
if i > 0:
|
|
await asyncio.sleep(random.expovariate(request_rate))
|
|
tasks.append(asyncio.create_task(worker(i, req)))
|
|
await asyncio.gather(*tasks)
|
|
|
|
return [r for r in results if r is not None]
|
|
|
|
|
|
def summarize(results: List[RequestResult], wall_s: float) -> RunSummary:
|
|
lat = [r.latency_ms for r in results if r.ok]
|
|
ttfs = [r.ttfs_ms for r in results if r.ok and math.isfinite(r.ttfs_ms)]
|
|
all_steps = [s for r in results if r.ok for s in r.step_latencies_ms]
|
|
succ = sum(1 for r in results if r.ok)
|
|
fail = len(results) - succ
|
|
return RunSummary(
|
|
total_requests=len(results),
|
|
success=succ,
|
|
failed=fail,
|
|
throughput_req_s=(succ / wall_s) if wall_s > 0 else 0.0,
|
|
p50_ms=percentile(lat, 50),
|
|
p90_ms=percentile(lat, 90),
|
|
p95_ms=percentile(lat, 95),
|
|
p99_ms=percentile(lat, 99),
|
|
mean_ms=(statistics.mean(lat) if lat else float("nan")),
|
|
ttfs_p50_ms=percentile(ttfs, 50),
|
|
ttfs_p99_ms=percentile(ttfs, 99),
|
|
step_mean_ms=(statistics.mean(all_steps) if all_steps else float("nan")),
|
|
step_p99_ms=percentile(all_steps, 99),
|
|
max_vram_mb=max((r.peak_vram_mb for r in results), default=0.0),
|
|
)
|
|
|
|
|
|
def print_summary(
|
|
args: argparse.Namespace,
|
|
summ: RunSummary,
|
|
total_requests: int,
|
|
wall_s: float,
|
|
) -> None:
|
|
w = 60
|
|
sep = "-" * w
|
|
print("\n" + "=" * w)
|
|
print("{s:^{n}}".format(s=" WAN Benchmark Result ", n=w))
|
|
print("=" * w)
|
|
print("{:<40} {:<}".format("Checkpoint:", Path(args.checkpoint).name))
|
|
print("{:<40} {:<}".format("Device / dtype:", f"{args.device}/{args.dtype}"))
|
|
print("{:<40} {:<}".format("Concurrency:", args.concurrency))
|
|
rate_str = f"{args.request_rate:.1f} req/s" if args.request_rate != float("inf") else "inf (closed-loop)"
|
|
print("{:<40} {:<}".format("Request rate:", rate_str))
|
|
print(sep)
|
|
print("{:<40} {:<.2f}".format("Benchmark duration (s):", wall_s))
|
|
print("{:<40} {}/{}".format("Successful requests:", summ.success, total_requests))
|
|
if summ.failed:
|
|
print("{:<40} {:<}".format("Failed requests:", summ.failed))
|
|
print(sep)
|
|
print("{:<40} {:<.3f}".format("Throughput (req/s):", summ.throughput_req_s))
|
|
print("{:<40} {:<.1f}".format("Latency mean (ms):", summ.mean_ms))
|
|
print("{:<40} {:<.1f}".format("Latency p50 (ms):", summ.p50_ms))
|
|
print("{:<40} {:<.1f}".format("Latency p90 (ms):", summ.p90_ms))
|
|
print("{:<40} {:<.1f}".format("Latency p95 (ms):", summ.p95_ms))
|
|
print("{:<40} {:<.1f}".format("Latency p99 (ms):", summ.p99_ms))
|
|
print(sep)
|
|
print("{:<40} {:<.1f}".format("TTFS p50 (ms):", summ.ttfs_p50_ms))
|
|
print("{:<40} {:<.1f}".format("TTFS p99 (ms):", summ.ttfs_p99_ms))
|
|
print("{:<40} {:<.1f}".format("Step latency mean (ms):", summ.step_mean_ms))
|
|
print("{:<40} {:<.1f}".format("Step latency p99 (ms):", summ.step_p99_ms))
|
|
print(sep)
|
|
print("{:<40} {:<.1f}".format("Peak VRAM (MB):", summ.max_vram_mb))
|
|
print("=" * w)
|
|
|
|
|
|
# -----------------------------
|
|
# CLI
|
|
# -----------------------------
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
p = argparse.ArgumentParser(
|
|
description="Benchmark ComfyUI WAN diffusion model denoising throughput and latency."
|
|
)
|
|
p.add_argument(
|
|
"--checkpoint", required=True,
|
|
help="Path to the WAN diffusion-model checkpoint (.safetensors / .pt).",
|
|
)
|
|
p.add_argument("--device", default="cuda")
|
|
p.add_argument("--dtype", default="fp16", choices=["fp16", "bf16", "fp32"])
|
|
p.add_argument("--num-requests", type=int, default=100)
|
|
p.add_argument("--concurrency", type=int, default=4,
|
|
help="Max number of in-flight requests (semaphore width).")
|
|
p.add_argument(
|
|
"--request-rate", type=float, default=float("inf"),
|
|
help="Poisson arrival rate in req/s. inf = closed-loop (default).",
|
|
)
|
|
p.add_argument("--warmup-requests", type=int, default=2,
|
|
help="Warmup iterations excluded from metrics.")
|
|
p.add_argument("--seed", type=int, default=1234)
|
|
p.add_argument("--text-seq-len", type=int, default=512,
|
|
help="Cross-attention sequence length (UMT5 default: 512).")
|
|
p.add_argument("--text-dim", type=int, default=4096,
|
|
help="Text embedding width (UMT5-XXL: 4096).")
|
|
p.add_argument("--out-dir", type=Path, default=Path("benchmarks/out"))
|
|
p.add_argument("--output-file", type=Path, default=None,
|
|
help="Override path for the summary JSON output.")
|
|
return p.parse_args()
|
|
|
|
|
|
def default_profiles() -> List[RequestSpec]:
|
|
return [
|
|
RequestSpec("wan21_t2v_720p_16f_30s", 1, 1280, 720, 16, 30, 6.0, 0),
|
|
RequestSpec("wan21_t2v_720p_32f_30s", 1, 1280, 720, 32, 30, 6.0, 0),
|
|
RequestSpec("wan21_t2v_480p_32f_20s", 1, 854, 480, 32, 20, 6.0, 0),
|
|
]
|
|
|
|
|
|
async def main_async() -> None:
|
|
args = parse_args()
|
|
args.out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
runner = WanRunner(
|
|
checkpoint=args.checkpoint,
|
|
device=args.device,
|
|
dtype_str=args.dtype,
|
|
text_seq_len=args.text_seq_len,
|
|
text_dim=args.text_dim,
|
|
)
|
|
|
|
all_reqs = build_request_stream(
|
|
args.num_requests + args.warmup_requests,
|
|
args.seed,
|
|
default_profiles(),
|
|
)
|
|
warmup_reqs = all_reqs[: args.warmup_requests]
|
|
bench_reqs = all_reqs[args.warmup_requests :]
|
|
|
|
if warmup_reqs:
|
|
print(f"Running {len(warmup_reqs)} warmup request(s)...")
|
|
for req in warmup_reqs:
|
|
runner.run_one(req)
|
|
print("Warmup complete.")
|
|
|
|
print(f"Benchmarking {len(bench_reqs)} requests (concurrency={args.concurrency})...")
|
|
t0 = now()
|
|
results = await run_closed_loop(runner, bench_reqs, args.concurrency, args.request_rate)
|
|
wall_s = now() - t0
|
|
|
|
summ = summarize(results, wall_s)
|
|
print_summary(args, summ, len(bench_reqs), wall_s)
|
|
|
|
out_file = args.output_file or (args.out_dir / "summary.json")
|
|
with open(args.out_dir / "requests.jsonl", "w") as f:
|
|
for r in results:
|
|
f.write(json.dumps(dataclasses.asdict(r)) + "\n")
|
|
with open(out_file, "w") as f:
|
|
json.dump(dataclasses.asdict(summ), f, indent=2)
|
|
print(f"\nResults written to {args.out_dir}/")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main_async())
|