mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-09 07:16:44 +02:00
server-bench : add speed-bench for speculative decoding benchmarking (#23869)
* spec: add speed-bench support for benchmarking * speed-bench : add trailing newline to requirements.txt * speed-bench : bump datasets to 4.8.0 to fix ty check * server-bench : remove now-unused type: ignore after datasets bump
This commit is contained in:
@@ -323,3 +323,8 @@ statistics ngram_map_k: #calls(b,g,a) = 6 1690 26, #gen drafts = 26, #acc drafts
|
||||
- `#gen tokens`: number of tokens generated by this implementation (including rejected tokens)
|
||||
- `#acc tokens`: number of tokens accepted by the main model
|
||||
- `dur(b,g,a): durations of begin (new prompt), generation and accumulation (process acceptance).
|
||||
|
||||
## Benchmarking
|
||||
|
||||
To measure the end-to-end effect of speculative decoding (throughput, latency, and draft acceptance) across diverse prompts, see the SPEED-Bench client in [tools/server/bench/speed-bench](../tools/server/bench/speed-bench/README.md).
|
||||
It runs against a running `llama-server` and can compare a baseline run against a speculative-decoding run.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
datasets~=3.2.0
|
||||
datasets~=4.8.0
|
||||
matplotlib~=3.10.0
|
||||
numpy~=1.26.4
|
||||
requests~=2.32.3
|
||||
|
||||
@@ -25,7 +25,7 @@ def get_prompts_text(dataset_name: str, n_prompts: int) -> Optional[list[str]]:
|
||||
ret = []
|
||||
if dataset_name.lower() == "mmlu":
|
||||
logger.info("Loading MMLU dataset...")
|
||||
ret = datasets.load_dataset("cais/mmlu", "all")["test"]["question"] # type: ignore
|
||||
ret = datasets.load_dataset("cais/mmlu", "all")["test"]["question"]
|
||||
else:
|
||||
return None
|
||||
if n_prompts >= 0:
|
||||
|
||||
@@ -0,0 +1,117 @@
|
||||
# SPEED-Bench server benchmark
|
||||
|
||||
A lightweight [SPEED-Bench](https://huggingface.co/datasets/nvidia/SPEED-Bench) client for benchmarking an already-running `llama-server` through its OpenAI-compatible API. It is primarily meant to evaluate speculative decoding (draft model, n-gram, MTP, EAGLE3, ...) by reporting per-category throughput, latency, and draft acceptance.
|
||||
|
||||
The dataset handling follows the [aiperf SPEED-Bench tutorial](https://github.com/ai-dynamo/aiperf/blob/main/docs/tutorials/speed-bench.md), which also documents the dataset layout in more detail.
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
pip install -r tools/server/bench/speed-bench/requirements.txt
|
||||
```
|
||||
|
||||
## Start a server
|
||||
|
||||
The client does not launch the server, so start `llama-server` yourself first. If you care about throughput numbers, set the client `--concurrency` to the server's slot count (`--np`):
|
||||
|
||||
```bash
|
||||
llama-server \
|
||||
-m target.gguf \
|
||||
-c 8192 \
|
||||
--port 8080 \
|
||||
-ngl 99 -fa on \
|
||||
--np 1 \
|
||||
--jinja
|
||||
```
|
||||
|
||||
For speculative decoding, start the server with the appropriate flags for your setup (e.g. a draft model with `-md`, or `--spec-type ngram-mod`). See the [speculative decoding doc](../../../../docs/speculative.md) for details.
|
||||
|
||||
## Run
|
||||
|
||||
```bash
|
||||
python tools/server/bench/speed-bench/speed_bench.py \
|
||||
--url localhost:8080 \
|
||||
--bench qualitative \
|
||||
--category coding \
|
||||
--osl 1024 \
|
||||
--concurrency 1
|
||||
```
|
||||
|
||||
## Options
|
||||
|
||||
| Option | Default | Description |
|
||||
| --- | --- | --- |
|
||||
| `--url` | `localhost:8080` | Server URL. The scheme and `/v1` are optional and a trailing slash is fine, so `localhost:8080` and `http://localhost:8080/v1/` both work. |
|
||||
| `--model` | none | Optional `model` field sent in each request. |
|
||||
| `--bench` | `qualitative` | SPEED-Bench config, e.g. `qualitative`, `throughput_1k`. See [available dataset variants](https://github.com/ai-dynamo/aiperf/blob/main/docs/tutorials/speed-bench.md#available-dataset-variants). |
|
||||
| `--category` | `all` | Category filter within the bench; comma-separated list or `all`. For `qualitative` the categories are `coding`, `humanities`, `math`, `multilingual`, `qa`, `rag`, `reasoning`, `roleplay`, `stem`, `summarization`, `writing`. For the `throughput_{ISL}` splits they are `high_entropy`, `low_entropy`, `mixed`. |
|
||||
| `--osl` | `1024` | Output sequence length, mapped to `max_tokens`. |
|
||||
| `--extra-inputs` | `{"temperature":0}` | Extra request fields as a JSON object. |
|
||||
| `--concurrency` | `1` | Concurrent client requests; usually match `--np`. |
|
||||
| `--limit` | none | Max samples per category (handy for smoke tests). |
|
||||
| `--timeout` | `600` | Per-request timeout in seconds. |
|
||||
| `--output` | none | Save raw per-request results and the summary to JSON. |
|
||||
|
||||
A few common ones:
|
||||
|
||||
- `--category all` runs every category in the bench.
|
||||
- `--category coding,math` runs just those two.
|
||||
- `--bench throughput_8k` runs a fixed-input-length throughput split.
|
||||
- `--limit 8` keeps at most 8 samples per category, which is enough for a quick check.
|
||||
|
||||
The `throughput_{ISL}` splits use fixed input lengths (1k - 32k), so they are handy for long-context testing and for comparing different `llama-server` batching settings (e.g. sweeping `-ub` / `--ubatch-size`) on prompts of a known size. Make sure the server `-c` is large enough for the chosen split. When raising `-ub`, also raise `-b` to at least the same value, since the physical ubatch cannot exceed the logical batch.
|
||||
|
||||
When `--output` is given, the JSON file holds the run `config`, the `selected_samples` / `completed_samples` / `failed_samples` counts, the per-category `summary` rows, and the per-sample `results`.
|
||||
|
||||
## Metrics
|
||||
|
||||
The summary prints one row per category plus an `overall` row:
|
||||
|
||||
- `samples` - how many samples finished successfully.
|
||||
- `avg_prompt_t/s` - prefill throughput from llama.cpp (`timings.prompt_per_second`), averaged over the category's samples.
|
||||
- `avg_pred_t/s` - decode throughput from llama.cpp (`timings.predicted_per_second`), averaged over the category's samples.
|
||||
- `avg_latency` - average end-to-end request latency seen by the client.
|
||||
- `accept_rate` - `accepted / draft_n` over the category, or `n/a` if nothing was drafted (`draft_n == 0`).
|
||||
|
||||
## Baseline vs speculative decoding
|
||||
|
||||
Save a run from each server with `--output`, then diff the two JSON files with `speed_bench_compare.py`.
|
||||
|
||||
First, start a plain `llama-server` (no speculative decoding) and save a baseline:
|
||||
|
||||
```bash
|
||||
python tools/server/bench/speed-bench/speed_bench.py \
|
||||
--url localhost:8080 \
|
||||
--bench qualitative \
|
||||
--category all \
|
||||
--osl 1024 \
|
||||
--concurrency 1 \
|
||||
--output baseline.json
|
||||
```
|
||||
|
||||
Then restart `llama-server` with speculative decoding enabled and save another run:
|
||||
|
||||
```bash
|
||||
python tools/server/bench/speed-bench/speed_bench.py \
|
||||
--url localhost:8080 \
|
||||
--bench qualitative \
|
||||
--category all \
|
||||
--osl 1024 \
|
||||
--concurrency 1 \
|
||||
--output spec.json
|
||||
```
|
||||
|
||||
Finally compare the two:
|
||||
|
||||
```bash
|
||||
python tools/server/bench/speed-bench/speed_bench_compare.py \
|
||||
--baseline baseline.json \
|
||||
--speculative spec.json
|
||||
```
|
||||
|
||||
The comparison table adds:
|
||||
|
||||
- `decode_speedup = spec_avg_pred_t/s / base_avg_pred_t/s`
|
||||
- `latency_speedup = base_avg_latency / spec_avg_latency`
|
||||
|
||||
Keep `--bench`, `--category`, `--osl`, and `--limit` the same across both runs, otherwise they won't be using the same prompts.
|
||||
@@ -0,0 +1,3 @@
|
||||
datasets
|
||||
requests
|
||||
tqdm
|
||||
@@ -0,0 +1,432 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import concurrent.futures
|
||||
import json
|
||||
import statistics
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
from datasets import get_dataset_config_names, load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
DATASET_REPO = "nvidia/SPEED-Bench"
|
||||
|
||||
@dataclass
|
||||
class Sample:
|
||||
id: str
|
||||
category: str
|
||||
turns: list[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestResult:
|
||||
id: str
|
||||
category: str
|
||||
ok: bool
|
||||
turns: int
|
||||
latency_s: float
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
finish_reason: str | None
|
||||
draft_n: int
|
||||
draft_n_accepted: int
|
||||
prompt_ms: float | None
|
||||
predicted_ms: float | None
|
||||
prompt_per_second: float | None
|
||||
predicted_per_second: float | None
|
||||
error: str | None
|
||||
|
||||
|
||||
def normalize_base_url(url: str) -> str:
|
||||
url = url.strip().rstrip("/")
|
||||
if not url:
|
||||
raise ValueError("--url cannot be empty")
|
||||
if "://" not in url:
|
||||
url = "http://" + url
|
||||
parsed = urlparse(url)
|
||||
if not parsed.scheme or not parsed.netloc:
|
||||
raise ValueError(f"invalid --url: {url}")
|
||||
if not parsed.path.rstrip("/").endswith("/v1"):
|
||||
url = url + "/v1"
|
||||
return url.rstrip("/")
|
||||
|
||||
|
||||
def parse_extra_inputs(value: str) -> dict[str, Any]:
|
||||
extra = json.loads(value)
|
||||
if not isinstance(extra, dict):
|
||||
raise ValueError("--extra-inputs must be a JSON object")
|
||||
return extra
|
||||
|
||||
|
||||
def extract_turns(row: dict[str, Any]) -> list[str]:
|
||||
turns = row.get("turns")
|
||||
if isinstance(turns, list) and turns:
|
||||
clean_turns = [str(turn).strip() for turn in turns if turn and str(turn).strip()]
|
||||
if clean_turns:
|
||||
return clean_turns
|
||||
raise ValueError("missing or empty turns")
|
||||
|
||||
|
||||
def load_samples(args: argparse.Namespace) -> list[Sample]:
|
||||
bench_names = get_dataset_config_names(DATASET_REPO)
|
||||
if args.bench not in bench_names:
|
||||
raise ValueError(
|
||||
f"unknown --bench {args.bench!r}; available benches: {', '.join(bench_names)}"
|
||||
)
|
||||
|
||||
dataset = load_dataset(DATASET_REPO, name=args.bench, split="test")
|
||||
categories = list(dict.fromkeys(str(category) for category in dataset["category"]))
|
||||
requested_categories = None
|
||||
if args.category != "all":
|
||||
requested_list = [category.strip() for category in args.category.split(",") if category.strip()]
|
||||
if not requested_list:
|
||||
raise ValueError(
|
||||
f"--category must be 'all' or a comma-separated list; available categories: {', '.join(categories)}"
|
||||
)
|
||||
requested_categories = set(requested_list)
|
||||
unknown_categories = [category for category in requested_list if category not in categories]
|
||||
if unknown_categories:
|
||||
unknown = ", ".join(unknown_categories)
|
||||
raise ValueError(
|
||||
f"unknown --category {unknown!r} for bench {args.bench!r}; "
|
||||
f"available categories: all, {', '.join(categories)}"
|
||||
)
|
||||
|
||||
samples: list[Sample] = []
|
||||
samples_per_category: dict[str, int] = {}
|
||||
skipped = 0
|
||||
for index, row_raw in enumerate(dataset):
|
||||
row = dict(row_raw)
|
||||
category_raw = row.get("category")
|
||||
if not isinstance(category_raw, str) or not category_raw.strip():
|
||||
skipped += 1
|
||||
continue
|
||||
category = category_raw.strip()
|
||||
if requested_categories is not None and category not in requested_categories:
|
||||
continue
|
||||
if args.limit is not None and samples_per_category.get(category, 0) >= args.limit:
|
||||
continue
|
||||
|
||||
try:
|
||||
turns = extract_turns(row)
|
||||
except ValueError:
|
||||
skipped += 1
|
||||
continue
|
||||
question_id = row.get("question_id")
|
||||
if not isinstance(question_id, str) or not question_id.strip():
|
||||
skipped += 1
|
||||
continue
|
||||
sample_id = question_id.strip()
|
||||
samples.append(Sample(id=sample_id, category=category, turns=turns))
|
||||
samples_per_category[category] = samples_per_category.get(category, 0) + 1
|
||||
|
||||
if not samples:
|
||||
raise RuntimeError(f"no samples selected from bench={args.bench} category={args.category}")
|
||||
|
||||
if skipped:
|
||||
print(f"speed_bench: skipped {skipped} rows without usable turns")
|
||||
return samples
|
||||
|
||||
|
||||
def parse_completion_response(data: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any], str | None, str]:
|
||||
usage = data.get("usage") or {}
|
||||
timings = data.get("timings") or {}
|
||||
finish_reason = None
|
||||
content = ""
|
||||
choices = data.get("choices")
|
||||
if isinstance(choices, list) and choices and isinstance(choices[0], dict):
|
||||
choice = choices[0]
|
||||
finish_reason = choice.get("finish_reason")
|
||||
message = choice.get("message")
|
||||
if isinstance(message, dict) and isinstance(message.get("content"), str):
|
||||
content = message["content"]
|
||||
elif isinstance(choice.get("text"), str):
|
||||
content = choice["text"]
|
||||
return usage, timings, finish_reason, content
|
||||
|
||||
|
||||
def run_request(
|
||||
endpoint: str,
|
||||
model: str | None,
|
||||
messages: list[dict[str, str]],
|
||||
osl: int,
|
||||
extra_inputs: dict[str, Any],
|
||||
timeout: float,
|
||||
) -> tuple[dict[str, Any], float]:
|
||||
payload: dict[str, Any] = {
|
||||
"messages": messages,
|
||||
"max_tokens": osl,
|
||||
"stream": False,
|
||||
}
|
||||
if model:
|
||||
payload["model"] = model
|
||||
payload.update(extra_inputs)
|
||||
payload["max_tokens"] = osl
|
||||
|
||||
start = time.perf_counter()
|
||||
response = requests.post(endpoint, json=payload, timeout=timeout)
|
||||
latency_s = time.perf_counter() - start
|
||||
if response.status_code != 200:
|
||||
body = response.text[:500].replace("\n", "\\n")
|
||||
raise RuntimeError(f"HTTP {response.status_code}: {body}")
|
||||
return response.json(), latency_s
|
||||
|
||||
|
||||
def run_one(
|
||||
sample: Sample,
|
||||
endpoint: str,
|
||||
model: str | None,
|
||||
osl: int,
|
||||
extra_inputs: dict[str, Any],
|
||||
timeout: float,
|
||||
) -> RequestResult:
|
||||
selected_turns = sample.turns
|
||||
messages: list[dict[str, str]] = []
|
||||
total_latency_s = 0.0
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
total_tokens = 0
|
||||
draft_n = 0
|
||||
draft_n_accepted = 0
|
||||
prompt_ms = 0.0
|
||||
predicted_ms = 0.0
|
||||
prompt_per_second = None
|
||||
predicted_per_second = None
|
||||
finish_reason: str | None = None
|
||||
try:
|
||||
for turn in selected_turns:
|
||||
messages.append({"role": "user", "content": turn})
|
||||
data, latency_s = run_request(endpoint, model, messages, osl, extra_inputs, timeout)
|
||||
total_latency_s += latency_s
|
||||
usage, timings, finish_reason, assistant_text = parse_completion_response(data)
|
||||
|
||||
turn_prompt_tokens = int(usage.get("prompt_tokens") or timings.get("prompt_n") or 0)
|
||||
turn_completion_tokens_count = int(usage.get("completion_tokens") or timings.get("predicted_n") or 0)
|
||||
turn_total_tokens_count = int(usage.get("total_tokens") or (turn_prompt_tokens + turn_completion_tokens_count))
|
||||
prompt_tokens += turn_prompt_tokens
|
||||
completion_tokens += turn_completion_tokens_count
|
||||
total_tokens += turn_total_tokens_count
|
||||
draft_n += int(timings.get("draft_n") or 0)
|
||||
draft_n_accepted += int(timings.get("draft_n_accepted") or 0)
|
||||
prompt_ms += float(timings.get("prompt_ms") or 0)
|
||||
predicted_ms += float(timings.get("predicted_ms") or 0)
|
||||
if len(selected_turns) == 1 and isinstance(timings.get("prompt_per_second"), (int, float)):
|
||||
prompt_per_second = float(timings["prompt_per_second"])
|
||||
if len(selected_turns) == 1 and isinstance(timings.get("predicted_per_second"), (int, float)):
|
||||
predicted_per_second = float(timings["predicted_per_second"])
|
||||
|
||||
messages.append({"role": "assistant", "content": assistant_text})
|
||||
|
||||
if total_tokens == 0:
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
if len(selected_turns) > 1:
|
||||
prompt_per_second = (prompt_tokens / (prompt_ms / 1000)) if prompt_ms > 0 else None
|
||||
predicted_per_second = (completion_tokens / (predicted_ms / 1000)) if predicted_ms > 0 else None
|
||||
|
||||
return RequestResult(
|
||||
id=sample.id,
|
||||
category=sample.category,
|
||||
ok=True,
|
||||
turns=len(selected_turns),
|
||||
latency_s=total_latency_s,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
finish_reason=finish_reason,
|
||||
draft_n=draft_n,
|
||||
draft_n_accepted=draft_n_accepted,
|
||||
prompt_ms=prompt_ms if prompt_ms > 0 else None,
|
||||
predicted_ms=predicted_ms if predicted_ms > 0 else None,
|
||||
prompt_per_second=prompt_per_second,
|
||||
predicted_per_second=predicted_per_second,
|
||||
error=None,
|
||||
)
|
||||
except Exception as exc:
|
||||
return RequestResult(
|
||||
id=sample.id,
|
||||
category=sample.category,
|
||||
ok=False,
|
||||
turns=len(selected_turns),
|
||||
latency_s=total_latency_s,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
finish_reason=None,
|
||||
draft_n=0,
|
||||
draft_n_accepted=0,
|
||||
prompt_ms=None,
|
||||
predicted_ms=None,
|
||||
prompt_per_second=None,
|
||||
predicted_per_second=None,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
|
||||
def summarize_group(category: str, results: list[RequestResult]) -> dict[str, Any]:
|
||||
ok_results = [result for result in results if result.ok]
|
||||
latencies = [result.latency_s for result in ok_results]
|
||||
server_prompt_speeds = [
|
||||
result.prompt_per_second
|
||||
for result in ok_results
|
||||
if result.prompt_per_second is not None
|
||||
]
|
||||
server_completion_speeds = [
|
||||
result.predicted_per_second
|
||||
for result in ok_results
|
||||
if result.predicted_per_second is not None
|
||||
]
|
||||
turns = sum(result.turns for result in ok_results)
|
||||
draft_n = sum(result.draft_n for result in ok_results)
|
||||
accepted = sum(result.draft_n_accepted for result in ok_results)
|
||||
|
||||
return {
|
||||
"category": category,
|
||||
"requests": len(ok_results),
|
||||
"turns": turns,
|
||||
"failed": len(results) - len(ok_results),
|
||||
"avg_prompt_t_s": statistics.mean(server_prompt_speeds) if server_prompt_speeds else None,
|
||||
"avg_pred_t_s": statistics.mean(server_completion_speeds) if server_completion_speeds else None,
|
||||
"avg_latency": statistics.mean(latencies) if latencies else None,
|
||||
"draft_n": draft_n,
|
||||
"accepted": accepted,
|
||||
"accept_rate": (accepted / draft_n) if draft_n > 0 else None,
|
||||
}
|
||||
|
||||
|
||||
def fmt_value(value: Any, kind: str = "") -> str:
|
||||
if value is None:
|
||||
return "n/a"
|
||||
if kind == "int":
|
||||
return str(int(value))
|
||||
if kind == "rate":
|
||||
return f"{float(value):.4f}"
|
||||
if kind == "seconds":
|
||||
return f"{float(value):.3f}s"
|
||||
if kind == "speed":
|
||||
return f"{float(value):.2f}"
|
||||
if kind == "speedup":
|
||||
return f"{float(value):.2f}x"
|
||||
return str(value)
|
||||
|
||||
|
||||
def print_table(rows: list[dict[str, Any]]) -> None:
|
||||
columns = [
|
||||
("category", "category", ""),
|
||||
("samples", "requests", "int"),
|
||||
("avg_prompt_t/s", "avg_prompt_t_s", "speed"),
|
||||
("avg_pred_t/s", "avg_pred_t_s", "speed"),
|
||||
("avg_latency", "avg_latency", "seconds"),
|
||||
("accept_rate", "accept_rate", "rate"),
|
||||
]
|
||||
print_rows(rows, columns)
|
||||
|
||||
|
||||
def print_rows(rows: list[dict[str, Any]], columns: list[tuple[str, str, str]]) -> None:
|
||||
rendered_rows = []
|
||||
for row in rows:
|
||||
rendered_rows.append([fmt_value(row.get(key), kind) for _, key, kind in columns])
|
||||
|
||||
widths = [len(header) for header, _, _ in columns]
|
||||
for rendered in rendered_rows:
|
||||
for i, cell in enumerate(rendered):
|
||||
widths[i] = max(widths[i], len(cell))
|
||||
|
||||
header = " ".join(header.ljust(widths[i]) for i, (header, _, _) in enumerate(columns))
|
||||
print(header)
|
||||
print(" ".join("-" * width for width in widths))
|
||||
for rendered in rendered_rows:
|
||||
print(" ".join(cell.ljust(widths[i]) for i, cell in enumerate(rendered)))
|
||||
|
||||
|
||||
def save_output(path: str, args: argparse.Namespace, samples: list[Sample], results: list[RequestResult], summary: list[dict[str, Any]]) -> None:
|
||||
payload = {
|
||||
"config": {
|
||||
"url": args.url,
|
||||
"model": args.model,
|
||||
"bench": args.bench,
|
||||
"category": args.category,
|
||||
"osl": args.osl,
|
||||
"concurrency": args.concurrency,
|
||||
"extra_inputs": args.extra_inputs,
|
||||
},
|
||||
"selected_samples": len(samples),
|
||||
"completed_samples": sum(1 for result in results if result.ok),
|
||||
"failed_samples": sum(1 for result in results if not result.ok),
|
||||
"summary": summary,
|
||||
"results": [asdict(result) for result in results],
|
||||
}
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(payload, f, indent=2, sort_keys=True)
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
parser = argparse.ArgumentParser(description="Run SPEED-Bench against an OpenAI-compatible llama-server.")
|
||||
parser.add_argument("--url", default="localhost:8080", help="Server URL, for example localhost:8080 or http://localhost:8080/v1")
|
||||
parser.add_argument("--model", default=None, help="Optional model name to send in OpenAI requests")
|
||||
parser.add_argument("--bench", default="qualitative", help="SPEED-Bench config to run, for example qualitative or throughput_1k")
|
||||
parser.add_argument("--category", default="all", help="Category to run within the selected bench; use all for no category filter")
|
||||
parser.add_argument("--osl", type=int, default=4096, help="Output sequence length, mapped to max_tokens")
|
||||
parser.add_argument("--extra-inputs", default='{"temperature":0}', help="Extra request fields as a JSON object")
|
||||
parser.add_argument("--concurrency", type=int, default=1, help="Concurrent client requests; usually match llama-server --np")
|
||||
parser.add_argument("--limit", type=int, default=None, help="Optional sample limit per category for smoke tests")
|
||||
parser.add_argument("--timeout", type=float, default=600, help="Per-request timeout in seconds")
|
||||
parser.add_argument("--output", default=None, help="Optional path to save raw results JSON")
|
||||
args = parser.parse_args(argv)
|
||||
try:
|
||||
base_url = normalize_base_url(args.url)
|
||||
endpoint = base_url + "/chat/completions"
|
||||
extra_inputs = parse_extra_inputs(args.extra_inputs)
|
||||
args.extra_inputs = extra_inputs
|
||||
samples = load_samples(args)
|
||||
except Exception as exc:
|
||||
print(f"speed_bench: setup failed: {exc}", file=sys.stderr)
|
||||
return 2
|
||||
|
||||
print(f"speed_bench: loaded {len(samples)} samples from bench={args.bench} category={args.category}")
|
||||
|
||||
results: list[RequestResult] = []
|
||||
started = time.perf_counter()
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=args.concurrency) as executor:
|
||||
futures = [
|
||||
executor.submit(run_one, sample, endpoint, args.model, args.osl, extra_inputs, args.timeout)
|
||||
for sample in samples
|
||||
]
|
||||
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="speed_bench", unit="sample"):
|
||||
result = future.result()
|
||||
results.append(result)
|
||||
|
||||
elapsed = time.perf_counter() - started
|
||||
categories = list(dict.fromkeys(sample.category for sample in samples))
|
||||
summary = [
|
||||
summarize_group(category, [result for result in results if result.category == category])
|
||||
for category in categories
|
||||
]
|
||||
summary.append(summarize_group("overall", results))
|
||||
print()
|
||||
print(f"Summary (elapsed={elapsed:.2f}s)")
|
||||
print_table(summary)
|
||||
|
||||
if args.output:
|
||||
save_output(args.output, args, samples, results, summary)
|
||||
print(f"\nspeed_bench: wrote {args.output}")
|
||||
|
||||
failed = sum(1 for result in results if not result.ok)
|
||||
if failed:
|
||||
print(f"\nspeed_bench: {failed} samples failed", file=sys.stderr)
|
||||
first_error = next((result.error for result in results if result.error), None)
|
||||
if first_error:
|
||||
print(f"first error: {first_error}", file=sys.stderr)
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -0,0 +1,84 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from speed_bench import fmt_value, print_rows
|
||||
|
||||
|
||||
def load_summary(path: str) -> list[dict[str, Any]]:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
summary = data.get("summary")
|
||||
if not isinstance(summary, list):
|
||||
raise ValueError(f"{path} does not contain a summary list")
|
||||
return summary
|
||||
|
||||
|
||||
def compare_rows(baseline: list[dict[str, Any]], speculative: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
baseline_by_category = {row["category"]: row for row in baseline}
|
||||
comparisons = []
|
||||
for row in speculative:
|
||||
base = baseline_by_category.get(row["category"])
|
||||
if not base:
|
||||
continue
|
||||
base_speed = base.get("avg_pred_t_s")
|
||||
spec_speed = row.get("avg_pred_t_s")
|
||||
base_latency = base.get("avg_latency")
|
||||
spec_latency = row.get("avg_latency")
|
||||
comparisons.append(
|
||||
{
|
||||
"category": row["category"],
|
||||
"base_avg_pred_t_s": base_speed,
|
||||
"spec_avg_pred_t_s": spec_speed,
|
||||
"decode_speedup": (spec_speed / base_speed) if base_speed and spec_speed else None,
|
||||
"base_avg_latency": base_latency,
|
||||
"spec_avg_latency": spec_latency,
|
||||
"latency_speedup": (base_latency / spec_latency) if base_latency and spec_latency else None,
|
||||
"accept_rate": row.get("accept_rate"),
|
||||
}
|
||||
)
|
||||
return comparisons
|
||||
|
||||
|
||||
def print_comparison(rows: list[dict[str, Any]]) -> None:
|
||||
if not rows:
|
||||
print("No overlapping categories found for comparison.")
|
||||
return
|
||||
columns = [
|
||||
("category", "category", ""),
|
||||
("base_avg_pred_t/s", "base_avg_pred_t_s", "speed"),
|
||||
("spec_avg_pred_t/s", "spec_avg_pred_t_s", "speed"),
|
||||
("decode_speedup", "decode_speedup", "speedup"),
|
||||
("base_avg_latency", "base_avg_latency", "seconds"),
|
||||
("spec_avg_latency", "spec_avg_latency", "seconds"),
|
||||
("latency_speedup", "latency_speedup", "speedup"),
|
||||
("accept_rate", "accept_rate", "rate"),
|
||||
]
|
||||
print_rows(rows, columns)
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
parser = argparse.ArgumentParser(description="Compare two SPEED-Bench runs (baseline vs speculative).")
|
||||
parser.add_argument("--baseline", required=True, help="Baseline results JSON produced by speed_bench.py --output")
|
||||
parser.add_argument("--speculative", required=True, help="Speculative decoding results JSON produced by speed_bench.py --output")
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
try:
|
||||
baseline = load_summary(args.baseline)
|
||||
speculative = load_summary(args.speculative)
|
||||
except Exception as exc:
|
||||
print(f"speed_bench_compare: failed to load inputs: {exc}", file=sys.stderr)
|
||||
return 2
|
||||
|
||||
comparisons = compare_rows(baseline, speculative)
|
||||
print(f"Comparison: baseline={args.baseline} speculative={args.speculative}")
|
||||
print_comparison(comparisons)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
Reference in New Issue
Block a user