server-bench : add speed-bench for speculative decoding benchmarking (#23869)

* spec: add speed-bench support for benchmarking

* speed-bench : add trailing newline to requirements.txt

* speed-bench : bump datasets to 4.8.0 to fix ty check

* server-bench : remove now-unused type: ignore after datasets bump
This commit is contained in:
Ruixiang Wang
2026-05-29 23:09:47 +02:00
committed by GitHub
parent 5a46b46acd
commit 689a9a470e
7 changed files with 643 additions and 2 deletions
+5
View File
@@ -323,3 +323,8 @@ statistics ngram_map_k: #calls(b,g,a) = 6 1690 26, #gen drafts = 26, #acc drafts
- `#gen tokens`: number of tokens generated by this implementation (including rejected tokens)
- `#acc tokens`: number of tokens accepted by the main model
- `dur(b,g,a): durations of begin (new prompt), generation and accumulation (process acceptance).
## Benchmarking
To measure the end-to-end effect of speculative decoding (throughput, latency, and draft acceptance) across diverse prompts, see the SPEED-Bench client in [tools/server/bench/speed-bench](../tools/server/bench/speed-bench/README.md).
It runs against a running `llama-server` and can compare a baseline run against a speculative-decoding run.
+1 -1
View File
@@ -1,4 +1,4 @@
datasets~=3.2.0
datasets~=4.8.0
matplotlib~=3.10.0
numpy~=1.26.4
requests~=2.32.3
+1 -1
View File
@@ -25,7 +25,7 @@ def get_prompts_text(dataset_name: str, n_prompts: int) -> Optional[list[str]]:
ret = []
if dataset_name.lower() == "mmlu":
logger.info("Loading MMLU dataset...")
ret = datasets.load_dataset("cais/mmlu", "all")["test"]["question"] # type: ignore
ret = datasets.load_dataset("cais/mmlu", "all")["test"]["question"]
else:
return None
if n_prompts >= 0:
+117
View File
@@ -0,0 +1,117 @@
# SPEED-Bench server benchmark
A lightweight [SPEED-Bench](https://huggingface.co/datasets/nvidia/SPEED-Bench) client for benchmarking an already-running `llama-server` through its OpenAI-compatible API. It is primarily meant to evaluate speculative decoding (draft model, n-gram, MTP, EAGLE3, ...) by reporting per-category throughput, latency, and draft acceptance.
The dataset handling follows the [aiperf SPEED-Bench tutorial](https://github.com/ai-dynamo/aiperf/blob/main/docs/tutorials/speed-bench.md), which also documents the dataset layout in more detail.
## Install
```bash
pip install -r tools/server/bench/speed-bench/requirements.txt
```
## Start a server
The client does not launch the server, so start `llama-server` yourself first. If you care about throughput numbers, set the client `--concurrency` to the server's slot count (`--np`):
```bash
llama-server \
-m target.gguf \
-c 8192 \
--port 8080 \
-ngl 99 -fa on \
--np 1 \
--jinja
```
For speculative decoding, start the server with the appropriate flags for your setup (e.g. a draft model with `-md`, or `--spec-type ngram-mod`). See the [speculative decoding doc](../../../../docs/speculative.md) for details.
## Run
```bash
python tools/server/bench/speed-bench/speed_bench.py \
--url localhost:8080 \
--bench qualitative \
--category coding \
--osl 1024 \
--concurrency 1
```
## Options
| Option | Default | Description |
| --- | --- | --- |
| `--url` | `localhost:8080` | Server URL. The scheme and `/v1` are optional and a trailing slash is fine, so `localhost:8080` and `http://localhost:8080/v1/` both work. |
| `--model` | none | Optional `model` field sent in each request. |
| `--bench` | `qualitative` | SPEED-Bench config, e.g. `qualitative`, `throughput_1k`. See [available dataset variants](https://github.com/ai-dynamo/aiperf/blob/main/docs/tutorials/speed-bench.md#available-dataset-variants). |
| `--category` | `all` | Category filter within the bench; comma-separated list or `all`. For `qualitative` the categories are `coding`, `humanities`, `math`, `multilingual`, `qa`, `rag`, `reasoning`, `roleplay`, `stem`, `summarization`, `writing`. For the `throughput_{ISL}` splits they are `high_entropy`, `low_entropy`, `mixed`. |
| `--osl` | `1024` | Output sequence length, mapped to `max_tokens`. |
| `--extra-inputs` | `{"temperature":0}` | Extra request fields as a JSON object. |
| `--concurrency` | `1` | Concurrent client requests; usually match `--np`. |
| `--limit` | none | Max samples per category (handy for smoke tests). |
| `--timeout` | `600` | Per-request timeout in seconds. |
| `--output` | none | Save raw per-request results and the summary to JSON. |
A few common ones:
- `--category all` runs every category in the bench.
- `--category coding,math` runs just those two.
- `--bench throughput_8k` runs a fixed-input-length throughput split.
- `--limit 8` keeps at most 8 samples per category, which is enough for a quick check.
The `throughput_{ISL}` splits use fixed input lengths (1k - 32k), so they are handy for long-context testing and for comparing different `llama-server` batching settings (e.g. sweeping `-ub` / `--ubatch-size`) on prompts of a known size. Make sure the server `-c` is large enough for the chosen split. When raising `-ub`, also raise `-b` to at least the same value, since the physical ubatch cannot exceed the logical batch.
When `--output` is given, the JSON file holds the run `config`, the `selected_samples` / `completed_samples` / `failed_samples` counts, the per-category `summary` rows, and the per-sample `results`.
## Metrics
The summary prints one row per category plus an `overall` row:
- `samples` - how many samples finished successfully.
- `avg_prompt_t/s` - prefill throughput from llama.cpp (`timings.prompt_per_second`), averaged over the category's samples.
- `avg_pred_t/s` - decode throughput from llama.cpp (`timings.predicted_per_second`), averaged over the category's samples.
- `avg_latency` - average end-to-end request latency seen by the client.
- `accept_rate` - `accepted / draft_n` over the category, or `n/a` if nothing was drafted (`draft_n == 0`).
## Baseline vs speculative decoding
Save a run from each server with `--output`, then diff the two JSON files with `speed_bench_compare.py`.
First, start a plain `llama-server` (no speculative decoding) and save a baseline:
```bash
python tools/server/bench/speed-bench/speed_bench.py \
--url localhost:8080 \
--bench qualitative \
--category all \
--osl 1024 \
--concurrency 1 \
--output baseline.json
```
Then restart `llama-server` with speculative decoding enabled and save another run:
```bash
python tools/server/bench/speed-bench/speed_bench.py \
--url localhost:8080 \
--bench qualitative \
--category all \
--osl 1024 \
--concurrency 1 \
--output spec.json
```
Finally compare the two:
```bash
python tools/server/bench/speed-bench/speed_bench_compare.py \
--baseline baseline.json \
--speculative spec.json
```
The comparison table adds:
- `decode_speedup = spec_avg_pred_t/s / base_avg_pred_t/s`
- `latency_speedup = base_avg_latency / spec_avg_latency`
Keep `--bench`, `--category`, `--osl`, and `--limit` the same across both runs, otherwise they won't be using the same prompts.
@@ -0,0 +1,3 @@
datasets
requests
tqdm
@@ -0,0 +1,432 @@
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import concurrent.futures
import json
import statistics
import sys
import time
from dataclasses import asdict, dataclass
from typing import Any
from urllib.parse import urlparse
import requests
from datasets import get_dataset_config_names, load_dataset
from tqdm import tqdm
DATASET_REPO = "nvidia/SPEED-Bench"
@dataclass
class Sample:
id: str
category: str
turns: list[str]
@dataclass
class RequestResult:
id: str
category: str
ok: bool
turns: int
latency_s: float
prompt_tokens: int
completion_tokens: int
total_tokens: int
finish_reason: str | None
draft_n: int
draft_n_accepted: int
prompt_ms: float | None
predicted_ms: float | None
prompt_per_second: float | None
predicted_per_second: float | None
error: str | None
def normalize_base_url(url: str) -> str:
url = url.strip().rstrip("/")
if not url:
raise ValueError("--url cannot be empty")
if "://" not in url:
url = "http://" + url
parsed = urlparse(url)
if not parsed.scheme or not parsed.netloc:
raise ValueError(f"invalid --url: {url}")
if not parsed.path.rstrip("/").endswith("/v1"):
url = url + "/v1"
return url.rstrip("/")
def parse_extra_inputs(value: str) -> dict[str, Any]:
extra = json.loads(value)
if not isinstance(extra, dict):
raise ValueError("--extra-inputs must be a JSON object")
return extra
def extract_turns(row: dict[str, Any]) -> list[str]:
turns = row.get("turns")
if isinstance(turns, list) and turns:
clean_turns = [str(turn).strip() for turn in turns if turn and str(turn).strip()]
if clean_turns:
return clean_turns
raise ValueError("missing or empty turns")
def load_samples(args: argparse.Namespace) -> list[Sample]:
bench_names = get_dataset_config_names(DATASET_REPO)
if args.bench not in bench_names:
raise ValueError(
f"unknown --bench {args.bench!r}; available benches: {', '.join(bench_names)}"
)
dataset = load_dataset(DATASET_REPO, name=args.bench, split="test")
categories = list(dict.fromkeys(str(category) for category in dataset["category"]))
requested_categories = None
if args.category != "all":
requested_list = [category.strip() for category in args.category.split(",") if category.strip()]
if not requested_list:
raise ValueError(
f"--category must be 'all' or a comma-separated list; available categories: {', '.join(categories)}"
)
requested_categories = set(requested_list)
unknown_categories = [category for category in requested_list if category not in categories]
if unknown_categories:
unknown = ", ".join(unknown_categories)
raise ValueError(
f"unknown --category {unknown!r} for bench {args.bench!r}; "
f"available categories: all, {', '.join(categories)}"
)
samples: list[Sample] = []
samples_per_category: dict[str, int] = {}
skipped = 0
for index, row_raw in enumerate(dataset):
row = dict(row_raw)
category_raw = row.get("category")
if not isinstance(category_raw, str) or not category_raw.strip():
skipped += 1
continue
category = category_raw.strip()
if requested_categories is not None and category not in requested_categories:
continue
if args.limit is not None and samples_per_category.get(category, 0) >= args.limit:
continue
try:
turns = extract_turns(row)
except ValueError:
skipped += 1
continue
question_id = row.get("question_id")
if not isinstance(question_id, str) or not question_id.strip():
skipped += 1
continue
sample_id = question_id.strip()
samples.append(Sample(id=sample_id, category=category, turns=turns))
samples_per_category[category] = samples_per_category.get(category, 0) + 1
if not samples:
raise RuntimeError(f"no samples selected from bench={args.bench} category={args.category}")
if skipped:
print(f"speed_bench: skipped {skipped} rows without usable turns")
return samples
def parse_completion_response(data: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any], str | None, str]:
usage = data.get("usage") or {}
timings = data.get("timings") or {}
finish_reason = None
content = ""
choices = data.get("choices")
if isinstance(choices, list) and choices and isinstance(choices[0], dict):
choice = choices[0]
finish_reason = choice.get("finish_reason")
message = choice.get("message")
if isinstance(message, dict) and isinstance(message.get("content"), str):
content = message["content"]
elif isinstance(choice.get("text"), str):
content = choice["text"]
return usage, timings, finish_reason, content
def run_request(
endpoint: str,
model: str | None,
messages: list[dict[str, str]],
osl: int,
extra_inputs: dict[str, Any],
timeout: float,
) -> tuple[dict[str, Any], float]:
payload: dict[str, Any] = {
"messages": messages,
"max_tokens": osl,
"stream": False,
}
if model:
payload["model"] = model
payload.update(extra_inputs)
payload["max_tokens"] = osl
start = time.perf_counter()
response = requests.post(endpoint, json=payload, timeout=timeout)
latency_s = time.perf_counter() - start
if response.status_code != 200:
body = response.text[:500].replace("\n", "\\n")
raise RuntimeError(f"HTTP {response.status_code}: {body}")
return response.json(), latency_s
def run_one(
sample: Sample,
endpoint: str,
model: str | None,
osl: int,
extra_inputs: dict[str, Any],
timeout: float,
) -> RequestResult:
selected_turns = sample.turns
messages: list[dict[str, str]] = []
total_latency_s = 0.0
prompt_tokens = 0
completion_tokens = 0
total_tokens = 0
draft_n = 0
draft_n_accepted = 0
prompt_ms = 0.0
predicted_ms = 0.0
prompt_per_second = None
predicted_per_second = None
finish_reason: str | None = None
try:
for turn in selected_turns:
messages.append({"role": "user", "content": turn})
data, latency_s = run_request(endpoint, model, messages, osl, extra_inputs, timeout)
total_latency_s += latency_s
usage, timings, finish_reason, assistant_text = parse_completion_response(data)
turn_prompt_tokens = int(usage.get("prompt_tokens") or timings.get("prompt_n") or 0)
turn_completion_tokens_count = int(usage.get("completion_tokens") or timings.get("predicted_n") or 0)
turn_total_tokens_count = int(usage.get("total_tokens") or (turn_prompt_tokens + turn_completion_tokens_count))
prompt_tokens += turn_prompt_tokens
completion_tokens += turn_completion_tokens_count
total_tokens += turn_total_tokens_count
draft_n += int(timings.get("draft_n") or 0)
draft_n_accepted += int(timings.get("draft_n_accepted") or 0)
prompt_ms += float(timings.get("prompt_ms") or 0)
predicted_ms += float(timings.get("predicted_ms") or 0)
if len(selected_turns) == 1 and isinstance(timings.get("prompt_per_second"), (int, float)):
prompt_per_second = float(timings["prompt_per_second"])
if len(selected_turns) == 1 and isinstance(timings.get("predicted_per_second"), (int, float)):
predicted_per_second = float(timings["predicted_per_second"])
messages.append({"role": "assistant", "content": assistant_text})
if total_tokens == 0:
total_tokens = prompt_tokens + completion_tokens
if len(selected_turns) > 1:
prompt_per_second = (prompt_tokens / (prompt_ms / 1000)) if prompt_ms > 0 else None
predicted_per_second = (completion_tokens / (predicted_ms / 1000)) if predicted_ms > 0 else None
return RequestResult(
id=sample.id,
category=sample.category,
ok=True,
turns=len(selected_turns),
latency_s=total_latency_s,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
finish_reason=finish_reason,
draft_n=draft_n,
draft_n_accepted=draft_n_accepted,
prompt_ms=prompt_ms if prompt_ms > 0 else None,
predicted_ms=predicted_ms if predicted_ms > 0 else None,
prompt_per_second=prompt_per_second,
predicted_per_second=predicted_per_second,
error=None,
)
except Exception as exc:
return RequestResult(
id=sample.id,
category=sample.category,
ok=False,
turns=len(selected_turns),
latency_s=total_latency_s,
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
finish_reason=None,
draft_n=0,
draft_n_accepted=0,
prompt_ms=None,
predicted_ms=None,
prompt_per_second=None,
predicted_per_second=None,
error=str(exc),
)
def summarize_group(category: str, results: list[RequestResult]) -> dict[str, Any]:
ok_results = [result for result in results if result.ok]
latencies = [result.latency_s for result in ok_results]
server_prompt_speeds = [
result.prompt_per_second
for result in ok_results
if result.prompt_per_second is not None
]
server_completion_speeds = [
result.predicted_per_second
for result in ok_results
if result.predicted_per_second is not None
]
turns = sum(result.turns for result in ok_results)
draft_n = sum(result.draft_n for result in ok_results)
accepted = sum(result.draft_n_accepted for result in ok_results)
return {
"category": category,
"requests": len(ok_results),
"turns": turns,
"failed": len(results) - len(ok_results),
"avg_prompt_t_s": statistics.mean(server_prompt_speeds) if server_prompt_speeds else None,
"avg_pred_t_s": statistics.mean(server_completion_speeds) if server_completion_speeds else None,
"avg_latency": statistics.mean(latencies) if latencies else None,
"draft_n": draft_n,
"accepted": accepted,
"accept_rate": (accepted / draft_n) if draft_n > 0 else None,
}
def fmt_value(value: Any, kind: str = "") -> str:
if value is None:
return "n/a"
if kind == "int":
return str(int(value))
if kind == "rate":
return f"{float(value):.4f}"
if kind == "seconds":
return f"{float(value):.3f}s"
if kind == "speed":
return f"{float(value):.2f}"
if kind == "speedup":
return f"{float(value):.2f}x"
return str(value)
def print_table(rows: list[dict[str, Any]]) -> None:
columns = [
("category", "category", ""),
("samples", "requests", "int"),
("avg_prompt_t/s", "avg_prompt_t_s", "speed"),
("avg_pred_t/s", "avg_pred_t_s", "speed"),
("avg_latency", "avg_latency", "seconds"),
("accept_rate", "accept_rate", "rate"),
]
print_rows(rows, columns)
def print_rows(rows: list[dict[str, Any]], columns: list[tuple[str, str, str]]) -> None:
rendered_rows = []
for row in rows:
rendered_rows.append([fmt_value(row.get(key), kind) for _, key, kind in columns])
widths = [len(header) for header, _, _ in columns]
for rendered in rendered_rows:
for i, cell in enumerate(rendered):
widths[i] = max(widths[i], len(cell))
header = " ".join(header.ljust(widths[i]) for i, (header, _, _) in enumerate(columns))
print(header)
print(" ".join("-" * width for width in widths))
for rendered in rendered_rows:
print(" ".join(cell.ljust(widths[i]) for i, cell in enumerate(rendered)))
def save_output(path: str, args: argparse.Namespace, samples: list[Sample], results: list[RequestResult], summary: list[dict[str, Any]]) -> None:
payload = {
"config": {
"url": args.url,
"model": args.model,
"bench": args.bench,
"category": args.category,
"osl": args.osl,
"concurrency": args.concurrency,
"extra_inputs": args.extra_inputs,
},
"selected_samples": len(samples),
"completed_samples": sum(1 for result in results if result.ok),
"failed_samples": sum(1 for result in results if not result.ok),
"summary": summary,
"results": [asdict(result) for result in results],
}
with open(path, "w", encoding="utf-8") as f:
json.dump(payload, f, indent=2, sort_keys=True)
def main(argv: list[str] | None = None) -> int:
parser = argparse.ArgumentParser(description="Run SPEED-Bench against an OpenAI-compatible llama-server.")
parser.add_argument("--url", default="localhost:8080", help="Server URL, for example localhost:8080 or http://localhost:8080/v1")
parser.add_argument("--model", default=None, help="Optional model name to send in OpenAI requests")
parser.add_argument("--bench", default="qualitative", help="SPEED-Bench config to run, for example qualitative or throughput_1k")
parser.add_argument("--category", default="all", help="Category to run within the selected bench; use all for no category filter")
parser.add_argument("--osl", type=int, default=4096, help="Output sequence length, mapped to max_tokens")
parser.add_argument("--extra-inputs", default='{"temperature":0}', help="Extra request fields as a JSON object")
parser.add_argument("--concurrency", type=int, default=1, help="Concurrent client requests; usually match llama-server --np")
parser.add_argument("--limit", type=int, default=None, help="Optional sample limit per category for smoke tests")
parser.add_argument("--timeout", type=float, default=600, help="Per-request timeout in seconds")
parser.add_argument("--output", default=None, help="Optional path to save raw results JSON")
args = parser.parse_args(argv)
try:
base_url = normalize_base_url(args.url)
endpoint = base_url + "/chat/completions"
extra_inputs = parse_extra_inputs(args.extra_inputs)
args.extra_inputs = extra_inputs
samples = load_samples(args)
except Exception as exc:
print(f"speed_bench: setup failed: {exc}", file=sys.stderr)
return 2
print(f"speed_bench: loaded {len(samples)} samples from bench={args.bench} category={args.category}")
results: list[RequestResult] = []
started = time.perf_counter()
with concurrent.futures.ThreadPoolExecutor(max_workers=args.concurrency) as executor:
futures = [
executor.submit(run_one, sample, endpoint, args.model, args.osl, extra_inputs, args.timeout)
for sample in samples
]
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="speed_bench", unit="sample"):
result = future.result()
results.append(result)
elapsed = time.perf_counter() - started
categories = list(dict.fromkeys(sample.category for sample in samples))
summary = [
summarize_group(category, [result for result in results if result.category == category])
for category in categories
]
summary.append(summarize_group("overall", results))
print()
print(f"Summary (elapsed={elapsed:.2f}s)")
print_table(summary)
if args.output:
save_output(args.output, args, samples, results, summary)
print(f"\nspeed_bench: wrote {args.output}")
failed = sum(1 for result in results if not result.ok)
if failed:
print(f"\nspeed_bench: {failed} samples failed", file=sys.stderr)
first_error = next((result.error for result in results if result.error), None)
if first_error:
print(f"first error: {first_error}", file=sys.stderr)
return 1
return 0
if __name__ == "__main__":
raise SystemExit(main())
@@ -0,0 +1,84 @@
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import sys
from typing import Any
from speed_bench import fmt_value, print_rows
def load_summary(path: str) -> list[dict[str, Any]]:
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
summary = data.get("summary")
if not isinstance(summary, list):
raise ValueError(f"{path} does not contain a summary list")
return summary
def compare_rows(baseline: list[dict[str, Any]], speculative: list[dict[str, Any]]) -> list[dict[str, Any]]:
baseline_by_category = {row["category"]: row for row in baseline}
comparisons = []
for row in speculative:
base = baseline_by_category.get(row["category"])
if not base:
continue
base_speed = base.get("avg_pred_t_s")
spec_speed = row.get("avg_pred_t_s")
base_latency = base.get("avg_latency")
spec_latency = row.get("avg_latency")
comparisons.append(
{
"category": row["category"],
"base_avg_pred_t_s": base_speed,
"spec_avg_pred_t_s": spec_speed,
"decode_speedup": (spec_speed / base_speed) if base_speed and spec_speed else None,
"base_avg_latency": base_latency,
"spec_avg_latency": spec_latency,
"latency_speedup": (base_latency / spec_latency) if base_latency and spec_latency else None,
"accept_rate": row.get("accept_rate"),
}
)
return comparisons
def print_comparison(rows: list[dict[str, Any]]) -> None:
if not rows:
print("No overlapping categories found for comparison.")
return
columns = [
("category", "category", ""),
("base_avg_pred_t/s", "base_avg_pred_t_s", "speed"),
("spec_avg_pred_t/s", "spec_avg_pred_t_s", "speed"),
("decode_speedup", "decode_speedup", "speedup"),
("base_avg_latency", "base_avg_latency", "seconds"),
("spec_avg_latency", "spec_avg_latency", "seconds"),
("latency_speedup", "latency_speedup", "speedup"),
("accept_rate", "accept_rate", "rate"),
]
print_rows(rows, columns)
def main(argv: list[str] | None = None) -> int:
parser = argparse.ArgumentParser(description="Compare two SPEED-Bench runs (baseline vs speculative).")
parser.add_argument("--baseline", required=True, help="Baseline results JSON produced by speed_bench.py --output")
parser.add_argument("--speculative", required=True, help="Speculative decoding results JSON produced by speed_bench.py --output")
args = parser.parse_args(argv)
try:
baseline = load_summary(args.baseline)
speculative = load_summary(args.speculative)
except Exception as exc:
print(f"speed_bench_compare: failed to load inputs: {exc}", file=sys.stderr)
return 2
comparisons = compare_rows(baseline, speculative)
print(f"Comparison: baseline={args.baseline} speculative={args.speculative}")
print_comparison(comparisons)
return 0
if __name__ == "__main__":
raise SystemExit(main())