Compare commits

...

47 Commits

Author SHA1 Message Date
Xuan Son Nguyen a432e6f863 use destructor instead 2026-06-23 22:57:20 +02:00
Xuan Son Nguyen 5d67f69f59 remove outdated comment 2026-06-23 22:49:40 +02:00
Xuan-Son Nguyen beef5cf077 Apply suggestions from code review
Co-authored-by: Piotr Wilkin (ilintar) <piotr.wilkin@syndatis.com>
2026-06-23 22:48:04 +02:00
Xuan Son Nguyen b093e46873 case: router with only one model 2026-06-23 16:47:30 +02:00
Xuan Son Nguyen 1401fc3ca7 cli support router mode
Co-authored-by: Piotr Wilkin <ilintar@gmail.com>
2026-06-23 16:43:58 +02:00
Xuan Son Nguyen 85c58bbcd0 remote server ok 2026-06-23 16:19:28 +02:00
Xuan Son Nguyen 19296c1735 working 2026-06-23 16:09:09 +02:00
Xuan Son Nguyen 90c111bf98 Merge branch 'master' into xsn/cli_http_based 2026-06-23 13:29:22 +02:00
Xuan-Son Nguyen 75ad0b23ed server: fix remote preset handling, add test (#24938)
* server: add test for remote preset

* fix remote preset handling

* fix

* fix test
2026-06-23 13:28:34 +02:00
Xuan Son Nguyen f7421eabe8 wip 2026-06-23 13:28:14 +02:00
Xuan Son Nguyen 59797670dc cli: move to HTTP-based implementation 2026-06-23 13:14:28 +02:00
Wyatt Caldwell c926ad0985 vulkan: link ggml-cpu when GGML_VULKAN_CHECK_RESULTS / RUN_TESTS are enabled (#24444)
The result-checking and test debug paths in ggml-vulkan.cpp call ggml_graph_compute_with_ctx() to compute a CPU reference graph, but that symbol is defined in ggml-cpu, which ggml-vulkan does not link. Enabling -DGGML_VULKAN_CHECK_RESULTS=ON (or -DGGML_VULKAN_RUN_TESTS=ON) therefore fails to link with an unresolved external (e.g. LNK2019 on MSVC, undefined reference on GCC/Clang). This regressed after ggml-cpu was split into its own library. Link ggml-cpu under those two options so the debug builds link again.

Signed-off-by: Wyatt Caldwell <218154709+Detensable@users.noreply.github.com>
2026-06-23 12:55:46 +02:00
Gabe Goodhart a3900a6694 model: Granite Speech Plus (#24818)
* feat: Add conversion support for Granite Speech Plus

Branch: GraniteSpeechPlus
AI-usage: full (Bob, OpenCode + Qwen3.6-35b)
Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Extend granite_speech to support plus multi-layer concatenation

Branch: GraniteSpeechPlus
AI-usage: draft (Bob, OpenCode + Qwen3.6-35b)
Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix(conversion): Fix plural naming for feature_layers for audio

Branch: GraniteSpeechPlus
AI-usage: none
Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix(mtmd): Align feature_layer usage and naming everywhere

Branch: GraniteSpeechPlus
AI-usage: none
Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* style: Use fstring for log

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>

---------

Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
2026-06-23 12:03:31 +02:00
Masashi Yoshimura 7c908502ea ggml-webgpu: improve MTP inference by using mat-vec path for small batches (#24811)
* ggml-webgpu: improve small batches decoding

* Add barrier to the NUM_COLS loop in mul-mat-vec
2026-06-23 17:13:55 +09:00
Masashi Yoshimura 035cd8f9a6 codeowners: add yomaytk to ggml-webgpu (#24930) 2026-06-23 15:19:34 +09:00
Aldehir Rojas 73618f27a8 server: improve user message detection and create checkpoints at every user message (#24176)
* server : improve message span logic

* cont : cast size_t to int32_t in comparisons

* server : create checkpoints before every user msg

* chat : remove \n in gemma4 delimiters

* chat : merge msg delimiter structs into one

* cont : reword comment

* cont : initialize tokens in delimiter

* cont : add server_tokens::get_raw_tokens() for mtmd

* cont : move message finding to server_tokens and skip mtmd tokens

* cont : update cohere2moe parser

* cont : increase min-step to 8192 and always produce a chkpt for last user message
2026-06-23 08:27:28 +03:00
Shawn Gu 23ee8797e1 opencl: q8_0 gemv precision improvement (#24923) 2026-06-22 22:25:21 -07:00
Matt Thompson dec5ca5577 server : Add id to tool call responses api (#24882) 2026-06-22 23:03:12 +02:00
Mahdiou Diallo 9c0ac887f3 ui: Prioritize favorite models in model selection (#24766)
Updated model selection prioritization to include favorite models.
2026-06-22 21:00:21 +02:00
Xuan-Son Nguyen 721354fbdf server: (router) move model downloading to dedicated process (#24834)
* server: real-time model load progress tracking via /models/sse

* update docs

* server: move model download to child process

* rm unused

* fix most problems

* clean up

* nit fixes

* fix test case

* do not detact() thread

* shorter MODEL_DOWNLOAD_TIMEOUT in test

* throttle
2026-06-22 18:24:04 +02:00
Xuan-Son Nguyen 6ee0f65793 server: refactor/generalize input file schema (#24299)
* server: refactor/generalize input file schema

* wire up input_video, accept raw base64

* nits

* nits (2)

* fix windows
2026-06-22 16:42:47 +02:00
Pascal 099b579acb ui: model status and load progress via /models/sse feed (#24878)
* ui: model status and load progress via /models/sse feed

* ui: centralize SSE wire-format delimiters into shared constants for the chat and /models/sse parsers

* ui: type /models/sse event names as a ServerModelsSseEventType enum

Address review from allozaur
2026-06-22 15:55:30 +02:00
Neo Zhang f8cc15f163 [SYCL] support bf16 on bin_bcast OP and unary OPs (#24838)
* support bf16 on bin_bcast OP and unary OPs

* support the older Intel compiler than 2026.0
2026-06-22 14:09:02 +03:00
Tim Neumann 37957e8531 sampling : remove unconditional softmax+sort in top-n-sigma sampler (#22645) 2026-06-22 14:08:32 +03:00
Pascal d0f9d2e5ac server: fix edit_file crash on append at end of file (line_start -1) (#24893)
line_start -1 normalized to n+1, so append inserted at lines.begin() + n + 1,
one past end() -> heap-buffer-overflow in vector::_M_range_insert.

Normalize -1 to n (insert at end()), restrict -1 to append mode and reject it
for replace/delete instead of silently clobbering the last line. Parenthesize
the insert offset so empty-file append computes the position as int first,
avoiding a transient begin() - 1 on a null vector data pointer.
2026-06-22 10:55:28 +02:00
aafsmarak 0ef6f06d55 docs/android.md: Add dependency libandroid-spawn for building in termux (#21812)
Fixes https://github.com/ggml-org/llama.cpp/issues/18615
2026-06-22 05:48:31 +02:00
Aldehir Rojas 52b3df0023 common/peg : implement ac parser for stricter grammar generation (#24869)
* common/peg : implement ac parser

* cont : extract functions

* cont : tidy up

* cont : remove a test

* cont : move ac() def
2026-06-21 16:20:58 -05:00
Xuan-Son Nguyen 7c082bc417 server: fix report progress for loading spec models, add "stages" list (#24870)
* server: fix report progress for loading spec models, add "stages" list

* improve

* nits

* nits 2
2026-06-21 17:36:52 +02:00
Xuan-Son Nguyen bddfd2b113 server: refactor batch construction (#24843)
* server: refactor batch construction

* wip

* wip 2

* wip 3

* wip 4

* add abort_all_slots

* handle batch full more carefully

* fix assert

* rm debug log

* small nits

* (debug) add timings

* debug: force llama_synchronize for accurate timings

* address comments

* disable DEBUG_TIMINGS
2026-06-21 14:16:11 +02:00
Xuan-Son Nguyen 0d135df48c mtmd: fix mtmd_get_memory_usage (#24867) 2026-06-21 14:12:15 +02:00
Sigbjørn Skjæret bf533823cd jinja : implement call statement (#24847)
* implement call statement

* undo unintended change

* de-lambda

* simplify

* move caller context inside function handler
2026-06-21 14:04:52 +02:00
Xuan-Son Nguyen 2f89acc2bc mtmd: add load progress callback (#24865) 2026-06-21 13:40:52 +02:00
Xuan-Son Nguyen bfa3219177 server: add "verbose" field to schema (#24864) 2026-06-21 13:03:14 +02:00
Xuan-Son Nguyen d6d899580d server: real-time model load progress tracking via /models/sse (#24828)
* server: real-time model load progress tracking via /models/sse

* update docs

* add mutex for notify_to_router

* correct docs
2026-06-21 11:58:14 +02:00
Georgi Gerganov 8a118ee86c minor : clean-up whitespaces (#24862)
[no ci]
2026-06-21 11:37:12 +03:00
YiChen Lv d789527482 spec : Support Step3.5/3.7 flash mtp3 (#24340)
* add mtp_layer_offset + include nextn flags in graph reuse

* add llama_set_mtp_layer_offset + llama_model_n_nextn_layer API

* offset head select + require all MTP blocks

* speculative multi-head process()

* speculative multi-head draft()

* gather outputs via inp_out_ids

* cleanup

* fix core

* minor cleanup

* merged draft_multi_head into draft()

* mtp rename nextn

* Apply suggestions from code review

Co-authored-by: Aman Gupta <amangupta052@gmail.com>

* clean-up comments

* fix for multi seq

* apply suggestions && chain-heads comment

* add a reference for chain_heads discussion

---------

Co-authored-by: Aman Gupta <amangupta052@gmail.com>
2026-06-21 11:33:18 +03:00
Aldehir Rojas 063d9c156e common/peg : refactor until gbnf grammar generation (#24839)
* common/peg : refactor until gbnf grammar into an ac automaton

* cont : add a test with multiple strings

* cont : pad state with 0s so rules line up

* cont : clean up comments

* cont : use set everywhere

* cont : inline state num string padding

* cont : add a ref to PR

* cont : fix regression in server-tools.cpp
2026-06-20 21:15:06 -05:00
Aldehir Rojas c57607016a common/json-schema-to-grammar : align spacing rules with parsers (#24835) 2026-06-20 17:43:04 -05:00
Guanhuai Zhang 4a80943174 fix(hexagon): use padded stride for ssm-conv weights (#24470) 2026-06-20 14:58:49 -07:00
Adrien Gallouët 84de01a1f1 llama : use LLM_KV for quantization_version & file_type (#24802)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-06-20 20:07:01 +02:00
Xuan-Son Nguyen 75f460ac28 arg: try fixing test-args-parser randomly fails (#24826)
* arg: try fixing test-args-parser randomly fails

* return ref

* try triggering the workflow

* exception wrapper

* wip

* test

* test 2

* arg: guard win32 utf8 argv override

make_utf8_argv rebuilds argv from GetCommandLineW to fix utf8 handling of
non ascii arguments on windows. the override runs unconditionally inside
common_params_parse, so it also clobbers a programmatic argv passed by a
caller. test-arg-parser builds a synthetic argv but then sees the real
process command line instead, the model argument is never parsed, and the
assert that expects success aborts via fastfail (0xC0000409). this shows up
as a random failure in the openvino windows workflow.

only override argv when its length matches the caller argc, so the utf8
repair still applies to real binaries while a programmatic argv stays intact.

---------

Co-authored-by: Pascal <admin@serveurperso.com>
2026-06-20 19:45:27 +02:00
Muhammad Salem 8452824611 release: add missing link for win opencl adreno arm64 (#24809) 2026-06-20 23:08:59 +08:00
Matti4 e27f308597 server: avoid forwarding auth headers in CORS proxy (#24373)
* server: avoid forwarding auth headers in CORS proxy

* format

* fix test

* fix e2e test

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
2026-06-20 15:34:47 +02:00
Aldehir Rojas 67e9fd3b74 docker : prebuild web UI for s390x build [no release] (#24829) 2026-06-20 05:54:42 -05:00
davidrhodus 796f41bedc model : glm-dsa load DSA indexer tensors as optional (#24770)
GLM-5.2 ships the DSA "lightning indexer" on only a subset of layers (the
"full" layers; others omit it), but the GLM_DSA loader created the five
indexer tensors on every layer as required, so loading any GLM-5.2 GGUF
failed with e.g. `missing tensor 'blk.3.indexer.k_norm.weight'`.

GLM_DSA's graph is llama_model_deepseek2::graph (plain MLA) and does not use
the indexer tensors (indexer runtime not yet implemented), so they are
loaded-but-unused. Marking them TENSOR_NOT_REQUIRED lets layers without an
indexer load as nullptr and the model runs as full MLA attention.

DeepSeek-V3.2 (uniform indexer on all layers) is unaffected.
2026-06-20 13:48:24 +03:00
Adrien Gallouët 37a77fb057 ggml : optimize AMX (#24806)
Flatten the partition over n_batch * M so every thread participates in
the quantization

    | CPU                             | Model                         | Test   |   t/s OLD |   t/s NEW |   Speedup |
    |:--------------------------------|:------------------------------|:-------|----------:|----------:|----------:|
    | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B IQ4_NL - 4.5 bpw  | pp512  |    730.71 |    779.86 |      1.07 |
    | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B IQ4_NL - 4.5 bpw  | tg128  |     87.88 |     86.79 |      0.99 |
    | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B IQ4_XS - 4.25 bpw | pp512  |    725.09 |   1023.31 |      1.41 |
    | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B IQ4_XS - 4.25 bpw | tg128  |     83.64 |     83.62 |      1.00 |
    | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B Q4_0              | pp512  |    820.51 |    924.05 |      1.13 |
    | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B Q4_0              | tg128  |     90.59 |     92.46 |      1.02 |
    | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B Q4_1              | pp512  |    776.88 |    872.79 |      1.12 |
    | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B Q4_1              | tg128  |     89.39 |     90.94 |      1.02 |
    | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B Q4_K_M            | pp512  |    719.28 |   1009.27 |      1.40 |
    | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B Q4_K_M            | tg128  |     80.62 |     80.86 |      1.00 |
    | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B Q4_K_S            | pp512  |    732.29 |   1077.29 |      1.47 |
    | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B Q4_K_S            | tg128  |     86.42 |     83.53 |      0.97 |

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-06-20 13:43:06 +03:00
Sigbjørn Skjæret f4043fec01 convert : more consistent handling of rope_parameters (#24833) 2026-06-20 13:42:36 +03:00
130 changed files with 5417 additions and 2743 deletions
-16
View File
@@ -4,20 +4,6 @@ ARG BUILD_DATE=N/A
ARG APP_VERSION=N/A
ARG APP_REVISION=N/A
ARG NODE_VERSION=24
FROM docker.io/node:$NODE_VERSION AS web
ARG APP_VERSION
WORKDIR /app/tools/ui
COPY tools/ui/package.json tools/ui/package-lock.json ./
RUN npm ci
COPY tools/ui/ ./
RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build
### Build Llama.cpp stage
FROM docker.io/gcc:${GCC_VERSION} AS build
@@ -34,8 +20,6 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
WORKDIR /app
COPY . .
COPY --from=web /app/tools/ui/dist tools/ui/dist
RUN --mount=type=cache,target=/root/.ccache \
--mount=type=cache,target=/app/build \
cmake -S . -B build -G Ninja \
-1
View File
@@ -11,7 +11,6 @@
build*/
tools/ui/node_modules/
tools/ui/dist/
models/*
+16 -2
View File
@@ -58,6 +58,13 @@ jobs:
git tag ${{ steps.srctag.outputs.name }} || exit 0
git push origin ${{ steps.srctag.outputs.name }} || exit 0
build_ui:
name: Build UI
needs: create_tag
uses: ./.github/workflows/ui-build.yml
with:
hf_ui_version: ${{ needs.create_tag.outputs.source_tag }}
prepare_matrices:
name: Prepare Docker matrices
runs-on: ubuntu-24.04
@@ -79,7 +86,7 @@ jobs:
[
{ "tag": "cpu", "dockerfile": ".devops/cpu.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04" },
{ "tag": "cpu", "dockerfile": ".devops/cpu.Dockerfile", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-arm" },
{ "tag": "cpu", "dockerfile": ".devops/s390x.Dockerfile", "platforms": "linux/s390x", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-s390x" },
{ "tag": "cpu", "dockerfile": ".devops/s390x.Dockerfile", "platforms": "linux/s390x", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-s390x", "prebuilt_ui": true },
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "12.8.1", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "12.8.1", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" },
{ "tag": "cuda13", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "13.3.0", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
@@ -135,7 +142,7 @@ jobs:
push_to_registry:
name: Push Docker image to Docker Registry
needs: [prepare_matrices, create_tag]
needs: [prepare_matrices, create_tag, build_ui]
runs-on: ${{ matrix.config.runs_on }}
strategy:
@@ -150,6 +157,13 @@ jobs:
fetch-depth: 0
ref: ${{ needs.create_tag.outputs.source_tag }}
- name: Download prebuilt UI
if: ${{ matrix.config.prebuilt_ui == true }}
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8
with:
name: ui-build
path: tools/ui/dist
- name: Set up QEMU
if: ${{ contains(matrix.config.platforms, 'linux/amd64') }}
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4
+1
View File
@@ -1627,6 +1627,7 @@ jobs:
**Windows:**
- [Windows x64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cpu-x64.zip)
- [Windows arm64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cpu-arm64.zip)
- [Windows arm64 (OpenCL Adreno)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-opencl-adreno-arm64.zip)
- [Windows x64 (CUDA 12)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cuda-12.4-x64.zip) - [CUDA 12.4 DLLs](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/cudart-llama-bin-win-cuda-12.4-x64.zip)
- [Windows x64 (CUDA 13)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cuda-13.3-x64.zip) - [CUDA 13.3 DLLs](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/cudart-llama-bin-win-cuda-13.3-x64.zip)
- [Windows x64 (Vulkan)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-vulkan-x64.zip)
+1 -1
View File
@@ -10,7 +10,7 @@
# ggml-org/ggml-rpc : rgerganov
# ggml-org/ggml-sycl : arthw
# ggml-org/ggml-vulkan : 0cc4m, jeffbolznv
# ggml-org/ggml-webgpu : reeselevine
# ggml-org/ggml-webgpu : reeselevine, yomaytk
# ggml-org/ggml-zdnn : taronaeo
# ggml-org/llama-common : ggerganov, aldehir, angt, danbev, ngxson, pwilkin
# ggml-org/llama-mtmd : ngxson
+26 -11
View File
@@ -301,6 +301,8 @@ static handle_model_result common_params_handle_model(struct common_params_model
const common_download_opts & opts) {
handle_model_result result;
// TODO @ngxson : refactor this into a new common_model_download_context
if (!model.docker_repo.empty()) {
model.path = common_docker_resolve_model(model.docker_repo);
} else if (!model.hf_repo.empty()) {
@@ -396,7 +398,7 @@ static bool parse_bool_value(const std::string & value) {
// CLI argument parsing functions
//
bool common_params_handle_models(common_params & params, llama_example curr_ex) {
bool common_params_handle_models(common_params & params, llama_example curr_ex, const common_params_handle_models_params & handle_params) {
const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(),
params.speculative.types.end(),
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end();
@@ -407,6 +409,11 @@ bool common_params_handle_models(common_params & params, llama_example curr_ex)
opts.skip_download = params.skip_download;
opts.download_mtp = spec_type_draft_mtp;
opts.download_mmproj = !params.no_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty();
opts.preset_only = handle_params.preset_only;
if (handle_params.callback) {
opts.callback = handle_params.callback;
}
// sub-models (draft, mmproj, vocoder) are explicitly specified by the user,
// so we should not auto-discover mtp/mmproj siblings for them
@@ -584,19 +591,20 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n");
}
// export_graph_ops loads only metadata
const bool skip_model_download = ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
const bool skip_model_download =
// server will call common_params_handle_models() later, so we skip it here
ctx_arg.ex == LLAMA_EXAMPLE_SERVER ||
// export_graph_ops loads only metadata
ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
if (!skip_model_download) {
// handle model and download
common_params_handle_models(params, ctx_arg.ex);
common_params_handle_models(params, ctx_arg.ex, {});
// model is required (except for server)
// TODO @ngxson : maybe show a list of available models in CLI in this case
if (params.model.path.empty()
&& ctx_arg.ex != LLAMA_EXAMPLE_SERVER
&& !params.usage
&& !params.completion) {
bool can_skip_model = params.usage || params.completion || !params.server_base.empty();
if (!can_skip_model && params.model.path.empty()) {
throw std::invalid_argument("error: --model is required\n");
}
}
@@ -924,8 +932,8 @@ static utf8_argv make_utf8_argv() {
bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **)) {
#ifdef _WIN32
auto utf8 = make_utf8_argv();
if (!utf8.ptrs.empty()) {
argc = static_cast<int>(utf8.buf.size());
// repair argv only when it matches the process command line
if (static_cast<int>(utf8.buf.size()) == argc) {
argv = utf8.ptrs.data();
}
#endif
@@ -1110,6 +1118,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.completion = true;
}
));
add_opt(common_arg(
{"--server-base"}, "URL",
string_format("connect to this server instead of starting a new one, example: 'http://localhost:8080' (default: none)"),
[](common_params & params, const std::string & value) {
params.server_base = value;
}
).set_examples({LLAMA_EXAMPLE_CLI}));
add_opt(common_arg(
{"--verbose-prompt"},
string_format("print a verbose prompt before generation (default: %s)", params.verbose_prompt ? "true" : "false"),
@@ -2897,7 +2912,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.server_tools = parse_csv_row(value);
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_TOOLS"));
add_opt(common_arg(
add_opt(common_arg(
{"-ag", "--agent"},
{"-no-ag", "--no-agent"},
"whether to enable CORS proxy and all built-in tools - do not enable in untrusted environments (default: disabled)",
+10 -1
View File
@@ -1,6 +1,7 @@
#pragma once
#include "common.h"
#include "download.h"
#include <set>
#include <map>
@@ -129,11 +130,19 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<com
// see: https://github.com/ggml-org/llama.cpp/issues/18163
void common_params_add_preset_options(std::vector<common_arg> & args);
struct common_params_handle_models_params {
common_download_callback * callback = nullptr;
bool preset_only = false; // if true, only check & download remote preset (for router mode)
};
// populate model paths (main model, mmproj, etc) from -hf if necessary
// return true if the model is ready to use
// throw an exception if there is an error that prevents the model from being used (e.g. network error, model not found, etc)
// if params.skip_download is true, no downloads will be attempted. return false if the model is invalid or missing (e.g. ETag check failed)
bool common_params_handle_models(common_params & params, llama_example curr_ex);
bool common_params_handle_models(
common_params & params,
llama_example curr_ex,
const common_params_handle_models_params & handle_params);
// initialize argument parser context - used by test-arg-parser and preset
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
+5 -4
View File
@@ -395,10 +395,11 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
arguments.name_suffix) +
arguments.value_prefix +
(schema_info.resolves_to_string(param_schema) ?
p.tool_arg_string_value(until_suffix) :
p.tool_arg_json_value(p.schema(
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false))) +
p.tool_arg_close(p.literal(arguments.value_suffix)));
p.ac(p.tool_arg_string_value(until_suffix) +
p.tool_arg_close(p.literal(arguments.value_suffix)), arguments.value_suffix) :
(p.tool_arg_json_value(p.schema(
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false)) +
p.tool_arg_close(p.literal(arguments.value_suffix)))));
auto named_arg = p.rule("tool-" + name + "-arg-" + param_name, arg);
if (is_required) {
+103 -53
View File
@@ -90,41 +90,93 @@ std::string common_chat_msg::render_content(const std::string & delimiter) const
return text;
}
std::vector<common_chat_msg_span> common_chat_split_by_role(const std::string & prompt, const std::vector<common_chat_msg_delimiter> & delims) {
if (delims.empty() || prompt.empty()) {
return {};
common_chat_role common_chat_role_from_string(const std::string & role) {
if (role == "system") { return COMMON_CHAT_ROLE_SYSTEM; }
if (role == "assistant") { return COMMON_CHAT_ROLE_ASSISTANT; }
if (role == "user") { return COMMON_CHAT_ROLE_USER; }
if (role == "tool") { return COMMON_CHAT_ROLE_TOOL; }
return COMMON_CHAT_ROLE_UNKNOWN;
}
const char * common_chat_role_to_string(common_chat_role role) {
switch (role) {
case COMMON_CHAT_ROLE_SYSTEM: return "system";
case COMMON_CHAT_ROLE_ASSISTANT: return "assistant";
case COMMON_CHAT_ROLE_USER: return "user";
case COMMON_CHAT_ROLE_TOOL: return "tool";
case COMMON_CHAT_ROLE_UNKNOWN: return "";
}
return "";
}
json common_chat_msg_delimiters::to_json() const {
json result = json::array();
for (const auto & d : delimiters) {
result.push_back({
{ "role", common_chat_role_to_string(d.role) },
{ "delimiter", d.delimiter },
});
}
return result;
}
common_chat_msg_delimiters common_chat_msg_delimiters_parse(const json & delimiters) {
common_chat_msg_delimiters result;
if (!delimiters.is_array()) {
return result;
}
auto parser = build_peg_parser([&](common_peg_parser_builder & p) {
std::vector<std::string> all_delims;
std::vector<common_peg_parser> tagged_messages;
all_delims.reserve(delims.size());
tagged_messages.reserve(delims.size());
for (const auto & d : delims) {
all_delims.push_back(d.delimiter);
result.delimiters.reserve(delimiters.size());
for (const auto & d : delimiters) {
if (!d.is_object()) {
continue;
}
auto any_delim = p.until_one_of(all_delims);
for (const auto & d : delims) {
tagged_messages.push_back(p.tag(d.role, p.literal(d.delimiter) + any_delim));
}
return any_delim + p.zero_or_more(p.choice(tagged_messages)) + p.end();
});
common_peg_parse_context ctx(prompt);
const auto result = parser.parse(ctx);
if (!result.success()) {
return {};
result.delimiters.push_back({
common_chat_role_from_string(d.value("role", std::string())),
d.value("delimiter", std::string()),
});
}
std::vector<common_chat_msg_span> spans;
ctx.ast.visit(result, [&](const common_peg_ast_node & node) {
if (!node.tag.empty()) {
spans.push_back({ node.tag, node.start, node.end - node.start });
return result;
}
void common_chat_msg_delimiters::tokenize(const llama_vocab * vocab) {
for (auto & d : delimiters) {
d.tokens = common_tokenize(vocab, d.delimiter, false, true);
}
}
common_chat_msg_spans common_chat_msg_delimiters::split(const llama_tokens & tokens, const std::map<size_t, size_t> & skips) const {
std::vector<std::pair<common_chat_role, size_t>> matches;
auto skip = skips.begin();
for (size_t i = 0; i < tokens.size();) {
if (skip != skips.end() && i == skip->first) {
i += skip->second;
++skip;
continue;
}
});
for (const auto & d : delimiters) {
if (i + d.tokens.size() > tokens.size()) {
continue;
}
if (std::equal(d.tokens.begin(), d.tokens.end(), tokens.begin() + i)) {
matches.emplace_back(d.role, i);
break;
}
}
i++;
}
matches.emplace_back(COMMON_CHAT_ROLE_UNKNOWN, tokens.size());
common_chat_msg_spans spans;
for (size_t i = 0; i + 1 < matches.size(); i++) {
const auto & curr = matches[i];
const auto & next = matches[i + 1];
spans.add(curr.first, curr.second, next.second - curr.second);
}
return spans;
}
@@ -1081,13 +1133,13 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
data.prompt = prompt;
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs, /* messages_override= */ adjusted_messages);
data.message_spans = common_chat_split_by_role(prompt, {
{ "assistant", "<|start|>assistant" },
{ "user", "<|start|>user" },
{ "system", "<|start|>developer" },
{ "system", "<|start|>system" },
{ "tool", "<|start|>functions" },
});
data.message_delimiters = {
{ COMMON_CHAT_ROLE_ASSISTANT, "<|start|>assistant" },
{ COMMON_CHAT_ROLE_USER, "<|start|>user" },
{ COMMON_CHAT_ROLE_SYSTEM, "<|start|>developer" },
{ COMMON_CHAT_ROLE_SYSTEM, "<|start|>system" },
{ COMMON_CHAT_ROLE_TOOL, "<|start|>functions" },
};
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.supports_thinking = true;
@@ -1228,10 +1280,10 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
data.prompt += data.generation_prompt;
}
data.message_spans = common_chat_split_by_role(data.prompt, {
{ "user", "<|turn>user\n" },
{ "assistant", "<|turn>model\n" },
});
data.message_delimiters = {
{ COMMON_CHAT_ROLE_USER, "<|turn>user" },
{ COMMON_CHAT_ROLE_ASSISTANT, "<|turn>model" },
};
data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4;
data.supports_thinking = true;
@@ -2030,15 +2082,15 @@ static common_chat_params common_chat_params_init_cohere2moe(const common_chat_t
RESULT_START, RESULT_END,
};
// Split the rendered prompt into per-role message spans. Tool results are rendered with the
// Declare per-role message delimiters. Tool results are rendered with the
// system token followed by <|START_TOOL_RESULT|>, so the "tool" delimiter must be listed before
// the plain "system" one (it is a strict superset, and the role split tries delimiters in order).
data.message_spans = common_chat_split_by_role(data.prompt, {
{ "assistant", GEN_PREFIX },
{ "user", TURN_START + USER },
{ "tool", TURN_START + SYSTEM + RESULT_START },
{ "system", TURN_START + SYSTEM },
});
data.message_delimiters = {
{ COMMON_CHAT_ROLE_ASSISTANT, GEN_PREFIX },
{ COMMON_CHAT_ROLE_USER, TURN_START + USER },
{ COMMON_CHAT_ROLE_TOOL, TURN_START + SYSTEM + RESULT_START },
{ COMMON_CHAT_ROLE_SYSTEM, TURN_START + SYSTEM },
};
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
@@ -2526,17 +2578,15 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
autoparser.analyze_template(tmpl);
auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser);
std::vector<common_chat_msg_delimiter> delimiters;
common_chat_msg_delimiters delimiters;
if (!autoparser.assistant_start.empty()) {
delimiters.push_back({ "assistant", autoparser.assistant_start });
delimiters.add(COMMON_CHAT_ROLE_ASSISTANT, autoparser.assistant_start);
}
if (!autoparser.user_start.empty()) {
delimiters.push_back({ "user", autoparser.user_start });
delimiters.add(COMMON_CHAT_ROLE_USER, autoparser.user_start);
}
if (!delimiters.empty()) {
auto_params.message_spans = common_chat_split_by_role(auto_params.prompt, delimiters);
}
auto_params.message_delimiters = std::move(delimiters);
auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE;
if (auto_params.supports_thinking) {
+65 -6
View File
@@ -143,15 +143,75 @@ struct common_chat_msg_diff {
}
};
enum common_chat_role {
COMMON_CHAT_ROLE_UNKNOWN,
COMMON_CHAT_ROLE_SYSTEM,
COMMON_CHAT_ROLE_ASSISTANT,
COMMON_CHAT_ROLE_USER,
COMMON_CHAT_ROLE_TOOL
};
common_chat_role common_chat_role_from_string(const std::string & role);
const char * common_chat_role_to_string(common_chat_role role);
struct common_chat_msg_span {
std::string role;
common_chat_role role = COMMON_CHAT_ROLE_UNKNOWN;
std::size_t pos = 0;
std::size_t len = 0;
bool valid() const {
return role != COMMON_CHAT_ROLE_UNKNOWN;
}
};
struct common_chat_msg_spans {
std::vector<common_chat_msg_span> spans;
void add(common_chat_role role, size_t pos, size_t len) {
spans.push_back({ role, pos, len });
}
bool is_user_start(int32_t pos) const {
for (auto it = spans.begin(); it != spans.end(); ++it) {
if (it->role == COMMON_CHAT_ROLE_USER && pos == (int32_t) it->pos) {
return true;
}
}
return false;
}
int32_t last_user_message_pos() const {
for (auto it = spans.rbegin(); it != spans.rend(); ++it) {
if (it->role == COMMON_CHAT_ROLE_USER) {
return (int32_t) it->pos;
}
}
return -1;
}
};
struct common_chat_msg_delimiter {
std::string role;
std::string delimiter;
common_chat_role role = COMMON_CHAT_ROLE_UNKNOWN;
std::string delimiter;
llama_tokens tokens = {};
};
struct common_chat_msg_delimiters {
std::vector<common_chat_msg_delimiter> delimiters;
common_chat_msg_delimiters() = default;
common_chat_msg_delimiters(std::initializer_list<common_chat_msg_delimiter> delims) : delimiters(delims) {}
void add(common_chat_role role, const std::string & delimiter) {
delimiters.push_back({ role, delimiter });
}
void tokenize(const llama_vocab * vocab);
// split tokens into message spans. skips maps a start index to a length of a region to jump over without matching
common_chat_msg_spans split(const llama_tokens & tokens, const std::map<size_t, size_t> & skips = {}) const;
nlohmann::ordered_json to_json() const;
};
struct common_chat_tool {
@@ -219,7 +279,7 @@ struct common_chat_params {
std::vector<std::string> preserved_tokens;
std::vector<std::string> additional_stops;
std::string parser;
std::vector<common_chat_msg_span> message_spans;
common_chat_msg_delimiters message_delimiters;
};
// per-message parsing syntax
@@ -325,5 +385,4 @@ struct common_chat_prompt_preset {
common_chat_prompt_preset common_chat_get_asr_prompt(const common_chat_templates * chat_templates);
std::vector<common_chat_msg_span> common_chat_split_by_role(const std::string & prompt, const std::vector<common_chat_msg_delimiter> & delims);
common_chat_msg_delimiters common_chat_msg_delimiters_parse(const nlohmann::ordered_json & delimiters);
+4 -1
View File
@@ -609,7 +609,7 @@ struct common_params {
bool cache_prompt = true; // whether to enable prompt caching
bool cache_idle_slots = true; // save and clear idle slots upon starting a new task
int32_t n_ctx_checkpoints = 32; // max number of context checkpoints per slot
int32_t checkpoint_min_step = 256; // minimum spacing between context checkpoints
int32_t checkpoint_min_step = 8192; // minimum spacing between context checkpoints
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
std::string hostname = "127.0.0.1";
@@ -631,6 +631,9 @@ struct common_params {
std::map<std::string, std::string> default_template_kwargs;
// CLI params
std::string server_base; // if set, connect to this server instead of starting a new one
// UI configs
bool ui = true;
bool ui_mcp_proxy = false;
+3 -1
View File
@@ -799,6 +799,7 @@ common_download_model_result common_download_model(const common_params_model &
bool download_mmproj = opts.download_mmproj;
bool download_mtp = opts.download_mtp;
bool preset_only = opts.preset_only;
bool is_hf = !model.hf_repo.empty();
if (is_hf) {
@@ -806,7 +807,8 @@ common_download_model_result common_download_model(const common_params_model &
if (!hf.preset.path.empty()) {
// if preset.ini exists, only download that file alone
tasks.push_back({hf.preset.url, hf.preset.local_path});
} else {
} else if (!preset_only) {
// only add other files if we're NOT in preset-only mode (normal run, non-router)
for (const auto & f : hf.model_files) {
tasks.push_back({f.url, f.local_path});
}
+1
View File
@@ -55,6 +55,7 @@ struct common_download_opts {
bool skip_download = false; // if true, only validation is performed, common_skip_download_exception may be thrown if the file is missing or invalid
bool download_mmproj = false;
bool download_mtp = false;
bool preset_only = false; // if true, only check & download remote preset (for router mode)
common_download_callback * callback = nullptr;
};
+70
View File
@@ -2,6 +2,16 @@
#include <cpp-httplib/httplib.h>
#ifdef _WIN32
#include <winsock2.h>
#include <windows.h>
#else
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <unistd.h>
#endif
struct common_http_url {
std::string scheme;
std::string user;
@@ -97,3 +107,63 @@ static std::pair<httplib::Client, common_http_url> common_http_client(const std:
static std::string common_http_show_masked_url(const common_http_url & parts) {
return parts.scheme + "://" + (parts.user.empty() ? "" : "****:****@") + parts.host + parts.path;
}
static int common_http_get_free_port() {
#ifdef _WIN32
WSADATA wsaData;
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {
return -1;
}
typedef SOCKET native_socket_t;
#define INVALID_SOCKET_VAL INVALID_SOCKET
#define CLOSE_SOCKET(s) closesocket(s)
#else
typedef int native_socket_t;
#define INVALID_SOCKET_VAL -1
#define CLOSE_SOCKET(s) close(s)
#endif
native_socket_t sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock == INVALID_SOCKET_VAL) {
#ifdef _WIN32
WSACleanup();
#endif
return -1;
}
struct sockaddr_in serv_addr;
std::memset(&serv_addr, 0, sizeof(serv_addr));
serv_addr.sin_family = AF_INET;
serv_addr.sin_addr.s_addr = htonl(INADDR_ANY);
serv_addr.sin_port = htons(0);
if (bind(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) != 0) {
CLOSE_SOCKET(sock);
#ifdef _WIN32
WSACleanup();
#endif
return -1;
}
#ifdef _WIN32
int namelen = sizeof(serv_addr);
#else
socklen_t namelen = sizeof(serv_addr);
#endif
if (getsockname(sock, (struct sockaddr*)&serv_addr, &namelen) != 0) {
CLOSE_SOCKET(sock);
#ifdef _WIN32
WSACleanup();
#endif
return -1;
}
int port = ntohs(serv_addr.sin_port);
CLOSE_SOCKET(sock);
#ifdef _WIN32
WSACleanup();
#endif
return port;
}
+89 -46
View File
@@ -686,59 +686,62 @@ value set_statement::execute_impl(context & ctx) {
return mk_val<value_undefined>();
}
static inline void bind_parameters(const std::string & name, const statements & this_args, const func_args & args, context & ctx) {
const size_t expected_count = this_args.size();
const size_t input_count = args.count();
JJ_DEBUG("Invoking '%s' with %zu input arguments (expected %zu)", name.c_str(), input_count, expected_count);
for (size_t i = 0; i < expected_count; ++i) {
if (i < input_count) {
if (is_stmt<identifier>(this_args[i])) {
// normal parameter
std::string param_name = cast_stmt<identifier>(this_args[i])->val;
value param_value = args.get_kwarg_or_pos(param_name, i);
JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), param_value->type().c_str());
ctx.set_val(param_name, param_value);
} else if (is_stmt<keyword_argument_expression>(this_args[i])) {
// default argument used as normal parameter
auto kwarg = cast_stmt<keyword_argument_expression>(this_args[i]);
if (!is_stmt<identifier>(kwarg->key)) {
throw std::runtime_error("Keyword argument key must be an identifier in '" + name + "'");
}
std::string param_name = cast_stmt<identifier>(kwarg->key)->val;
value param_value = args.get_kwarg_or_pos(param_name, i);
JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), param_value->type().c_str());
ctx.set_val(param_name, param_value);
} else {
throw std::runtime_error("Invalid parameter type in '" + name + "'");
}
} else {
auto & default_arg = this_args[i];
if (is_stmt<keyword_argument_expression>(default_arg)) {
auto kwarg = cast_stmt<keyword_argument_expression>(default_arg);
if (!is_stmt<identifier>(kwarg->key)) {
throw std::runtime_error("Keyword argument key must be an identifier in '" + name + "'");
}
std::string param_name = cast_stmt<identifier>(kwarg->key)->val;
JJ_DEBUG(" Binding parameter '%s' to default argument of type %s", param_name.c_str(), kwarg->val->type().c_str());
ctx.set_val(param_name, kwarg->val->execute(args.ctx));
} else {
throw std::runtime_error("Not enough arguments provided to '" + name + "'");
}
//std::string param_name = cast_stmt<identifier>(default_args[i])->val;
//JJ_DEBUG(" Binding parameter '%s' to default", param_name.c_str());
//ctx.var[param_name] = default_args[i]->execute(ctx);
}
}
}
value macro_statement::execute_impl(context & ctx) {
if (!is_stmt<identifier>(this->name)) {
throw std::runtime_error("Macro name must be an identifier");
}
std::string name = cast_stmt<identifier>(this->name)->val;
const func_handler func = [this, name, &ctx](const func_args & args) -> value {
size_t expected_count = this->args.size();
size_t input_count = args.count();
const func_handler func = [this, name](const func_args & args) -> value {
context macro_ctx(args.ctx); // new scope for macro execution
JJ_DEBUG("Invoking macro '%s' with %zu input arguments (expected %zu)", name.c_str(), input_count, expected_count);
context macro_ctx(ctx); // new scope for macro execution
// bind parameters
for (size_t i = 0; i < expected_count; ++i) {
if (i < input_count) {
if (is_stmt<identifier>(this->args[i])) {
// normal parameter
std::string param_name = cast_stmt<identifier>(this->args[i])->val;
value param_value = args.get_kwarg_or_pos(param_name, i);
JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), param_value->type().c_str());
macro_ctx.set_val(param_name, param_value);
} else if (is_stmt<keyword_argument_expression>(this->args[i])) {
// default argument used as normal parameter
auto kwarg = cast_stmt<keyword_argument_expression>(this->args[i]);
if (!is_stmt<identifier>(kwarg->key)) {
throw std::runtime_error("Keyword argument key must be an identifier in macro '" + name + "'");
}
std::string param_name = cast_stmt<identifier>(kwarg->key)->val;
value param_value = args.get_kwarg_or_pos(param_name, i);
JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), param_value->type().c_str());
macro_ctx.set_val(param_name, param_value);
} else {
throw std::runtime_error("Invalid parameter type in macro '" + name + "'");
}
} else {
auto & default_arg = this->args[i];
if (is_stmt<keyword_argument_expression>(default_arg)) {
auto kwarg = cast_stmt<keyword_argument_expression>(default_arg);
if (!is_stmt<identifier>(kwarg->key)) {
throw std::runtime_error("Keyword argument key must be an identifier in macro '" + name + "'");
}
std::string param_name = cast_stmt<identifier>(kwarg->key)->val;
JJ_DEBUG(" Binding parameter '%s' to default argument of type %s", param_name.c_str(), kwarg->val->type().c_str());
macro_ctx.set_val(param_name, kwarg->val->execute(ctx));
} else {
throw std::runtime_error("Not enough arguments provided to macro '" + name + "'");
}
//std::string param_name = cast_stmt<identifier>(default_args[i])->val;
//JJ_DEBUG(" Binding parameter '%s' to default", param_name.c_str());
//macro_ctx.var[param_name] = default_args[i]->execute(ctx);
}
}
bind_parameters(name, this->args, args, macro_ctx);
// execute macro body
JJ_DEBUG("Executing macro '%s' body with %zu statements", name.c_str(), this->body.size());
@@ -752,6 +755,46 @@ value macro_statement::execute_impl(context & ctx) {
return mk_val<value_undefined>();
}
value call_statement::execute_impl(context & ctx) {
auto call_expr = cast_stmt<call_expression>(this->call);
if (!call_expr) {
throw std::runtime_error("Call statement requires a valid call expression");
}
value callee_val = call_expr->callee->execute(ctx);
if (!is_val<value_func>(callee_val)) {
throw std::runtime_error("Callee is not a function: got " + callee_val->type());
}
auto * callee_func = cast_val<value_func>(callee_val);
context caller_ctx(ctx); // new scope for caller execution
const func_handler func = [this, caller_ctx = std::move(caller_ctx)](const func_args & args) -> value {
context block_ctx(caller_ctx); // new scope for block execution
bind_parameters("caller", this->caller_args, args, block_ctx);
JJ_DEBUG("Executing call body with %zu statements", this->body.size());
auto res = exec_statements(this->body, block_ctx);
JJ_DEBUG("Call body execution complete, result: %s", res->val_str.str().c_str());
return res;
};
context call_ctx(ctx);
call_ctx.set_val("caller", mk_val<value_func>("caller", func));
func_args args(call_ctx);
for (const auto & arg_expr : call_expr->args) {
auto arg_val = arg_expr->execute(ctx);
JJ_DEBUG(" Argument type: %s", arg_val->type().c_str());
args.push_back(arg_val);
}
JJ_DEBUG("Calling macro '%s' with %zu arguments", callee_func->name.c_str(), args.count());
return callee_func->invoke(args);
}
value member_expression::execute_impl(context & ctx) {
value object = this->object->execute(ctx);
+1
View File
@@ -552,6 +552,7 @@ struct call_statement : public statement {
for (const auto & arg : this->caller_args) chk_type<expression>(arg);
}
std::string type() const override { return "CallStatement"; }
value execute_impl(context & ctx) override;
};
struct ternary_expression : public expression {
+23 -23
View File
@@ -233,27 +233,27 @@ struct BuiltinRule {
};
static std::unordered_map<std::string, BuiltinRule> PRIMITIVE_RULES = {
{"boolean", {"(\"true\" | \"false\") space", {}}},
{"boolean", {"(\"true\" | \"false\")", {}}},
{"decimal-part", {"[0-9]{1,16}", {}}},
{"integral-part", {"[0] | [1-9] [0-9]{0,15}", {}}},
{"number", {"(\"-\"? integral-part) (\".\" decimal-part)? ([eE] [-+]? integral-part)? space", {"integral-part", "decimal-part"}}},
{"integer", {"(\"-\"? integral-part) space", {"integral-part"}}},
{"number", {"(\"-\"? integral-part) (\".\" decimal-part)? ([eE] [-+]? integral-part)?", {"integral-part", "decimal-part"}}},
{"integer", {"(\"-\"? integral-part)", {"integral-part"}}},
{"value", {"object | array | string | number | boolean | null", {"object", "array", "string", "number", "boolean", "null"}}},
{"object", {"\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? \"}\" space", {"string", "value"}}},
{"array", {"\"[\" space ( value (\",\" space value)* )? \"]\" space", {"value"}}},
{"uuid", {"\"\\\"\" [0-9a-fA-F]{8} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{12} \"\\\"\" space", {}}},
{"object", {"\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? space \"}\"", {"string", "value"}}},
{"array", {"\"[\" space ( value (\",\" space value)* )? space \"]\"", {"value"}}},
{"uuid", {"\"\\\"\" [0-9a-fA-F]{8} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{12} \"\\\"\"", {}}},
{"char", {"[^\"\\\\\\x7F\\x00-\\x1F] | [\\\\] ([\"\\\\bfnrt] | \"u\" [0-9a-fA-F]{4})", {}}},
{"string", {"\"\\\"\" char* \"\\\"\" space", {"char"}}},
{"null", {"\"null\" space", {}}},
{"string", {"\"\\\"\" char* \"\\\"\"", {"char"}}},
{"null", {"\"null\"", {}}},
};
static std::unordered_map<std::string, BuiltinRule> STRING_FORMAT_RULES = {
{"date", {"[0-9]{4} \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( \"0\" [1-9] | [1-2] [0-9] | \"3\" [0-1] )", {}}},
{"time", {"([01] [0-9] | \"2\" [0-3]) \":\" [0-5] [0-9] \":\" [0-5] [0-9] ( \".\" [0-9]{3} )? ( \"Z\" | ( \"+\" | \"-\" ) ( [01] [0-9] | \"2\" [0-3] ) \":\" [0-5] [0-9] )", {}}},
{"date-time", {"date \"T\" time", {"date", "time"}}},
{"date-string", {"\"\\\"\" date \"\\\"\" space", {"date"}}},
{"time-string", {"\"\\\"\" time \"\\\"\" space", {"time"}}},
{"date-time-string", {"\"\\\"\" date-time \"\\\"\" space", {"date-time"}}}
{"date-string", {"\"\\\"\" date \"\\\"\"", {"date"}}},
{"time-string", {"\"\\\"\" time \"\\\"\"", {"time"}}},
{"date-time-string", {"\"\\\"\" date-time \"\\\"\"", {"date-time"}}}
};
static bool is_reserved_name(const std::string & name) {
@@ -551,16 +551,16 @@ private:
}
return join_seq();
};
return _add_rule(name, "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space");
return _add_rule(name, "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\"");
}
/*
Returns a rule that matches a JSON string that is none of the provided strings
not_strings({"a"})
-> ["] ( [a] char+ | [^"a] char* )? ["] space
-> ["] ( [a] char+ | [^"a] char* )? ["]
not_strings({"and", "also"})
-> ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["] space
-> ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["]
*/
std::string _not_strings(const std::vector<std::string> & strings) {
@@ -619,7 +619,7 @@ private:
if (!trie.is_end_of_string) {
out << "?";
}
out << " [\"] space";
out << " [\"]";
return out.str();
}
@@ -725,7 +725,7 @@ private:
rule += " )?";
}
rule += " \"}\" space";
rule += " space \"}\"";
return rule;
}
@@ -858,14 +858,14 @@ public:
return _add_rule(rule_name, _generate_union_rule(name, schema_types));
}
if (schema.contains("const")) {
return _add_rule(rule_name, _generate_constant_rule(schema["const"]) + " space");
return _add_rule(rule_name, _generate_constant_rule(schema["const"]));
}
if (schema.contains("enum")) {
std::vector<std::string> enum_values;
for (const auto & v : schema["enum"]) {
enum_values.push_back(_generate_constant_rule(v));
}
return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ") space");
return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ")");
}
if ((schema_type.is_null() || schema_type == "object")
&& (schema.contains("properties") ||
@@ -933,7 +933,7 @@ public:
}
}
if (!enum_intersection.empty()) {
return _add_rule(rule_name, "(" + string_join(enum_intersection, " | ") + ") space");
return _add_rule(rule_name, "(" + string_join(enum_intersection, " | ") + ")");
}
}
return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json()));
@@ -948,7 +948,7 @@ public:
}
rule += visit(items[i], name + (name.empty() ? "" : "-") + "tuple-" + std::to_string(i));
}
rule += " \"]\" space";
rule += " space \"]\"";
return _add_rule(rule_name, rule);
}
std::string item_rule_name = visit(items, name + (name.empty() ? "" : "-") + "item");
@@ -956,7 +956,7 @@ public:
json max_items_json = schema.contains("maxItems") ? schema["maxItems"] : json();
int max_items = max_items_json.is_number_integer() ? max_items_json.get<int>() : std::numeric_limits<int>::max();
return _add_rule(rule_name, "\"[\" space " + build_repetition(item_rule_name, min_items, max_items, "\",\" space") + " \"]\" space");
return _add_rule(rule_name, "\"[\" space " + build_repetition(item_rule_name, min_items, max_items, "\",\" space") + " space \"]\"");
}
if ((schema_type.is_null() || schema_type == "string") && schema.contains("pattern")) {
return _visit_pattern(schema["pattern"], rule_name);
@@ -972,7 +972,7 @@ public:
std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char"));
int min_len = schema.contains("minLength") ? schema["minLength"].get<int>() : 0;
int max_len = schema.contains("maxLength") ? schema["maxLength"].get<int>() : std::numeric_limits<int>::max();
return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space");
return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\"");
}
if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) {
int64_t min_value = std::numeric_limits<int64_t>::min();
@@ -990,7 +990,7 @@ public:
std::stringstream out;
out << "(";
build_min_max_int(min_value, max_value, out);
out << ") space";
out << ")";
return _add_rule(rule_name, out.str());
}
if (schema.empty() || schema_type == "object") {
+202 -89
View File
@@ -6,13 +6,14 @@
#include "unicode.h"
#include <algorithm>
#include <deque>
#include <initializer_list>
#include <map>
#include <memory>
#include <nlohmann/json.hpp>
#include <regex>
#include <set>
#include <stdexcept>
#include <unordered_set>
// Trick to catch missing branches
template <typename T>
@@ -88,40 +89,7 @@ struct trie {
return match_result{match_result::NO_MATCH};
}
struct prefix_and_next {
std::vector<uint32_t> prefix;
std::vector<uint32_t> next_chars;
};
std::vector<prefix_and_next> collect_prefix_and_next() {
std::vector<uint32_t> prefix;
std::vector<prefix_and_next> result;
collect_prefix_and_next(0, prefix, result);
return result;
}
private:
void collect_prefix_and_next(size_t index, std::vector<uint32_t> & prefix, std::vector<prefix_and_next> & out) {
if (!nodes[index].is_word) {
if (!nodes[index].children.empty()) {
std::vector<uint32_t> chars;
chars.reserve(nodes[index].children.size());
for (const auto & p : nodes[index].children) {
chars.push_back(p.first);
}
out.emplace_back(prefix_and_next{prefix, chars});
}
}
for (const auto & p : nodes[index].children) {
uint32_t ch = p.first;
auto child = p.second;
prefix.push_back(ch);
collect_prefix_and_next(child, prefix, out);
prefix.pop_back();
}
}
size_t create_node() {
size_t index = nodes.size();
nodes.emplace_back();
@@ -153,6 +121,65 @@ struct trie {
}
};
// Aho-Corasick automaton
struct aho_corasick {
trie t;
std::vector<size_t> fail; // failure links
std::vector<size_t> order; // states in BFS order
std::vector<bool> terminal; // match states (directly or via a suffix link)
std::set<uint32_t> alphabet; // every character with a transition
aho_corasick(const std::vector<std::string> & strings) : t(strings) {
const auto & nodes = t.nodes;
const size_t n = nodes.size();
fail.assign(n, 0);
order.reserve(n);
std::deque<size_t> queue{ 0 };
while (!queue.empty()) {
size_t u = queue.front();
queue.pop_front();
order.push_back(u);
for (const auto & [ch, v] : nodes[u].children) {
if (u != 0) {
size_t f = fail[u];
while (f && nodes[f].children.find(ch) == nodes[f].children.end()) {
f = fail[f];
}
auto it = nodes[f].children.find(ch);
fail[v] = (it != nodes[f].children.end() && it->second != v) ? it->second : 0;
}
queue.push_back(v);
}
}
terminal.assign(n, false);
for (size_t u : order) {
terminal[u] = nodes[u].is_word || (u != 0 && terminal[fail[u]]);
}
for (const auto & node : nodes) {
for (const auto & [ch, v] : node.children) {
alphabet.insert(ch);
}
}
}
size_t num_states() const { return t.nodes.size(); }
bool is_terminal(size_t s) const { return terminal[s]; }
// follow failure links until a transition on `ch` exists.
size_t next(size_t state, uint32_t ch) const {
const auto & nodes = t.nodes;
while (state && nodes[state].children.find(ch) == nodes[state].children.end()) {
state = fail[state];
}
auto it = nodes[state].children.find(ch);
return it != nodes[state].children.end() ? it->second : 0;
}
};
static std::pair<uint32_t, size_t> parse_hex_escape(const std::string & str, size_t pos, int hex_count) {
if (pos + hex_count > str.length()) {
return {0, 0};
@@ -894,6 +921,10 @@ struct parser_executor {
common_peg_parse_result operator()(const common_peg_gbnf_parser & p) {
return arena.parse(p.child, ctx, start_pos);
}
common_peg_parse_result operator()(const common_peg_ac_parser & p) {
return arena.parse(p.child, ctx, start_pos);
}
};
common_peg_parse_result common_peg_arena::parse(common_peg_parse_context & ctx, size_t start) const {
@@ -962,7 +993,8 @@ void common_peg_arena::resolve_refs() {
std::is_same_v<T, common_peg_not_parser> ||
std::is_same_v<T, common_peg_tag_parser> ||
std::is_same_v<T, common_peg_atomic_parser> ||
std::is_same_v<T, common_peg_gbnf_parser>) {
std::is_same_v<T, common_peg_gbnf_parser> ||
std::is_same_v<T, common_peg_ac_parser>) {
p.child = resolve_ref(p.child);
} else if constexpr (std::is_same_v<T, common_peg_rule_parser>) {
p.child = resolve_ref(p.child);
@@ -992,12 +1024,12 @@ void common_peg_arena::resolve_refs() {
}
std::string common_peg_arena::dump(common_peg_parser_id id) const {
std::unordered_set<common_peg_parser_id> visited;
std::set<common_peg_parser_id> visited;
return dump_impl(id, visited);
}
std::string common_peg_arena::dump_impl(common_peg_parser_id id,
std::unordered_set<common_peg_parser_id> & visited) const {
std::set<common_peg_parser_id> & visited) const {
// Check for cycles
if (visited.count(id)) {
return "[cycle]";
@@ -1043,6 +1075,8 @@ std::string common_peg_arena::dump_impl(common_peg_parser_id
return "Atomic(" + dump_impl(p.child, visited) + ")";
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
return "Gbnf(" + p.grammar + ", " + dump_impl(p.child, visited) + ")";
} else if constexpr (std::is_same_v<T, common_peg_ac_parser>) {
return "Ac(" + string_join(p.delimiters, " | ") + ", " + dump_impl(p.child, visited) + ")";
} else if constexpr (std::is_same_v<T, common_peg_any_parser>) {
return "Any";
} else if constexpr (std::is_same_v<T, common_peg_space_parser>) {
@@ -1342,7 +1376,7 @@ common_peg_parser common_peg_parser_builder::json_object() {
common_peg_parser common_peg_parser_builder::json_array() {
return rule("json-array", [this]() {
auto ws = space();
auto elements = sequence({json(), zero_or_more(sequence({literal(","), ws, json()}))});
auto elements = sequence({json(), zero_or_more(sequence({ws, literal(","), ws, json()}))});
return sequence({
literal("["),
ws,
@@ -1452,6 +1486,13 @@ common_peg_parser common_peg_parser_builder::json_member(const std::string & key
});
}
common_peg_parser common_peg_parser_builder::ac(const common_peg_parser & p, const std::vector<std::string> & delimiters) {
if (delimiters.empty()) {
throw std::runtime_error("ac parser requires at least one delimiter");
}
return add(common_peg_ac_parser{p, delimiters});
}
static std::string gbnf_escape_char_class(uint32_t c) {
if (c == '-' || c == ']' || c == '[' || c == '\\') {
return "\\" + std::string(1, (char) c);
@@ -1502,61 +1543,118 @@ static std::string gbnf_escape_char_class(uint32_t c) {
return std::string(buf);
}
static std::string gbnf_excluding_pattern(const std::vector<std::string> & strings) {
trie matcher(strings);
auto pieces = matcher.collect_prefix_and_next();
std::string pattern;
std::string trailing; // optional proper-prefix of a delimiter, allowed only at the very end
for (size_t i = 0; i < pieces.size(); ++i) {
if (i > 0) {
pattern += " | ";
}
const auto & pre = pieces[i].prefix;
const auto & chars = pieces[i].next_chars;
std::string cls;
cls.reserve(chars.size());
for (uint32_t ch : chars) {
cls += gbnf_escape_char_class(ch);
}
if (!pre.empty()) {
std::string pre_literal = gbnf_format_literal(common_unicode_cpts_to_utf8(pre));
pattern += pre_literal + " [^" + cls + "]";
// Each interior alternative consumes a delimiter-prefix plus a disambiguating
// char, so the repetition alone cannot match a value that *ends* on a proper
// prefix of a delimiter (e.g. a trailing "\n" when the delimiter is
// "\n</parameter>\n"). The runtime until() (greedy first-match) accepts such
// values, so without this the grammar would reject input the parser accepts.
// Allow the value to terminate on any proper prefix as an optional tail.
// This makes the grammar a slight superset of the runtime language (a value
// may end on the longest prefix, which greedy first-match would not itself
// produce); harmless for constrained generation, which only needs to admit
// every runtime-valid string.
if (!trailing.empty()) {
trailing += " | ";
}
trailing += pre_literal;
} else {
pattern += "[^" + cls + "]";
}
static std::string gbnf_char_class(const std::vector<uint32_t> & chars, bool negate) {
std::string s = negate ? "[^" : "[";
for (uint32_t ch : chars) {
s += gbnf_escape_char_class(ch);
}
std::string result = "(" + pattern + ")*";
if (!trailing.empty()) {
result += " (" + trailing + ")?";
}
return result;
return s + "]";
}
static std::unordered_set<std::string> collect_reachable_rules(
static std::string gbnf_ac_grammar(
const common_grammar_builder & builder,
const std::string & prefix,
const std::vector<std::string> & strings,
const std::function<std::string(const std::vector<uint32_t> &,
const std::map<size_t, std::vector<uint32_t>> &,
const std::vector<uint32_t> &,
const std::function<std::string(size_t)> &)> & build_rule) {
aho_corasick ac(strings);
auto state_name = [&](size_t s) -> std::string {
if (s == 0) {
return prefix;
}
std::string num = std::to_string(s);
num = num.size() == 1 ? ("0" + num) : num;
return prefix + "-" + num;
};
for (size_t q = 0; q < ac.num_states(); q++) {
if (ac.is_terminal(q)) {
continue; // match states
}
std::map<size_t, std::vector<uint32_t>> buckets;
std::vector<uint32_t> completing; // chars that complete a delimiter
std::vector<uint32_t> specific; // chars with an explicit transition
for (uint32_t c : ac.alphabet) {
size_t d = ac.next(q, c);
if (ac.is_terminal(d)) {
completing.push_back(c);
specific.push_back(c);
} else if (d != 0) {
buckets[d].push_back(c); // specific non-root destination
specific.push_back(c);
}
}
builder.add_rule(state_name(q), build_rule(completing, buckets, specific, state_name));
}
// An empty delimiter makes the start state terminal. Emit an entry rule
// that matches the empty string so the returned reference stays valid.
if (ac.is_terminal(0)) {
builder.add_rule(prefix, "|");
}
return state_name(0);
}
// GBNF grammar matching strings that contain no string in `strings` as a
// substring. Emits the complement of an Aho-Corasick automaton DFA and returns
// the start state rule name.
//
// ref: https://github.com/ggml-org/llama.cpp/pull/24839
static std::string gbnf_excluding_grammar(const common_grammar_builder & builder,
const std::string & prefix,
const std::vector<std::string> & strings) {
return gbnf_ac_grammar(builder, prefix, strings,
[](const std::vector<uint32_t> & /*completing*/,
const std::map<size_t, std::vector<uint32_t>> & buckets,
const std::vector<uint32_t> & specific,
const std::function<std::string(size_t)> & state_name) {
// every state is accepting and completing chars get no
// alternative, so a forbidden string can never be matched
std::string rhs = "|";
for (const auto & [d, chars] : buckets) {
rhs += " " + gbnf_char_class(chars, false) + " " + state_name(d) + " |";
}
rhs += " " + gbnf_char_class(specific, true) + " " + state_name(0);
return rhs;
});
}
// GBNF grammar matching everything up to and including the first occurrence of
// any string in `strings`. Emits the Aho-Corasick automaton DFA and returns
// the start state rule name.
static std::string gbnf_including_grammar(const common_grammar_builder & builder,
const std::string & prefix,
const std::vector<std::string> & strings) {
return gbnf_ac_grammar(builder, prefix, strings,
[](const std::vector<uint32_t> & completing,
const std::map<size_t, std::vector<uint32_t>> & buckets,
const std::vector<uint32_t> & specific,
const std::function<std::string(size_t)> & state_name) {
std::vector<std::string> alts;
if (!completing.empty()) {
alts.push_back(gbnf_char_class(completing, false)); // terminate on match
}
for (const auto & [d, chars] : buckets) {
alts.push_back(gbnf_char_class(chars, false) + " " + state_name(d));
}
// every other character keeps scanning from the start state
alts.push_back(gbnf_char_class(specific, true) + " " + state_name(0));
return string_join(alts, " | ");
});
}
static std::set<std::string> collect_reachable_rules(
const common_peg_arena & arena,
const common_peg_parser_id & rule
) {
std::unordered_set<std::string> reachable;
std::unordered_set<std::string> visited;
std::set<std::string> reachable;
std::set<std::string> visited;
std::function<void(common_peg_parser_id)> visit = [&](common_peg_parser_id id) {
const auto & parser = arena.get(id);
@@ -1588,6 +1686,7 @@ static std::unordered_set<std::string> collect_reachable_rules(
std::is_same_v<T, common_peg_tag_parser> ||
std::is_same_v<T, common_peg_atomic_parser> ||
std::is_same_v<T, common_peg_gbnf_parser> ||
std::is_same_v<T, common_peg_ac_parser> ||
std::is_same_v<T, common_peg_schema_parser>) {
visit(p.child);
} else if constexpr (std::is_same_v<T, common_peg_rule_parser>) {
@@ -1765,7 +1864,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
if (p.delimiters.empty()) {
return ".*";
}
return gbnf_excluding_pattern(p.delimiters);
return gbnf_excluding_grammar(builder, "until-" + std::to_string(id), p.delimiters);
} else if constexpr (std::is_same_v<T, common_peg_schema_parser>) {
if (schema_delegates(p)) {
return to_gbnf(p.child);
@@ -1782,6 +1881,8 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
return to_gbnf(p.child);
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
return p.grammar;
} else if constexpr (std::is_same_v<T, common_peg_ac_parser>) {
return gbnf_including_grammar(builder, "ac-" + std::to_string(id), p.delimiters);
} else {
static_assert(is_always_false_v<T>);
}
@@ -1789,7 +1890,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
};
// Collect reachable rules
std::unordered_set<std::string> reachable_rules;
std::set<std::string> reachable_rules;
if (lazy) {
// Collect rules reachable from trigger rules
@@ -1918,6 +2019,8 @@ static nlohmann::json serialize_parser_variant(const common_peg_parser_variant &
};
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
return json{{"type", "gbnf"}, {"child", p.child}, {"grammar", p.grammar}};
} else if constexpr (std::is_same_v<T, common_peg_ac_parser>) {
return json{{"type", "ac"}, {"child", p.child}, {"delimiters", p.delimiters}};
}
}, variant);
}
@@ -2090,6 +2193,16 @@ static common_peg_parser_variant deserialize_parser_variant(const nlohmann::json
};
}
if (type == "ac") {
if (!j.contains("child") || !j.contains("delimiters") || !j["delimiters"].is_array() || j["delimiters"].empty()) {
throw std::runtime_error("ac parser requires 'child' and a non-empty 'delimiters' array");
}
return common_peg_ac_parser{
j["child"].get<common_peg_parser_id>(),
j["delimiters"].get<std::vector<std::string>>(),
};
}
throw std::runtime_error("Unknown parser type: " + type);
}
+16 -3
View File
@@ -3,8 +3,8 @@
#include <nlohmann/json_fwd.hpp>
#include <memory>
#include <set>
#include <unordered_map>
#include <unordered_set>
#include <string>
#include <string_view>
#include <functional>
@@ -275,6 +275,11 @@ struct common_peg_gbnf_parser {
std::string grammar;
};
struct common_peg_ac_parser {
common_peg_parser_id child;
std::vector<std::string> delimiters;
};
// Variant holding all parser types
using common_peg_parser_variant = std::variant<
common_peg_epsilon_parser,
@@ -296,7 +301,8 @@ using common_peg_parser_variant = std::variant<
common_peg_ref_parser,
common_peg_atomic_parser,
common_peg_tag_parser,
common_peg_gbnf_parser
common_peg_gbnf_parser,
common_peg_ac_parser
>;
class common_peg_arena {
@@ -335,7 +341,7 @@ class common_peg_arena {
friend class common_peg_parser_builder;
private:
std::string dump_impl(common_peg_parser_id id, std::unordered_set<common_peg_parser_id> & visited) const;
std::string dump_impl(common_peg_parser_id id, std::set<common_peg_parser_id> & visited) const;
common_peg_parser_id add_parser(common_peg_parser_variant parser);
void add_rule(const std::string & name, common_peg_parser_id id);
@@ -514,6 +520,13 @@ class common_peg_parser_builder {
// the child's grammar. Parsing delegates entirely to the child.
common_peg_parser gbnf(const common_peg_parser & p, const std::string & grammar) { return add(common_peg_gbnf_parser{p, grammar}); }
// Wraps a child parser but emits a GBNF grammar built from the Aho-Corasick
// automaton of `delimiters`, matching everything up to and including the
// first delimiter. Parsing delegates entirely to the child, which is
// responsible for consuming the delimiter (e.g. until(D) + literal(D)).
common_peg_parser ac(const common_peg_parser & p, const std::vector<std::string> & delimiters);
common_peg_parser ac(const common_peg_parser & p, const std::string & delimiter) { return ac(p, std::vector<std::string>{delimiter}); }
void set_root(const common_peg_parser & p);
common_peg_arena build();
+102 -35
View File
@@ -905,7 +905,13 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
int32_t n_embd = 0;
bool is_mem_shared = false;
// One MTP draft driver, three modes (set once in the ctor):
// is_mem_shared (gemma4): shares the target KV, runs all heads in one graph.
// chain_heads (step35): n_mtp_layers trained heads, one per draft step.
// neither (qwen35 / qwen35moe): a single trained MTP head.
int32_t n_mtp_layers = 1;
bool is_mem_shared = false; // gemma4
bool chain_heads = false; // derived in the ctor: n_mtp_layers > 1 && !is_mem_shared
// Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1.
// The last h-row of one process() call needs the first token of the NEXT
@@ -920,10 +926,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
std::vector<std::vector<float>> verify_h;
std::vector<int32_t> verify_h_rows;
// Per-seq draft length from the last draft() call, used in accept() to
// roll back ctx_dft's recurrent state past the AR draft's redundant
// pre-advancement before process() mirrored the verify batch.
std::vector<uint16_t> last_n_drafted;
std::vector<int> i_last;
std::vector<std::vector<float>> chain_h;
common_speculative_impl_draft_mtp(const common_params_speculative & params, uint32_t n_seq)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, n_seq)
@@ -936,6 +940,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
n_embd = llama_model_n_embd_out(llama_get_model(ctx_dft));
GGML_ASSERT(n_embd == llama_model_n_embd(llama_get_model(ctx_tgt)) &&
"MTP input row width must match the target h_nextn width");
n_mtp_layers = std::max(1, (int) llama_model_n_layer_nextn(llama_get_model(ctx_dft)));
LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__);
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling);
@@ -982,16 +987,25 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true);
is_mem_shared = llama_get_ctx_other(ctx_dft) == ctx_tgt;
chain_heads = n_mtp_layers > 1 && !is_mem_shared;
if (chain_heads) {
this->params.n_max = std::min(this->params.n_max, n_mtp_layers);
chain_h.assign(n_seq, {});
for (auto & c : chain_h) {
c.reserve((size_t) (this->params.n_max + 1) * n_embd);
}
}
pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));
i_last.assign(n_seq, -1);
i_batch_beg.assign(n_seq, -1);
i_batch_end.assign(n_seq, -1);
verify_h.assign(n_seq, {});
verify_h_rows.assign(n_seq, 0);
last_n_drafted.assign(n_seq, 0);
}
~common_speculative_impl_draft_mtp() override {
@@ -1097,9 +1111,34 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
set_h(i_batch_beg[seq_id], pending_h[seq_id].data());
}
const int32_t rc = llama_decode(ctx_dft, batch);
if (rc != 0) {
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]);
auto * mem_dft = llama_get_memory(ctx_dft);
bool ok = true;
for (int head = 0; head < n_mtp_layers; ++head) {
if (chain_heads) {
// ref: https://github.com/ggml-org/llama.cpp/pull/24340/changes#r3413498544
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (i_batch_beg[seq_id] < 0) {
continue;
}
llama_memory_seq_rm(mem_dft, seq_id, batch_in.pos[i_batch_beg[seq_id]], -1);
}
llama_set_nextn_layer_offset(ctx_dft, head);
}
const int32_t rc = llama_decode(ctx_dft, batch);
if (rc != 0) {
LOG_ERR("%s: llama_decode(ctx_dft) head=%d failed rc=%d (pos=%d)\n",
__func__, head, (int) rc, (int) batch_in.pos[0]);
ok = false;
break;
}
}
if (chain_heads) {
llama_set_nextn_layer_offset(ctx_dft, 0); // restore default for non-draft decodes
}
if (!ok) {
return false;
}
}
@@ -1134,7 +1173,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
int n_drafting = 0;
std::vector<bool> drafting(n_seq);
const float * h_row = nullptr;
const size_t row_bytes = (size_t) n_embd * sizeof(float);
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
@@ -1149,22 +1187,43 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
common_sampler_reset(smpls[seq_id].get());
common_batch_add(batch, dp.id_last, dp.n_past, { seq_id }, true);
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, pending_h[seq_id].data(), row_bytes);
h_row = pending_h[seq_id].data();
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
}
i_last[seq_id] = batch.n_tokens - 1;
int ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
return;
if (chain_heads) {
chain_h[seq_id].assign(pending_h[seq_id].begin(), pending_h[seq_id].end());
}
}
int i = 0;
while (n_drafting > 0) {
int i_batch = 0;
// each step decodes under a different head, i.e. a different decoder layer, and
// KV is per layer. process() filled this layer's KV only for positions < n_past
// (prompt + accepted prefix) — nothing in the draft region yet. so reset the
// draft region (the seq_rm lower bound is n_past, leaving the prompt KV intact)
// and select head i so it rebuilds its own layer's KV there; decoding just the
// latest token would leave its attention reading cells only another head wrote.
if (chain_heads) {
auto * mem_dft = llama_get_memory(ctx_dft);
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (drafting[seq_id]) {
llama_memory_seq_rm(mem_dft, seq_id, dparams[seq_id].n_past, -1);
}
}
llama_set_nextn_layer_offset(ctx_dft, i);
}
int ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
break;
}
// rebuild the batch for the next step: the growing-KV paths re-add only the
// new token (the KV already holds the prefix), while chained heads re-add the
// whole prefix at the next head. dropped sequences are simply not re-added.
common_batch_clear(batch);
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
@@ -1174,9 +1233,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
auto * smpl = smpls[seq_id].get();
common_sampler_sample(smpl, ctx_dft, i_batch, true);
h_row = llama_get_embeddings_nextn_ith(ctx_dft, i_batch);
++i_batch;
common_sampler_sample(smpl, ctx_dft, i_last[seq_id], true);
const float * h_row = llama_get_embeddings_nextn_ith(ctx_dft, i_last[seq_id]);
const auto * cur_p = common_sampler_get_candidates(smpl, true);
@@ -1210,30 +1268,41 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
continue;
}
if (is_mem_shared) {
if (chain_heads) {
// ref: https://github.com/ggml-org/llama.cpp/pull/24340#discussion_r3448031546
chain_h[seq_id].insert(chain_h[seq_id].end(), h_row, h_row + n_embd);
const int n_rows = (int) result.size() + 1; // id_last + tokens drafted so far
for (int t = 0; t < n_rows; ++t) {
const llama_token tok = (t == 0) ? dp.id_last : result[t - 1];
common_batch_add(batch, tok, dp.n_past + t, { seq_id }, t == n_rows - 1);
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd,
chain_h[seq_id].data() + (size_t) t * n_embd, row_bytes);
}
} else if (is_mem_shared) {
// note: with shared memory (e.g. Gemma4 assistants) we use the same position for all draft tokens
// ref: https://github.com/huggingface/transformers/blob/effde20942e3f82a1b97449f60b3a48c5ff96145/docs/source/en/model_doc/gemma4_assistant.md?plain=1#L36-L37
common_batch_add(batch, id, dp.n_past, { seq_id }, true);
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, h_row, row_bytes);
} else {
common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true);
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, h_row, row_bytes);
}
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
i_last[seq_id] = batch.n_tokens - 1;
}
if (batch.n_tokens == 0) {
break;
}
// evaluate the drafted tokens on the draft model
ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
break;
}
++i;
}
if (chain_heads) {
llama_set_nextn_layer_offset(ctx_dft, 0); // restore default for non-draft decodes
}
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
auto & dp = dparams[seq_id];
if (!dp.drafting) {
@@ -1243,8 +1312,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
if (dp.result->size() < (size_t) params.n_min) {
dp.result->clear();
}
last_n_drafted[seq_id] = (uint16_t) dp.result->size();
}
}
@@ -1857,7 +1924,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
bool has_draft_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE));
bool has_draft_eagle3 = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3)) && params.draft.ctx_dft != nullptr;
bool has_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr;
bool has_draft_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr;
@@ -1895,7 +1962,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
if (has_draft_eagle3) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, params));
}
if (has_mtp) {
if (has_draft_mtp) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, params));
}
}
+2
View File
@@ -96,6 +96,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
"GraniteMoeHybridForCausalLM": "granite",
"GraniteMoeSharedForCausalLM": "granite",
"GraniteSpeechForConditionalGeneration": "granite",
"GraniteSpeechPlusForConditionalGeneration": "granite",
"Grok1ForCausalLM": "grok",
"GrokForCausalLM": "grok",
"GroveMoeForCausalLM": "grovemoe",
@@ -261,6 +262,7 @@ MMPROJ_MODEL_MAP: dict[str, str] = {
"GlmasrModel": "ultravox",
"Granite4VisionForConditionalGeneration": "granite",
"GraniteSpeechForConditionalGeneration": "granite",
"GraniteSpeechPlusForConditionalGeneration": "granite",
"HunYuanVLForConditionalGeneration": "hunyuan",
"Idefics3ForConditionalGeneration": "smolvlm",
"InternVisionModel": "internvl",
+1 -1
View File
@@ -126,7 +126,7 @@ class BailingMoeV2Model(TextModel):
if (rope_dim := hparams.get("head_dim")) is None:
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)))
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.rope_parameters.get("partial_rotary_factor", 0.5)))
self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
+7 -1
View File
@@ -1119,8 +1119,10 @@ class TextModel(ModelBase):
rope_theta = self.find_hparam(["global_rope_theta", "rope_global_theta", "rope_theta_global", "rope_theta", "rotary_emb_base"], optional=True)
local_rope_theta = self.find_hparam(["local_rope_theta", "rope_local_theta", "rope_theta_local", "swa_rope_theta", "rope_local_base_freq"], optional=True)
partial_rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct", "rope_percent"], optional=True)
original_max_position_embeddings = self.find_hparam(["original_max_position_embeddings"], optional=True)
# Ensure "rope_theta" and "rope_type" is mirrored in rope_parameters
# Ensure global params are mirrored in rope_parameters
if "full_attention" not in self.rope_parameters and "sliding_attention" not in self.rope_parameters:
if local_rope_theta is not None:
self.rope_parameters["sliding_attention"] = {"rope_theta": local_rope_theta}
@@ -1128,6 +1130,10 @@ class TextModel(ModelBase):
self.rope_parameters["rope_theta"] = rope_theta
if "rope_type" not in self.rope_parameters and (rope_type := self.rope_parameters.get("type")) is not None:
self.rope_parameters["rope_type"] = rope_type
if "partial_rotary_factor" not in self.rope_parameters and partial_rotary_factor is not None:
self.rope_parameters["partial_rotary_factor"] = partial_rotary_factor
if "original_max_position_embeddings" not in self.rope_parameters and original_max_position_embeddings is not None:
self.rope_parameters["original_max_position_embeddings"] = original_max_position_embeddings
@classmethod
def __init_subclass__(cls):
+1 -1
View File
@@ -148,7 +148,7 @@ class ChatGLMModel(TextModel):
rope_dim = self.hparams["attention_dim"]
else:
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)))
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.rope_parameters.get("partial_rotary_factor", 0.5)))
self.gguf_writer.add_add_bos_token(False)
rope_freq = 10000
if "rope_ratio" in self.hparams:
+1 -1
View File
@@ -161,7 +161,7 @@ class DeciModel(TextModel):
factor = rope_params.get("factor", 8.0)
low_freq_factor = rope_params.get("low_freq_factor", 1.0)
high_freq_factor = rope_params.get("high_freq_factor", 4.0)
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)
old_context_len = rope_params.get("original_max_position_embeddings", 8192)
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
+3 -3
View File
@@ -24,7 +24,7 @@ class ExaoneModel(TextModel):
assert (hparams["activation_function"] == "silu")
rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"], optional=True)
rotary_factor = self.rope_parameters.get("partial_rotary_factor")
rotary_factor = rotary_factor if rotary_factor is not None else 1.0
self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"])))
@@ -39,7 +39,7 @@ class ExaoneModel(TextModel):
factor = rope_params.get("factor", 8.0)
low_freq_factor = rope_params.get("low_freq_factor", 1.0)
high_freq_factor = rope_params.get("high_freq_factor", 4.0)
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)
old_context_len = rope_params.get("original_max_position_embeddings", 8192)
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
@@ -104,7 +104,7 @@ class Exaone4Model(TextModel):
factor = rope_params.get("factor", 16.0)
low_freq_factor = rope_params.get("low_freq_factor", 1.0)
high_freq_factor = rope_params.get("high_freq_factor", 4.0)
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)
old_context_len = rope_params.get("original_max_position_embeddings", 8192)
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
+1 -1
View File
@@ -693,7 +693,7 @@ class Gemma4Model(Gemma3Model):
self.gguf_writer.add_head_count_kv(value_arr)
# handle n_rot differently for global vs swa layers
partial_rotary_factor_swa = self.hparams.get("partial_rotary_factor", 1.0)
partial_rotary_factor_swa = self.rope_parameters.get("partial_rotary_factor", 1.0)
n_rot_full = int(head_dim_full) # "proportional" is used, see generate_extra_tensors
n_rot_swa = int(head_dim_swa * partial_rotary_factor_swa)
self.gguf_writer.add_rope_dimension_count(n_rot_full)
+2 -2
View File
@@ -124,7 +124,7 @@ class Glm4MoeModel(TextModel):
self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
)
self.gguf_writer.add_rope_dimension_count(
int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5))
int(rope_dim * self.rope_parameters.get("partial_rotary_factor", 0.5))
)
# MoE parameters - Use only routed expert count (shared experts handled separately)
@@ -226,7 +226,7 @@ class GlmMoeDsaModel(DeepseekV2Model):
super().set_gguf_parameters()
rope_dim = self.hparams["qk_rope_head_dim"]
partial_rotary_factor = self.hparams.get("partial_rotary_factor", 1.0)
partial_rotary_factor = self.rope_parameters.get("partial_rotary_factor", 1.0)
self.gguf_writer.add_rope_dimension_count(int(rope_dim * partial_rotary_factor))
# NextN/MTP prediction layers
+28
View File
@@ -348,6 +348,34 @@ class GraniteSpeechMmprojModel(MmprojModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("GraniteSpeechPlusForConditionalGeneration")
class GraniteSpeechPlusMmprojModel(GraniteSpeechMmprojModel):
"""Conversion for GraniteSpeechPlus - extends GraniteSpeech with feature layer concatenation"""
has_vision_encoder = False
has_audio_encoder = True
def set_gguf_parameters(self):
assert self.hparams_audio is not None
super().set_gguf_parameters()
# Add feature_layer if present in encoder config
if feature_layers := self.hparams_audio.get("cat_hidden_layers"):
self.gguf_writer.add_audio_feature_layers(feature_layers)
logger.info(f"gguf: audio feature_layers = {feature_layers}")
# Validate projector dimension matches concatenated encoder output
hidden_dim = self.hparams_audio["hidden_dim"]
expected_dim = hidden_dim * (len(feature_layers) + 1)
projector_dim = self.global_config["projector_config"]["encoder_hidden_size"]
if projector_dim != expected_dim:
raise ValueError(
f"Projector encoder_hidden_size ({projector_dim}) does not match "
f"expected concatenated dimension ({expected_dim}). "
f"Expected: hidden_dim ({hidden_dim}) * (len(feature_layers) + 1) = {expected_dim}"
)
@ModelBase.register("Granite4VisionForConditionalGeneration")
class Granite4VisionMmprojModel(MmprojModel):
has_vision_encoder = True
+1 -1
View File
@@ -289,7 +289,7 @@ class LlamaModel(TextModel):
factor = rope_params.get("factor", 8.0)
low_freq_factor = rope_params.get("low_freq_factor", 1.0)
high_freq_factor = rope_params.get("high_freq_factor", 4.0)
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)
old_context_len = rope_params.get("original_max_position_embeddings", 8192)
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
+1 -1
View File
@@ -154,7 +154,7 @@ class MimoV2Model(TextModel):
self.gguf_writer.add_expert_count(self.hparams["n_routed_experts"])
self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"])
rope_dim = int(self.hparams["head_dim"] * self.hparams["partial_rotary_factor"])
rope_dim = int(self.hparams["head_dim"] * self.rope_parameters["partial_rotary_factor"])
self.gguf_writer.add_rope_dimension_count(rope_dim)
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("layernorm_epsilon", 1e-5))
+6 -10
View File
@@ -32,11 +32,9 @@ class MiniCPMModel(TextModel):
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
rope_dims = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
rope_scaling = self.find_hparam(['rope_scaling'], True)
if rope_scaling is not None:
long_factors = rope_scaling.get('long_factor', None)
short_factors = rope_scaling.get('short_factor', None)
long_factors = self.rope_parameters.get('long_factor')
short_factors = self.rope_parameters.get('short_factor')
if long_factors or short_factors:
if long_factors is None or short_factors is None:
raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor')
@@ -85,13 +83,11 @@ class MiniCPM3Model(TextModel):
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
rope_scaling = self.find_hparam(['rope_scaling'], True)
if rope_scaling is not None:
long_factors = self.rope_parameters.get('long_factor')
short_factors = self.rope_parameters.get('short_factor')
if long_factors or short_factors:
rope_dims = self.hparams["qk_rope_head_dim"]
long_factors = rope_scaling.get('long_factor', None)
short_factors = rope_scaling.get('short_factor', None)
if long_factors is None or short_factors is None:
raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor')
+4 -3
View File
@@ -125,17 +125,18 @@ class NemotronModel(TextModel):
self.gguf_writer.add_layer_norm_eps(f_norm_eps)
# * Partial RoPE
rot_pct = self.find_hparam(["partial_rotary_factor", "rope_pct", "rope_percent"])
rot_pct = self.rope_parameters["partial_rotary_factor"]
n_embd = self.find_hparam(["hidden_size", "n_embd"])
n_head = self.find_hparam(["num_attention_heads", "n_head"])
self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head)
# * RopeScaling for Nemotron
if "rope_scaling" not in self.hparams or self.hparams["rope_scaling"] is None:
factor = self.hparams.get("factor") or self.rope_parameters.get("factor")
if factor is None:
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
else:
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(self.hparams["factor"])
self.gguf_writer.add_rope_scaling_factor(factor)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# * Adding +1 to LayerNorm's weights here to implement layernorm1p w/o changing anything on the GGML engine side
+9 -11
View File
@@ -18,7 +18,7 @@ class Phi2Model(TextModel):
model_arch = gguf.MODEL_ARCH.PHI2
def set_gguf_parameters(self):
rot_pct = self.find_hparam(["partial_rotary_factor"])
rot_pct = self.rope_parameters["partial_rotary_factor"]
n_embd = self.find_hparam(["hidden_size", "n_embd"])
n_head = self.find_hparam(["num_attention_heads", "n_head"])
@@ -149,8 +149,8 @@ class Phi3MiniModel(TextModel):
n_head_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"])
rms_eps = self.find_hparam(["rms_norm_eps"])
max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"])
orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"])
rot_pct = self.hparams.get("partial_rotary_factor", 1.0)
orig_max_pos_embds = self.rope_parameters["original_max_position_embeddings"]
rot_pct = self.rope_parameters.get("partial_rotary_factor", 1.0)
rope_dims = int(rot_pct * n_embd) // n_head
self.gguf_writer.add_context_length(max_pos_embds)
@@ -174,18 +174,19 @@ class Phi3MiniModel(TextModel):
n_embd = self.find_hparam(["hidden_size", "n_embd"])
n_head = self.find_hparam(["num_attention_heads", "n_head"])
max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"])
orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"])
rot_pct = self.hparams.get("partial_rotary_factor", 1.0)
orig_max_pos_embds = self.rope_parameters["original_max_position_embeddings"]
rot_pct = self.rope_parameters.get("partial_rotary_factor", 1.0)
rope_dims = int(rot_pct * n_embd) // n_head
# write rope scaling for long context (128k) model
rope_scaling = self.find_hparam(['rope_scaling'], True)
if rope_scaling is None:
long_factors = self.rope_parameters.get('long_factor')
short_factors = self.rope_parameters.get('short_factor')
if not long_factors:
return
scale = max_pos_embds / orig_max_pos_embds
rope_scaling_type = rope_scaling.get('rope_type', rope_scaling.get('type', '')).lower()
rope_scaling_type = self.rope_parameters.get('rope_type', '').lower()
if len(rope_scaling_type) == 0:
raise KeyError('Missing the required key rope_scaling.type')
@@ -198,9 +199,6 @@ class Phi3MiniModel(TextModel):
self.gguf_writer.add_rope_scaling_attn_factors(attn_factor)
long_factors = rope_scaling.get('long_factor', None)
short_factors = rope_scaling.get('short_factor', None)
if long_factors is None or short_factors is None:
raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor')
+1 -1
View File
@@ -280,7 +280,7 @@ class Qwen3NextModel(Qwen2MoeModel):
self.gguf_writer.add_full_attention_interval(self.hparams.get("full_attention_interval", 4))
if (rope_dim := self.hparams.get("head_dim")) is None:
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.25)))
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.rope_parameters.get("partial_rotary_factor", 0.25)))
@classmethod
def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None:
+1 -1
View File
@@ -28,7 +28,7 @@ class StableLMModel(TextModel):
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"])
rotary_factor = self.rope_parameters["partial_rotary_factor"]
self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"])))
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(hparams["num_key_value_heads"])
+1 -1
View File
@@ -314,7 +314,7 @@ class Step35Model(TextModel):
factor = float(rope_params.get("factor", 8.0))
low_freq_factor = float(rope_params.get("low_freq_factor", 1.0))
high_freq_factor = float(rope_params.get("high_freq_factor", 4.0))
old_context_len = int(rope_params.get("original_max_position_embeddings", self.hparams.get("original_max_position_embeddings", 8192)))
old_context_len = int(rope_params.get("original_max_position_embeddings", 8192))
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
+1 -1
View File
@@ -29,7 +29,7 @@ With Termux, you can install and run `llama.cpp` as if the environment were Linu
```
$ apt update && apt upgrade -y
$ apt install git cmake
$ apt install git cmake libandroid-spawn
```
Then, follow the [build instructions](https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md), specifically for CMake.
+21 -21
View File
@@ -198,18 +198,18 @@ class BuiltinRule:
SPACE_RULE = '| " " | "\\n"{1,2} [ \\t]{0,20}'
PRIMITIVE_RULES = {
'boolean' : BuiltinRule('("true" | "false") space', []),
'boolean' : BuiltinRule('("true" | "false")', []),
'decimal-part' : BuiltinRule('[0-9]{1,16}', []),
'integral-part': BuiltinRule('[0] | [1-9] [0-9]{0,15}', []),
'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']),
'integer' : BuiltinRule('("-"? integral-part) space', ['integral-part']),
'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?', ['integral-part', 'decimal-part']),
'integer' : BuiltinRule('("-"? integral-part)', ['integral-part']),
'value' : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']),
'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']),
'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']),
'uuid' : BuiltinRule(r'"\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\"" space', []),
'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? space "}"', ['string', 'value']),
'array' : BuiltinRule('"[" space ( value ("," space value)* )? space "]"', ['value']),
'uuid' : BuiltinRule(r'"\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\""', []),
'char' : BuiltinRule(r'[^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})', []),
'string' : BuiltinRule(r'"\"" char* "\"" space', ['char']),
'null' : BuiltinRule('"null" space', []),
'string' : BuiltinRule(r'"\"" char* "\""', ['char']),
'null' : BuiltinRule('"null"', []),
}
# TODO: support "uri", "email" string formats
@@ -217,9 +217,9 @@ STRING_FORMAT_RULES = {
'date' : BuiltinRule('[0-9]{4} "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []),
'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9]{3} )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []),
'date-time' : BuiltinRule('date "T" time', ['date', 'time']),
'date-string' : BuiltinRule('"\\"" date "\\"" space', ['date']),
'time-string' : BuiltinRule('"\\"" time "\\"" space', ['time']),
'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']),
'date-string' : BuiltinRule('"\\"" date "\\""', ['date']),
'time-string' : BuiltinRule('"\\"" time "\\""', ['time']),
'date-time-string': BuiltinRule('"\\"" date-time "\\""', ['date-time']),
}
DOTALL = '[\\U00000000-\\U0010FFFF]'
@@ -319,7 +319,7 @@ class SchemaConverter:
out.append(f'[^"{"".join(rejects)}] {char_rule}*')
visit(trie)
out.append(f' ){"" if trie.is_end_of_string else "?"} ["] space')
out.append(f' ){"" if trie.is_end_of_string else "?"} ["]')
return ''.join(out)
def _add_rule(self, name, rule):
@@ -549,7 +549,7 @@ class SchemaConverter:
return self._add_rule(
name,
to_rule(transform()) if self._raw_pattern \
else "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space")
else "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\"")
def _resolve_ref(self, ref):
@@ -580,10 +580,10 @@ class SchemaConverter:
return self._add_rule(rule_name, self._generate_union_rule(name, [{**schema, 'type': t} for t in schema_type]))
elif 'const' in schema:
return self._add_rule(rule_name, self._generate_constant_rule(schema['const']) + ' space')
return self._add_rule(rule_name, self._generate_constant_rule(schema['const']))
elif 'enum' in schema:
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + ') space'
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + ')'
return self._add_rule(rule_name, rule)
elif schema_type in (None, 'object') and \
@@ -624,7 +624,7 @@ class SchemaConverter:
enum_intersection &= s
if enum_intersection:
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in sorted(enum_intersection))) + ') space'
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in sorted(enum_intersection))) + ')'
return self._add_rule(rule_name, rule)
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None))
@@ -638,12 +638,12 @@ class SchemaConverter:
' "," space '.join(
self.visit(item, f'{name}{"-" if name else ""}tuple-{i}')
for i, item in enumerate(items)) +
' "]" space')
' space "]"')
else:
item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item')
min_items = schema.get("minItems", 0)
max_items = schema.get("maxItems")
return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' "]" space')
return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' space "]"')
elif schema_type in (None, 'string') and 'pattern' in schema:
return self._visit_pattern(schema['pattern'], rule_name)
@@ -663,7 +663,7 @@ class SchemaConverter:
min_len = schema.get('minLength', 0)
max_len = schema.get('maxLength')
return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space')
return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\""')
elif schema_type in (None, 'integer') and \
('minimum' in schema or 'exclusiveMinimum' in schema or 'maximum' in schema or 'exclusiveMaximum' in schema):
@@ -680,7 +680,7 @@ class SchemaConverter:
out = ["("]
_generate_min_max_int(min_value, max_value, out)
out.append(") space")
out.append(")")
return self._add_rule(rule_name, ''.join(out))
elif (schema_type == 'object') or (len(schema) == 0):
@@ -765,7 +765,7 @@ class SchemaConverter:
rule += ' )'
rule += ' )?'
rule += ' "}" space'
rule += ' space "}"'
return rule
+5 -6
View File
@@ -2417,15 +2417,14 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
// Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size
GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size);
parallel_for_ggml(params, n_batch, [&](int begin, int end) {
for (int batch_idx = begin; batch_idx < end; ++batch_idx) {
parallel_for_ggml(params, n_batch * M, [&](int begin, int end) {
for (int idx = begin; idx < end; ++idx) {
int batch_idx = idx / M;
int m = idx % M;
int64_t src1_offset = ggml_batch_offset(src1, batch_idx, ne2);
const float * A_data = (const float *)((const char *)src1->data + src1_offset);
char * wdata_batch = (char *)wdata + batch_idx * M * row_size_A;
for (int m = 0; m < M; ++m) {
from_float<vec_dot_type>(A_data + m * K, wdata_batch + m * row_size_A, K);
}
from_float<vec_dot_type>(A_data + m * K, wdata_batch + m * row_size_A, K);
}
});
});
+10 -9
View File
@@ -183,24 +183,25 @@ static inline void hvx_transpose_32x32_f32(HVX_Vector m[32]) {
// transposed into VTCM.
//
// VTCM layouts (per thread):
// src1_T : {d_inner_per_thread, d_conv} staged once per launch (small).
// src0_T : {d_inner_tile, ncs} staged per d_inner-tile.
// src1_T : {d_inner_stride, d_conv} - staged once per launch (small).
// src0_T : {d_inner_tile, ncs} - staged per d_inner-tile.
//
// d_inner_tile is chosen so that per-thread VTCM stays under the budget.
// Each thread iterates ceil(d_inner_per_thread d_inner_tile) tiles serially.
#define HTP_SSM_CONV_VTCM_BUDGET (1u << 20) // 1 MiB per thread
// Scalar transpose: src1 {d_conv, d_inner} (DDR) -> {d_inner_per_thread, d_conv} (VTCM)
// Scalar transpose: src1 {d_conv, d_inner} (DDR) -> {d_inner_stride, d_conv} (VTCM)
static inline void transpose_src1(const float * src1_data,
uint32_t src1_stride_inner,
uint32_t i1_off,
uint32_t d_inner_per_thread,
uint32_t d_inner_stride,
uint32_t d_conv,
float * src1_T) {
for (uint32_t i = 0; i < d_inner_per_thread; ++i) {
const float * src_row = src1_data + (i1_off + i) * src1_stride_inner;
for (uint32_t j = 0; j < d_conv; ++j) {
src1_T[j * d_inner_per_thread + i] = src_row[j];
src1_T[j * d_inner_stride + i] = src_row[j];
}
}
}
@@ -280,6 +281,7 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void
}
const uint32_t d_inner_per_thread = ir1 - ir0;
const uint32_t d_inner_stride = scctx->nrows_per_thread;
const uint32_t d_inner_tile = scctx->d_inner_tile;
const float * src0_data = (const float *) src0->data;
@@ -290,8 +292,8 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void
float * src0_T = (float *)(octx->src0_spad.data + ith * octx->src0_spad.size_per_thread);
float * src1_T = (float *)(octx->src1_spad.data + ith * octx->src1_spad.size_per_thread);
// Stage src1 weights once into VTCM in {d_inner_per_thread, d_conv} layout.
transpose_src1(src1_data, src1_stride_inner, ir0, d_inner_per_thread, d_conv, src1_T);
// Stage src1 weights once into VTCM in {d_inner_stride, d_conv} layout.
transpose_src1(src1_data, src1_stride_inner, ir0, d_inner_per_thread, d_inner_stride, d_conv, src1_T);
const uint32_t C_TILE = VLEN_FP32;
@@ -314,7 +316,7 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void
HVX_Vector acc = hvx_vec_splat_f32(0.0f);
for (uint32_t j = 0; j < d_conv; ++j) {
HVX_Vector x = *(const HVX_Vector *) (src0_T + (t + j) * d_inner_tile + cb);
HVX_Vector w = *(const HVX_Vector *) (src1_T + j * d_inner_per_thread + tile_off + cb);
HVX_Vector w = *(const HVX_Vector *) (src1_T + j * d_inner_stride + tile_off + cb);
acc = Q6_Vqf32_vadd_Vqf32Vqf32(acc, Q6_Vqf32_vmpy_VsfVsf(x, w));
}
HVX_Vector res = Q6_Vsf_equals_Vqf32(acc);
@@ -362,8 +364,7 @@ int op_ssm_conv_f32(struct htp_ops_context * octx) {
use_hvx = 1;
}
scctx.nrows_per_thread = (d_inner + n_threads - 1) / n_threads;
scctx.nrows_per_thread += (scctx.nrows_per_thread & 1);
scctx.nrows_per_thread = hex_round_up((d_inner + n_threads - 1) / n_threads, VLEN_FP32);
const uint32_t d_inner_per_thread = scctx.nrows_per_thread;
const uint32_t ncs = src0->ne[0];
@@ -174,7 +174,7 @@ __kernel void kernel_gemv_noshuffle_q8_0_f32(
regA.s6 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
regA.s7 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, regS, regB);
dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, convert_float(regS), regB);
}
// reduction in local memory, assumes #wave=4
+5
View File
@@ -293,6 +293,11 @@ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_t
(sycl::ext::oneapi::bfloat16 *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2,
ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, ggml_is_contiguous(src0),
ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream);
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_BF16) {
op()((const sycl::ext::oneapi::bfloat16 *) src0->data, (const float *) src1->data,
(sycl::ext::oneapi::bfloat16 *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2,
ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, ggml_is_contiguous(src0),
ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream);
#endif
} else {
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, ggml_type_name(dst->type),
+155 -53
View File
@@ -43,14 +43,44 @@ static __dpct_inline__ T op_sgn(T x) {
return x > static_cast<T>(0.f) ? static_cast<T>(1.f) : ((x < static_cast<T>(0.f) ? static_cast<T>(-1.f) : static_cast<T>(0.f)));
}
template<typename T>
static __dpct_inline__ T op_abs(T x) {
return sycl::fabs(x);
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::fabs(x); // or experimental namespace if needed
} else {
return sycl::fabs(x);
}
}
template<typename T>
static __dpct_inline__ T op_expm1(T x) {
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return static_cast<sycl::ext::oneapi::bfloat16>(
sycl::expm1(static_cast<float>(x))
);
} else {
return sycl::expm1(x);
}
}
template<typename T>
static __dpct_inline__ T op_elu(T x) {
return (x > static_cast<T>(0.f)) ? x : sycl::expm1(x);
return (x > static_cast<T>(0.f)) ? x : op_expm1(x);
}
template<typename T>
static __dpct_inline__ T op_tanh(T x) {
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
constexpr int ver = __INTEL_LLVM_COMPILER;
#if defined(__INTEL_LLVM_COMPILER) && (__INTEL_LLVM_COMPILER >= 20260000)
return sycl::ext::oneapi::experimental::tanh(x);
#else
return static_cast<T>(sycl::tanh(static_cast<float>(x)));
#endif
} else {
return sycl::tanh(x);
}
}
template<typename T>
@@ -59,74 +89,106 @@ static __dpct_inline__ T op_gelu(T x) {
const T SQRT_2_OVER_PI = static_cast<T>(0.79788456080286535587989211986876f);
return static_cast<T>(0.5f) * x *
(static_cast<T>(1.0f) +
sycl::tanh(SQRT_2_OVER_PI * x * (static_cast<T>(1.0f) + GELU_COEF_A * x * x)));
op_tanh(SQRT_2_OVER_PI * x * (static_cast<T>(1.0f) + GELU_COEF_A * x * x)));
}
template<typename T>
static __dpct_inline__ T op_exp(T x) {
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::exp(x);
} else {
return sycl::exp(x);
}
}
template<typename T>
static __dpct_inline__ T op_silu(T x) {
return x / (static_cast<T>(1.0f) + sycl::native::exp(-x));
return x / (static_cast<T>(1.0f) + op_exp(-x));
}
template<typename T>
static __dpct_inline__ T op_gelu_quick(T x) {
const T GELU_QUICK_COEF_LOCAL = static_cast<T>(-1.702f);
return x * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(GELU_QUICK_COEF_LOCAL * x)));
static __dpct_inline__ T op_erf(T x) {
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return static_cast<sycl::ext::oneapi::bfloat16>(
sycl::erf(static_cast<float>(x))
);
} else {
return sycl::erf(x);
}
}
template<typename T>
static __dpct_inline__ T op_gelu_erf(T x) {
const T SQRT_2_INV = static_cast<T>(0.70710678118654752440084436210484f);
return static_cast<T>(0.5f) * x * (static_cast<T>(1.0f) + sycl::erf(x * SQRT_2_INV));
return static_cast<T>(0.5f) * x * (static_cast<T>(1.0f) + op_erf(x * SQRT_2_INV));
}
template<typename T>
static __dpct_inline__ T op_tanh(T x) {
return sycl::tanh(x);
static __dpct_inline__ T op_gelu_quick(T x) {
const T GELU_QUICK_COEF_LOCAL = static_cast<T>(-1.702f);
return x * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + op_exp(GELU_QUICK_COEF_LOCAL * x)));
}
template<typename T>
static __dpct_inline__ T op_relu(T x) {
return sycl::fmax(x, static_cast<T>(0));
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::fmax(x, static_cast<T>(0));
} else {
return sycl::fmax(x, static_cast<T>(0));
}
}
template<typename T>
static __dpct_inline__ T op_sigmoid(T x) {
return static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(-x));
return static_cast<T>(1.0f) / (static_cast<T>(1.0f) + op_exp(-x));
}
template<typename T>
static __dpct_inline__ T op_sqrt(T x) {
return sycl::sqrt(x);
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::sqrt(x);
} else {
return sycl::sqrt(x);
}
}
template<typename T>
static __dpct_inline__ T op_sin(T x) {
return sycl::sin(x);
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::sin(x);
} else {
return sycl::sin(x);
}
}
template<typename T>
static __dpct_inline__ T op_cos(T x) {
return sycl::cos(x);
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::cos(x);
} else {
return sycl::cos(x);
}
}
template<typename T>
static __dpct_inline__ T op_hardsigmoid(T x) {
return sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::fmin(
static_cast<T>(1.0f), sycl::ext::oneapi::experimental::fmax(
static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
} else {
return sycl::fmin(static_cast<T>(1.0f),
sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
}
}
template<typename T>
static __dpct_inline__ T op_hardswish(T x) {
return x * sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
}
template<typename T>
static __dpct_inline__ T op_exp(T x) {
return sycl::exp(x);
}
template<typename T>
static __dpct_inline__ T op_expm1(T x) {
return sycl::expm1(x);
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return x * sycl::ext::oneapi::experimental::fmin(static_cast<T>(1.0f), sycl::ext::oneapi::experimental::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
} else {
return x * sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
}
}
template<typename T>
@@ -134,13 +196,17 @@ static __dpct_inline__ T op_log(T x) {
if (x <= static_cast<T>(0)) {
return neg_infinity<T>();
}
return sycl::log(x);
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::log(x);
} else {
return sycl::log(x);
}
}
template<typename T>
static __dpct_inline__ T op_softplus(T x) {
const float xf = (float) x;
const float ax = sycl::fabs(xf);
const float ax = op_abs(xf);
const float m = sycl::fmax(xf, 0.0f);
const float y = m + sycl::log1p(sycl::exp(-ax));
return (T) y;
@@ -159,8 +225,14 @@ static __dpct_inline__ T op_step(T x) {
template<typename T>
static __dpct_inline__ T op_leaky_relu(T x, float negative_slope) {
T neg_slope_T = static_cast<T>(negative_slope);
return sycl::fmax(x, static_cast<T>(0)) +
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::fmax(x, static_cast<T>(0)) +
sycl::ext::oneapi::experimental::fmin(x, static_cast<T>(0.0f)) * neg_slope_T;
} else {
return sycl::fmax(x, static_cast<T>(0)) +
sycl::fmin(x, static_cast<T>(0.0f)) * neg_slope_T;
}
}
template<typename T>
@@ -175,22 +247,40 @@ static __dpct_inline__ T op_clamp(T x, float min_val, float max_val) {
template<typename T>
static __dpct_inline__ T op_floor(T x) {
return sycl::floor(x);
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::floor(x);
} else {
return sycl::floor(x);
}
}
template<typename T>
static __dpct_inline__ T op_ceil(T x) {
return sycl::ceil(x);
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::ceil(x);
} else {
return sycl::ceil(x);
}
}
template<typename T>
static __dpct_inline__ T op_round(T x) {
return sycl::round(x);
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return static_cast<sycl::ext::oneapi::bfloat16>(
sycl::round(static_cast<float>(x))
);
} else {
return sycl::round(x);
}
}
template<typename T>
static __dpct_inline__ T op_trunc(T x) {
return sycl::trunc(x);
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
return sycl::ext::oneapi::experimental::trunc(x);
} else {
return sycl::trunc(x);
}
}
template<typename T, typename F>
@@ -339,7 +429,7 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst,
const int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
[=](sycl::nd_item<3> /*item_ct1*/) {
[=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset);
});
}
@@ -354,8 +444,8 @@ static void arange_kernel(T * dst, const int k, T start, T step,
template<typename KernelInvoker, typename... Args>
static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16 || dst->src[0]->type == GGML_TYPE_BF16);
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_BF16);
GGML_ASSERT(dst->src[0]->type == dst->type);
dpct::queue_ptr main_stream = ctx.stream();
@@ -367,6 +457,14 @@ static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx,
kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);
break;
}
#ifdef GGML_SYCL_HAS_BF16
case GGML_TYPE_BF16:
{
auto data_pts = cast_data<sycl::ext::oneapi::bfloat16>(dst);
kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);
break;
}
#endif
case GGML_TYPE_F32:
{
auto data_pts = cast_data<float>(dst);
@@ -480,7 +578,7 @@ static inline void ggml_sycl_op_unary(
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
sycl::range<1>(256)),
[=](sycl::nd_item<1> item_ct1) {
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
unary_op_generic_kernel(
src, dst_ptr, k_elements,
ne0, ne1, ne2, ne3,
@@ -508,7 +606,7 @@ static inline void ggml_sycl_op_arange(ggml_backend_sycl_context & ctx, ggml_ten
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE),
sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
arange_kernel(dst_ptr, k, start, step, item_ct1);
});
}
@@ -602,7 +700,7 @@ static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE),
sycl::range<1>(SYCL_EXP_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
unary_op_log_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
@@ -640,7 +738,7 @@ static inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tenso
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQRT_BLOCK_SIZE),
sycl::range<1>(SYCL_SQRT_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
unary_op_sqrt_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
@@ -653,7 +751,7 @@ static inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),
sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
unary_op_sin_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
@@ -666,7 +764,7 @@ static inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),
sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
unary_op_cos_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
@@ -681,7 +779,7 @@ static inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE),
sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
unary_op_leaky_relu_kernel(src, dst_ptr, k_elements, slope, item_ct1);
});
}, negative_slope);
@@ -694,7 +792,7 @@ static inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQR_BLOCK_SIZE),
sycl::range<1>(SYCL_SQR_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
unary_op_sqr_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
@@ -711,7 +809,7 @@ static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tens
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE),
sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
clamp(src, dst_ptr, min_arg, max_arg, k_elements, item_ct1);
});
}, min_val, max_val);
@@ -774,7 +872,8 @@ static inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tens
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
main_stream->parallel_for(
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
gated_op_fused_geglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
});
});
@@ -785,7 +884,8 @@ static inline void ggml_sycl_op_reglu(ggml_backend_sycl_context & ctx, ggml_tens
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_RELU_BLOCK_SIZE); // Using RELU block size for reglu
main_stream->parallel_for(
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
gated_op_fused_reglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
});
});
@@ -796,7 +896,8 @@ static inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_ten
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_SILU_BLOCK_SIZE); // Using SILU block size for swiglu
main_stream->parallel_for(
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)),
sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
gated_op_fused_swiglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
});
});
@@ -811,7 +912,6 @@ __dpct_inline__ float ggml_sycl_op_swiglu_oai_single(float x, float g, float alp
return out_glu;
}
template <typename T>
static void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k,
const int64_t n, const int64_t o0, const int64_t o1,
@@ -845,7 +945,7 @@ static void swiglu_oai_sycl(const T * x,
const int64_t num_blocks = (k + SYCL_GLU_BLOCK_SIZE - 1) / SYCL_GLU_BLOCK_SIZE;
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
swiglu_oai_kernel(x, g, dst, k, n, o0, o1, alpha, limit, item_ct1);
});
}
@@ -899,7 +999,8 @@ static inline void ggml_sycl_op_geglu_erf(ggml_backend_sycl_context & ctx, ggml_
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
main_stream->parallel_for(
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
gated_op_fused_geglu_erf(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
});
});
@@ -910,7 +1011,8 @@ static inline void ggml_sycl_op_geglu_quick(ggml_backend_sycl_context & ctx, ggm
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
main_stream->parallel_for(
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
gated_op_fused_geglu_quick(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
});
});
+5
View File
@@ -108,6 +108,9 @@ if (Vulkan_FOUND)
if (GGML_VULKAN_CHECK_RESULTS)
add_compile_definitions(GGML_VULKAN_CHECK_RESULTS)
# the result-checking path computes a CPU reference graph via
# ggml_graph_compute_with_ctx(), which is defined in ggml-cpu
target_link_libraries(ggml-vulkan PRIVATE ggml-cpu)
endif()
if (GGML_VULKAN_DEBUG)
@@ -129,6 +132,8 @@ if (Vulkan_FOUND)
if (GGML_VULKAN_RUN_TESTS)
add_compile_definitions(GGML_VULKAN_RUN_TESTS)
# the test path also calls ggml_graph_compute_with_ctx() (ggml-cpu)
target_link_libraries(ggml-vulkan PRIVATE ggml-cpu)
endif()
# Set up toolchain for host compilation whether cross-compiling or not
@@ -905,11 +905,12 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key {
ggml_type src0_type;
ggml_type src1_type;
int vectorized;
uint32_t num_cols;
bool use_mmvq;
bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const {
return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized &&
use_mmvq == other.use_mmvq;
num_cols == other.num_cols && use_mmvq == other.use_mmvq;
}
};
@@ -919,6 +920,7 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash {
ggml_webgpu_hash_combine(seed, key.src0_type);
ggml_webgpu_hash_combine(seed, key.src1_type);
ggml_webgpu_hash_combine(seed, key.vectorized);
ggml_webgpu_hash_combine(seed, key.num_cols);
ggml_webgpu_hash_combine(seed, key.use_mmvq);
return seed;
}
@@ -993,11 +995,12 @@ struct ggml_webgpu_mul_mat_id_pipeline_key {
ggml_type src0_type;
ggml_type src1_type;
uint32_t n_experts;
uint32_t num_cols;
int vectorized;
bool operator==(const ggml_webgpu_mul_mat_id_pipeline_key & other) const {
return src0_type == other.src0_type && src1_type == other.src1_type && n_experts == other.n_experts &&
vectorized == other.vectorized;
num_cols == other.num_cols && vectorized == other.vectorized;
}
};
@@ -1007,6 +1010,7 @@ struct ggml_webgpu_mul_mat_id_pipeline_key_hash {
ggml_webgpu_hash_combine(seed, key.src0_type);
ggml_webgpu_hash_combine(seed, key.src1_type);
ggml_webgpu_hash_combine(seed, key.n_experts);
ggml_webgpu_hash_combine(seed, key.num_cols);
ggml_webgpu_hash_combine(seed, key.vectorized);
return seed;
}
@@ -1107,7 +1111,7 @@ inline bool ggml_webgpu_can_use_mmvq(const ggml_tensor * src0,
const ggml_tensor * src1,
bool supports_dot_product,
const std::string & vendor) {
if (src1->ne[1] == 1) {
if (src1->ne[1] <= 4) {
bool supports_dp4a = vendor == "amd" || vendor == "intel" || vendor == "nvidia";
if (supports_dp4a && supports_dot_product) {
switch (src1->type) {
@@ -1889,6 +1893,7 @@ class ggml_webgpu_shader_lib {
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
1 :
0;
key.num_cols = context.dst->ne[1];
key.use_mmvq =
ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor);
@@ -2004,6 +2009,7 @@ class ggml_webgpu_shader_lib {
if (key.vectorized) {
variant += "_vectorized";
}
defines.push_back(std::string("NUM_COLS=") + std::to_string(key.num_cols));
auto processed = preprocessor.preprocess(shader_src, defines);
auto decisions = std::make_shared<ggml_webgpu_mul_mat_vec_shader_decisions>();
@@ -2421,6 +2427,7 @@ class ggml_webgpu_shader_lib {
if (key.vectorized) {
variant += "_vectorized";
}
defines.push_back(std::string("NUM_COLS=1"));
defines.push_back(std::string("N_EXPERTS=") + std::to_string(key.n_experts));
+11 -9
View File
@@ -1418,15 +1418,17 @@ static void ggml_webgpu_quantize_q8_dispatch(webgpu_context &
const size_t dst_offset = ggml_webgpu_tensor_offset(dst);
const size_t q8_src1_align_offset = ROUNDUP_POW2(
dst_offset + ggml_nbytes(dst), ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
const size_t q8_src1_binding_size =
ROUNDUP_POW2(src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)),
WEBGPU_STORAGE_BUF_BINDING_MULT);
const size_t q8_src1_binding_size = ROUNDUP_POW2(
src1->ne[3] * src1->ne[2] * src1->ne[1] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)),
WEBGPU_STORAGE_BUF_BINDING_MULT);
std::vector<uint32_t> q8_params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
(uint32_t) src1->ne[0],
(uint32_t) src1->ne[1],
(uint32_t) src1->ne[2],
(uint32_t) src1->ne[3],
};
@@ -1442,7 +1444,7 @@ static void ggml_webgpu_quantize_q8_dispatch(webgpu_context &
uint32_t q8_wg_x = 1;
uint32_t q8_wg_y = 1;
const uint32_t wg_per_vec = (src0->ne[0] / 4 + (q8_wg_size - 1)) / q8_wg_size;
const uint32_t q8_total_wg = src1->ne[2] * src1->ne[3] * wg_per_vec;
const uint32_t q8_total_wg = src1->ne[1] * src1->ne[2] * src1->ne[3] * wg_per_vec;
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
compute_2d_workgroups(q8_total_wg, max_wg_per_dim, q8_wg_x, q8_wg_y);
@@ -1456,7 +1458,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
ggml_tensor * src1,
ggml_tensor * dst) {
// Determine if this is a mat-vec operation
bool is_vec = (dst->ne[1] == 1);
bool use_mat_vec = (dst->ne[1] <= 4);
// use MMVQ path for mat-vec
bool use_mmvq = ggml_webgpu_can_use_mmvq(src0, src1, ctx->global_ctx->capabilities.supports_dot_product,
@@ -1482,7 +1484,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
webgpu_pipeline pipeline;
std::vector<webgpu_dispatch_desc> dispatches;
if (is_vec) {
if (use_mat_vec) {
if (use_mmvq) {
ggml_webgpu_quantize_q8_dispatch(ctx, src0, src1, dst, dispatches);
}
@@ -1529,7 +1531,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
uint32_t wg_y = 1;
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
if (is_vec) {
if (use_mat_vec) {
auto * decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
uint32_t batches = dst->ne[2] * dst->ne[3];
@@ -3691,8 +3693,8 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
ggml_webgpu_can_use_mmvq(src0, src1, ctx->webgpu_global_ctx->capabilities.supports_dot_product,
ctx->webgpu_global_ctx->vendor);
if (use_mmvq) {
const size_t q8_src1_size =
src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32));
const size_t q8_src1_size = src1->ne[3] * src1->ne[2] * src1->ne[1] *
(36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32));
res = ROUNDUP_POW2(res + q8_src1_size +
ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment,
WEBGPU_STORAGE_BUF_BINDING_MULT);
@@ -103,7 +103,7 @@ fn main(
#ifdef USE_SUBGROUP_REDUCTION
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let subgroup_total = subgroupAdd(acc[row]);
let subgroup_total = subgroupAdd(acc[0][row]);
if (subgroup_invocation_id == 0u) {
partial_sums[partial_index(row, subgroup_id)] = subgroup_total;
}
@@ -126,7 +126,7 @@ fn main(
#ifdef USE_WORKGROUP_REDUCTION
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
partial_sums[partial_index(row, thread_id)] = acc[row];
partial_sums[partial_index(row, thread_id)] = acc[0][row];
}
workgroupBarrier();
@@ -91,61 +91,67 @@ fn main(
let dst_idx_base = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row_base;
#ifdef MMVQ
let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * (params.k / 32u);
let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * params.n * (params.k / 32u);
let acc = accumulate_vec_q_dot(thread_id, row_base, src0_batch_offset, src1q_idx_base);
#else
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
let acc = accumulate_vec_dot(thread_id, row_base, src0_batch_offset, src1_idx_base);
#endif
for (var col = 0u;col < NUM_COLS;col += 1) {
#ifdef USE_SUBGROUP_REDUCTION
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let subgroup_total = subgroupAdd(acc[row]);
if (subgroup_invocation_id == 0u) {
partial_sums[partial_index(row, subgroup_id)] = subgroup_total;
}
}
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let subgroup_total = subgroupAdd(acc[col][row]);
if (subgroup_invocation_id == 0u) {
partial_sums[partial_index(row, subgroup_id)] = subgroup_total;
}
}
workgroupBarrier();
workgroupBarrier();
for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) {
let output_row = row_base + row;
var row_acc = 0.0f;
for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) {
row_acc += partial_sums[partial_index(row, k)];
}
let row_total = subgroupAdd(row_acc);
if (subgroup_invocation_id == 0) {
dst[dst_idx_base + row] = row_total;
}
}
for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) {
let output_row = row_base + row;
var row_acc = 0.0f;
for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) {
row_acc += partial_sums[partial_index(row, k)];
}
let row_total = subgroupAdd(row_acc);
if (subgroup_invocation_id == 0) {
dst[dst_idx_base + col * params.m + row] = row_total;
}
}
#endif
#ifdef USE_WORKGROUP_REDUCTION
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
partial_sums[partial_index(row, thread_id)] = acc[row];
}
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
partial_sums[partial_index(row, thread_id)] = acc[col][row];
}
workgroupBarrier();
var stride = WG_SIZE / 2u;
while (stride > 0) {
if (thread_id < stride) {
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)];
}
}
workgroupBarrier();
stride = stride / 2;
}
if (thread_id < OUTPUTS_PER_WG) {
let output_row = row_base + thread_id;
if (output_row < params.m) {
dst[dst_idx_base + col * params.m + thread_id] = partial_sums[partial_index(thread_id, 0)];
}
}
#endif
workgroupBarrier();
var stride = WG_SIZE / 2u;
while (stride > 0) {
if (thread_id < stride) {
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)];
}
}
workgroupBarrier();
stride = stride / 2;
}
if (thread_id < OUTPUTS_PER_WG) {
let output_row = row_base + thread_id;
if (output_row < params.m) {
dst[dst_idx_base + thread_id] = partial_sums[partial_index(thread_id, 0)];
}
}
#endif
}
File diff suppressed because it is too large Load Diff
@@ -51,10 +51,7 @@ fn repack_b_dm(block: u32) -> B_DS_TYPE {
fn get_dm(block_byte_base: u32) -> f32 {
return f32(load_f16_at_src0(block_byte_base));
}
fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 {
return f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK;
}
#endif
#endif // MUL_ACC_Q4_0
#ifdef MUL_ACC_Q4_1
#define BLOCK_SIZE_BYTES 20
@@ -85,10 +82,7 @@ fn get_dm(block_byte_base: u32) -> vec2<f32> {
f32(load_f16_at_src0(block_byte_base + 2u))
);
}
fn mul_q8_1(row_sum: i32, dma: vec2<f32>, b_ds: B_DS_TYPE) -> f32 {
return f32(row_sum) * (dma.x * b_ds.x) + dma.y * b_ds.y / THREADS_PER_BLOCK;
}
#endif
#endif // MUL_ACC_Q4_1
#ifdef MUL_ACC_Q8_0
#define BLOCK_SIZE_BYTES 34
@@ -111,46 +105,48 @@ fn repack_b_dm(block: u32) -> B_DS_TYPE {
fn get_dm(block_byte_base: u32) -> f32 {
return f32(load_f16_at_src0(block_byte_base));
}
fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 {
return f32(row_sum) * (da * b_ds);
}
#endif
#endif // MUL_ACC_Q8_0
#ifdef LEGACY_QUANTS
fn mmvq_dot_product(a_byte_base: u32, b_inner_id: u32, b_repacked: vec2<u32>, b_ds: B_DS_TYPE) -> f32 {
var row_sum = 0;
let a_repacked = repack_a(a_byte_base, b_inner_id);
row_sum += dot4I8Packed(a_repacked[0], b_repacked[0]);
row_sum += dot4I8Packed(a_repacked[1], b_repacked[1]);
return mul_q8_1(row_sum, get_dm(a_byte_base), b_ds);
}
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
var acc: array<f32, OUTPUTS_PER_WG>;
#if defined(LEGACY_QUANTS)
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<array<f32, OUTPUTS_PER_WG>, NUM_COLS> {
var acc: array<array<f32, OUTPUTS_PER_WG>, NUM_COLS>;
let num_blocks = params.k / BLOCK_SIZE;
for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
let b_inner_id = thread_id % THREADS_PER_BLOCK;
let b_block_idx = src1q_idx_base + block;
let b_repacked = repack_b_qs(b_block_idx, b_inner_id);
let b_ds = repack_b_dm(b_block_idx);
let inner_id = thread_id % THREADS_PER_BLOCK;
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let output_row = row_base + row;
if (output_row < params.m) {
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
acc[row] += mmvq_dot_product(block_byte_base, b_inner_id, b_repacked, b_ds);
let a_repacked = repack_a(block_byte_base, inner_id);
let da = get_dm(block_byte_base);
for (var col = 0u;col < NUM_COLS;col += 1) {
let src1q_idx = src1q_idx_base + col * (params.k / Q8_BLOCK_SIZE) + block;
let b_repacked = repack_b_qs(src1q_idx, inner_id);
let b_ds = repack_b_dm(src1q_idx);
let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1]);
#if defined(MUL_ACC_Q4_0)
acc[col][row] += f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK;
#endif // MUL_ACC_Q4_0
#if defined(MUL_ACC_Q4_1)
acc[col][row] += f32(row_sum) * (da.x * b_ds.x) + da.y * b_ds.y / THREADS_PER_BLOCK;
#endif // MUL_ACC_Q4_1
#if defined(MUL_ACC_Q8_0)
acc[col][row] += f32(row_sum) * (da * b_ds);
#endif // MUL_ACC_Q8_0
}
}
}
}
return acc;
}
#endif
#endif // LEGACY_QUANTS
#ifdef MUL_ACC_Q2_K
#define BLOCK_SIZE_BYTES 84
@@ -191,22 +187,7 @@ fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2<f32> {
let scale = byte_of(load_u32_at_src0_aligned(scale_byte), scale_byte & 3u);
return vec2<f32>(f32(scale & 0xFu), f32(scale >> 4u));
}
fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4<u32>, b_ds: B_DS_TYPE) -> f32 {
let a_repacked = repack_a(a_byte_base, tid);
let dm = get_dm(a_byte_base);
let scale_min = get_scale_min(a_byte_base, tid);
let scale_q = i32(scale_min.x);
let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u;
let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1])
+ dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q;
let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4)
+ dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4);
return b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m));
}
#endif
#endif // MUL_ACC_Q2_K
#ifdef MUL_ACC_Q4_K
#define BLOCK_SIZE_BYTES 144
@@ -265,39 +246,52 @@ fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2<f32> {
return vec2<f32>(scale, min_val);
}
fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4<u32>, b_ds: B_DS_TYPE) -> f32 {
let a_repacked = repack_a(a_byte_base, tid);
let dm = get_dm(a_byte_base);
let scale_min = get_scale_min(a_byte_base, tid);
let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1])
+ dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]);
// Each thread covers half of the Q8_1 block, so add only b_ds.y/2.
return b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD));
}
#endif
#endif // MUL_ACC_Q4_K
#ifdef K_QUANTS
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
var acc: array<f32, OUTPUTS_PER_WG>;
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<array<f32, OUTPUTS_PER_WG>, NUM_COLS> {
var acc: array<array<f32, OUTPUTS_PER_WG>, NUM_COLS>;
let tid = thread_id % THREADS_PER_BLOCK;
for (var block = thread_id / THREADS_PER_BLOCK; block < params.k / BLOCK_SIZE; block += WG_SIZE / THREADS_PER_BLOCK) {
let src1q_idx = src1q_idx_base + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE;
let b_repacked = repack_b_qs(src1q_idx, tid);
let b_ds = repack_b_dm(src1q_idx);
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let output_row = row_base + row;
if (output_row < params.m) {
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
acc[row] += mmvq_dot_product(block_byte_base, tid, b_repacked, b_ds);
let a_repacked = repack_a(block_byte_base, tid);
let dm = get_dm(block_byte_base);
let scale_min = get_scale_min(block_byte_base, tid);
for (var col = 0u;col < NUM_COLS;col += 1) {
let src1q_idx = src1q_idx_base + col * (params.k / Q8_BLOCK_SIZE) + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE;
let b_repacked = repack_b_qs(src1q_idx, tid);
let b_ds = repack_b_dm(src1q_idx);
#if defined(MUL_ACC_Q2_K)
let scale_q = i32(scale_min.x);
let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u;
let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1])
+ dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q;
let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4)
+ dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4);
acc[col][row] += b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m));
#endif // MUL_ACC_Q2_K
#if defined(MUL_ACC_Q4_K)
let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1])
+ dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]);
// Each thread covers half of the Q8_1 block, so add only b_ds.y/2.
acc[col][row] += b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD));
#endif // MUL_ACC_Q4_K
}
}
}
}
return acc;
}
#endif
#endif // K_QUANTS
@@ -9,9 +9,11 @@ requires packed_4x8_integer_dot_product;
struct Params {
offset_src1: u32,
stride_11: u32,
stride_12: u32,
stride_13: u32,
ne0: u32,
ne1: u32,
ne2: u32,
ne3: u32,
};
@@ -57,25 +59,28 @@ fn main(
@builtin(num_workgroups) num_wg: vec3<u32>
) {
let thread_id = local_id.x;
let num_vec4 = params.ne0 / 4u;
let ne0_vec4 = params.ne0 / 4u;
let wg_per_vec = (num_vec4 + (WG_SIZE - 1u)) / WG_SIZE;
let total_batches = wg_per_vec * params.ne2 * params.ne3;
let wg_per_vec = (ne0_vec4 + (WG_SIZE - 1u)) / WG_SIZE;
let total_batches = wg_per_vec * params.ne1 * params.ne2 * params.ne3;
let wg_linear = wg_id.y * num_wg.x + wg_id.x;
if (wg_linear >= total_batches) {
return;
}
let src13_idx = wg_linear / (params.ne2 * wg_per_vec);
let src12_idx = (wg_linear - src13_idx * (params.ne2 * wg_per_vec)) / wg_per_vec;
let src11_wg_idx = wg_linear % wg_per_vec;
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
let vec_idx = wg_linear / wg_per_vec;
let src13_idx = vec_idx / (params.ne2 * params.ne1);
let vec_ne12_num = vec_idx % (params.ne2 * params.ne1);
let src12_idx = vec_ne12_num / params.ne1;
let src11_idx = vec_ne12_num % params.ne1;
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + src11_idx * params.stride_11;
let src1_idx_vec4_base = src1_idx_base / 4u;
let blocks_per_row = params.ne0 / 32u;
let blocks_per_wg = (WG_SIZE * 4u) / 32u;
let src1q_idx_base = (src13_idx * params.ne2 + src12_idx) * blocks_per_row;
let src1q_idx_base = ((src13_idx * params.ne2 + src12_idx) * params.ne1 + src11_idx) * blocks_per_row;
let src11_wg_idx = wg_linear % wg_per_vec;
let src1q_idx = src1q_idx_base + src11_wg_idx * blocks_per_wg + thread_id / 8u;
let qs_idx = thread_id % 8u;
@@ -85,7 +90,7 @@ fn main(
var thread_amax = 0.0;
let src11_vec4_idx = src11_wg_idx * WG_SIZE + thread_id;
let is_valid = src11_vec4_idx < num_vec4;
let is_valid = src11_vec4_idx < ne0_vec4;
#ifdef USE_SUBGROUP_REDUCTION
+1
View File
@@ -359,6 +359,7 @@ class Keys:
CHUNK_SIZE = "clip.audio.chunk_size"
CONV_KERNEL_SIZE = "clip.audio.conv_kernel_size"
MAX_POS_EMB = "clip.audio.max_pos_emb"
FEATURE_LAYERS = "clip.audio.feature_layer" # Granite Speech Plus
class Attention:
HEAD_COUNT = "clip.audio.attention.head_count"
+3
View File
@@ -1310,6 +1310,9 @@ class GGUFWriter:
def add_audio_max_pos_emb(self, value: int) -> None:
self.add_uint32(Keys.ClipAudio.MAX_POS_EMB, value)
def add_audio_feature_layers(self, layers: Sequence[int]) -> None:
self.add_array(Keys.ClipAudio.FEATURE_LAYERS, layers)
def add_audio_projector_window_size(self, value: int) -> None:
self.add_uint32(Keys.ClipAudio.Projector.WINDOW_SIZE, value)
+9 -8
View File
@@ -558,14 +558,15 @@ extern "C" {
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model);
LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_embd_out (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_ctx_train (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_embd_out (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_layer_nextn(const struct llama_model * model);
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model);
// Get the model's RoPE frequency scaling factor
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
+8
View File
@@ -1156,6 +1156,10 @@ void llama_context::set_embeddings_layer_inp(uint32_t lid, bool enable) {
sched_need_reserve = true;
}
void llama_context::set_nextn_layer_offset(int32_t offset) {
cparams.nextn_layer_offset = offset;
}
void llama_context::set_causal_attn(bool value) {
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
@@ -3699,6 +3703,10 @@ void llama_set_embeddings_layer_inp(llama_context * ctx, uint32_t lid, bool valu
ctx->set_embeddings_layer_inp(lid, value);
}
void llama_set_nextn_layer_offset(llama_context * ctx, int32_t offset) {
ctx->set_nextn_layer_offset(offset);
}
llama_memory_t llama_get_memory(const struct llama_context * ctx) {
if (!ctx) {
return nullptr;
+1
View File
@@ -115,6 +115,7 @@ struct llama_context {
void set_embeddings (bool value);
void set_embeddings_nextn(bool value, bool masked);
void set_embeddings_layer_inp(uint32_t lid, bool enable);
void set_nextn_layer_offset(int32_t offset);
void set_causal_attn(bool value);
void set_warmup(bool value);
+2
View File
@@ -18,6 +18,8 @@ struct llama_cparams {
int32_t n_threads; // number of threads to use for generation
int32_t n_threads_batch; // number of threads to use for batch processing
int32_t nextn_layer_offset = 0;
float rope_freq_base;
float rope_freq_scale;
+5
View File
@@ -95,6 +95,11 @@ LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_c
// If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits
LLAMA_API void llama_set_embeddings_nextn(struct llama_context * ctx, bool value, bool masked);
// Select which appended NextN block the DECODER_MTP graph runs (offset past
// the trunk: il = n_layer() + offset). Used by the speculative NextN driver to
// chain multiple trained NextN heads. Default 0 (first head).
LLAMA_API void llama_set_nextn_layer_offset(struct llama_context * ctx, int32_t offset);
// mirrors:
// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx);
+9 -2
View File
@@ -682,9 +682,16 @@ struct llm_graph_params {
}
}
// TODO: https://github.com/ggml-org/llama.cpp/pull/24340#discussion_r3448035248
if (cparams.nextn_layer_offset != other.cparams.nextn_layer_offset) {
return false;
}
return
cparams.embeddings == other.cparams.embeddings &&
cparams.causal_attn == other.cparams.causal_attn &&
cparams.embeddings == other.cparams.embeddings &&
cparams.embeddings_nextn == other.cparams.embeddings_nextn &&
cparams.embeddings_nextn_masked == other.cparams.embeddings_nextn_masked &&
cparams.causal_attn == other.cparams.causal_attn &&
arch == other.arch &&
gtype == other.gtype &&
cvec == other.cvec &&
+4
View File
@@ -2312,6 +2312,10 @@ int32_t llama_model_n_layer(const llama_model * model) {
return model->hparams.n_layer();
}
int32_t llama_model_n_layer_nextn(const llama_model * model) {
return model->hparams.n_layer_nextn;
}
int32_t llama_model_n_head(const llama_model * model) {
return model->hparams.n_head();
}
+2 -2
View File
@@ -932,8 +932,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
// copy the KV pairs from the input file
gguf_set_kv (ctx_out.get(), ml.metadata);
gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV
gguf_set_val_u32(ctx_out.get(), "general.file_type", ftype); // TODO: use LLM_KV
gguf_set_val_u32(ctx_out.get(), ml.llm_kv(LLM_KV_GENERAL_QUANTIZATION_VERSION).c_str(), GGML_QNT_VERSION);
gguf_set_val_u32(ctx_out.get(), ml.llm_kv(LLM_KV_GENERAL_FILE_TYPE).c_str(), ftype);
// Remove split metadata
gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str());
-2
View File
@@ -2813,8 +2813,6 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t
cur_p->data[i].logit = -INFINITY;
}
}
llama_sampler_softmax_impl(cur_p, true);
}
static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) {
+5 -5
View File
@@ -101,11 +101,11 @@ void llama_model_glm_dsa::load_arch_tensors(llama_model_loader &) {
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags);
// DSA indexer
layer.indexer_k_norm = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "weight", i), {hparams.indexer_head_size}, flags);
layer.indexer_k_norm_b = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "bias", i), {hparams.indexer_head_size}, flags);
layer.indexer_proj = create_tensor(tn(LLM_TENSOR_INDEXER_PROJ, "weight", i), {n_embd, hparams.indexer_n_head}, flags);
layer.indexer_attn_k = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_K, "weight", i), {n_embd, hparams.indexer_head_size}, flags);
layer.indexer_attn_q_b = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.indexer_n_head * hparams.indexer_head_size}, flags);
layer.indexer_k_norm = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "weight", i), {hparams.indexer_head_size}, flags | TENSOR_NOT_REQUIRED);
layer.indexer_k_norm_b = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "bias", i), {hparams.indexer_head_size}, flags | TENSOR_NOT_REQUIRED);
layer.indexer_proj = create_tensor(tn(LLM_TENSOR_INDEXER_PROJ, "weight", i), {n_embd, hparams.indexer_n_head}, flags | TENSOR_NOT_REQUIRED);
layer.indexer_attn_k = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_K, "weight", i), {n_embd, hparams.indexer_head_size}, flags | TENSOR_NOT_REQUIRED);
layer.indexer_attn_q_b = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.indexer_n_head * hparams.indexer_head_size}, flags | TENSOR_NOT_REQUIRED);
if (i < (int) hparams.n_layer_dense_lead) {
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags);
+27 -28
View File
@@ -112,7 +112,7 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) {
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED);
};
auto load_block_mtp = [&](int i, bool is_first_mtp) {
auto load_block_mtp = [&](int i) {
auto & layer = layers[i];
const uint32_t n_head_l = hparams.n_head(i);
@@ -121,15 +121,12 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) {
// The MTP block is a full Step3p5 decoder layer (mtp_block) plus the
// NextN-specific wiring (enorm/hnorm/eh_proj + optional shared head).
// `mtp_flags` becomes NOT_REQUIRED when the GGUF is trunk-only.
//
// Only the FIRST MTP block (i == n_main) is required for the
// single-block MTP runtime; trailing MTP blocks are always tolerated
// as missing so pruned GGUFs (block 0 only) load cleanly. Override
// mtp_flags to NOT_REQUIRED for those.
const int eff_mtp_flags = is_first_mtp ? mtp_flags : (mtp_flags | TENSOR_NOT_REQUIRED);
// Multi-block MTP: every declared MTP block is required (the draft chain
// runs all n_layer_nextn heads), so each block uses the captured
// `mtp_flags` directly — already NOT_REQUIRED for a trunk-only GGUF,
// which keeps that path correct.
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, eff_mtp_flags);
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, mtp_flags);
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED);
@@ -140,12 +137,12 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) {
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | TENSOR_DUPLICATED);
}
create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, eff_mtp_flags);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, eff_mtp_flags);
create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, mtp_flags);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, mtp_flags);
layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, eff_mtp_flags);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, mtp_flags);
// dense MLP (leading dense blocks) — present if the MTP block isn't MoE
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
@@ -165,9 +162,9 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) {
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED);
// NextN-specific tensors that define the MTP block.
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, eff_mtp_flags);
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, eff_mtp_flags);
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, eff_mtp_flags);
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, mtp_flags);
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, mtp_flags);
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, mtp_flags);
layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED);
@@ -176,13 +173,11 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) {
for (int i = 0; i < n_layer; ++i) {
load_block_trunk(i, trunk_flags);
}
// Only the first MTP block (i == n_main) is required at runtime — the
// single-block-MTP graph in build_arch_graph always uses that one.
// Trailing MTP blocks are loaded if present (so an un-pruned GGUF with
// all MTP layers still works) but tolerated when absent via the pruning
// path. See scripts/prune_step35_extra_mtp.py for the pruner.
// All n_layer_nextn MTP blocks are required — the multi-block draft chain
// runs every head (head k at offset k). The GGUF declares the count via
// step35.nextn_predict_layers.
for (int i = n_layer; i < n_layer_all; ++i) {
load_block_mtp(i, /*is_first_mtp=*/ i == n_layer);
load_block_mtp(i);
}
}
@@ -372,13 +367,14 @@ llama_model_step35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr
: llm_graph_context(params) {
GGML_ASSERT(hparams.n_layer_nextn > 0 && "STEP35 MTP requires n_layer_nextn > 0");
// Single-block MTP only: always run the first trained MTP block (Qwen
// MTP / vLLM single-MTP-layer style). Multi-block round-robin proved to
// be a much deeper refactor than this PR justifies; the trailing MTP
// blocks are loaded with TENSOR_NOT_REQUIRED so pruned GGUFs (with just
// block 0) also work — see load_arch_tensors below and
// scripts/prune_step35_extra_mtp.py.
const int il = hparams.n_layer();
// Multi-block MTP: the DECODER_MTP graph runs the MTP head selected by
// cparams.nextn_layer_offset (0 = first trained head). The speculative driver
// bumps the offset per draft step to chain heads 45->46->47. offset 0 keeps
// single-block behavior identical to before.
const int il = hparams.n_layer() + cparams.nextn_layer_offset;
GGML_ASSERT(cparams.nextn_layer_offset >= 0 &&
cparams.nextn_layer_offset < (int) hparams.n_layer_nextn &&
"nextn_layer_offset out of range [0, n_layer_nextn)");
const auto & layer = model.layers[il];
GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj");
@@ -536,6 +532,9 @@ llama_model_step35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "mtp_post_ffn", il);
ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
// Pre-norm hidden state: used by the AR draft loop to seed the next MTP step.
cb(cur, "h_nextn", -1);
res->t_h_nextn = cur;
+148 -1
View File
@@ -129,7 +129,154 @@ void test_gbnf_generation(testing &t) {
});
assert_gbnf_equal(t, R"""(
root ::= ([^<] | "<" [^/] | "</" [^t] | "</t" [^a] | "</ta" [^g] | "</tag" [^>])* ("<" | "</" | "</t" | "</ta" | "</tag")?
root ::= until-0
space ::= | " " | "\n"{1,2} [ \t]{0,20}
until-0 ::= | [<] until-0-01 | [^<] until-0
until-0-01 ::= | [<] until-0-01 | [/] until-0-02 | [^/<] until-0
until-0-02 ::= | [<] until-0-01 | [t] until-0-03 | [^<t] until-0
until-0-03 ::= | [<] until-0-01 | [a] until-0-04 | [^<a] until-0
until-0-04 ::= | [<] until-0-01 | [g] until-0-05 | [^<g] until-0
until-0-05 ::= | [<] until-0-01 | [^<>] until-0
)""", gbnf);
});
t.test("until grammar overlapping delimiter", [](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
return p.until("\n</parameter>\n");
});
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
parser.build_grammar(builder);
});
assert_gbnf_equal(t, R"""(
root ::= until-0
space ::= | " " | "\n"{1,2} [ \t]{0,20}
until-0 ::= | [\n] until-0-01 | [^\n] until-0
until-0-01 ::= | [\n] until-0-01 | [<] until-0-02 | [^\n<] until-0
until-0-02 ::= | [\n] until-0-01 | [/] until-0-03 | [^\n/] until-0
until-0-03 ::= | [\n] until-0-01 | [p] until-0-04 | [^\np] until-0
until-0-04 ::= | [\n] until-0-01 | [a] until-0-05 | [^\na] until-0
until-0-05 ::= | [\n] until-0-01 | [r] until-0-06 | [^\nr] until-0
until-0-06 ::= | [\n] until-0-01 | [a] until-0-07 | [^\na] until-0
until-0-07 ::= | [\n] until-0-01 | [m] until-0-08 | [^\nm] until-0
until-0-08 ::= | [\n] until-0-01 | [e] until-0-09 | [^\ne] until-0
until-0-09 ::= | [\n] until-0-01 | [t] until-0-10 | [^\nt] until-0
until-0-10 ::= | [\n] until-0-01 | [e] until-0-11 | [^\ne] until-0
until-0-11 ::= | [\n] until-0-01 | [r] until-0-12 | [^\nr] until-0
until-0-12 ::= | [\n] until-0-01 | [>] until-0-13 | [^\n>] until-0
until-0-13 ::= | [^\n] until-0
)""", gbnf);
});
// DeepSeek-V3.2 tag prefix. The DSML token (DSML) embeds U+FF5C,
// so the delimiter mixes ASCII and multi-byte codepoints.
t.test("until grammar unicode delimiter", [](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
return p.until("<DSML");
});
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
parser.build_grammar(builder);
});
assert_gbnf_equal(t, R"""(
root ::= until-0
space ::= | " " | "\n"{1,2} [ \t]{0,20}
until-0 ::= | [<] until-0-01 | [^<] until-0
until-0-01 ::= | [<] until-0-01 | [\uFF5C] until-0-02 | [^<\uFF5C] until-0
until-0-02 ::= | [<] until-0-01 | [D] until-0-03 | [^<D] until-0
until-0-03 ::= | [<] until-0-01 | [S] until-0-04 | [^<S] until-0
until-0-04 ::= | [<] until-0-01 | [M] until-0-05 | [^<M] until-0
until-0-05 ::= | [<] until-0-01 | [L] until-0-06 | [^<L] until-0
until-0-06 ::= | [<] until-0-01 | [^<\uFF5C] until-0
)""", gbnf);
});
t.test("until grammar multiple delimiters", [](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
return p.until_one_of({"ab", "cd", "ef"});
});
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
parser.build_grammar(builder);
});
assert_gbnf_equal(t, R"""(
root ::= until-0
space ::= | " " | "\n"{1,2} [ \t]{0,20}
until-0 ::= | [a] until-0-01 | [c] until-0-03 | [e] until-0-05 | [^ace] until-0
until-0-01 ::= | [a] until-0-01 | [c] until-0-03 | [e] until-0-05 | [^abce] until-0
until-0-03 ::= | [a] until-0-01 | [c] until-0-03 | [e] until-0-05 | [^acde] until-0
until-0-05 ::= | [a] until-0-01 | [c] until-0-03 | [e] until-0-05 | [^acef] until-0
)""", gbnf);
});
t.test("ac grammar", [](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
return p.ac(p.until("</tag>") + p.literal("</tag>"), "</tag>");
});
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
parser.build_grammar(builder);
});
assert_gbnf_equal(t, R"""(
ac-3 ::= [<] ac-3-01 | [^<] ac-3
ac-3-01 ::= [<] ac-3-01 | [/] ac-3-02 | [^/<] ac-3
ac-3-02 ::= [<] ac-3-01 | [t] ac-3-03 | [^<t] ac-3
ac-3-03 ::= [<] ac-3-01 | [a] ac-3-04 | [^<a] ac-3
ac-3-04 ::= [<] ac-3-01 | [g] ac-3-05 | [^<g] ac-3
ac-3-05 ::= [>] | [<] ac-3-01 | [^<>] ac-3
root ::= ac-3
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)""", gbnf);
});
t.test("ac grammar terminates at first delimiter", [](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
return p.ac(p.until("\n</parameter>\n") + p.literal("\n</parameter>\n"), "\n</parameter>\n");
});
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
parser.build_grammar(builder);
});
assert_gbnf_equal(t, R"""(
ac-3 ::= [\n] ac-3-01 | [^\n] ac-3
ac-3-01 ::= [\n] ac-3-01 | [<] ac-3-02 | [^\n<] ac-3
ac-3-02 ::= [\n] ac-3-01 | [/] ac-3-03 | [^\n/] ac-3
ac-3-03 ::= [\n] ac-3-01 | [p] ac-3-04 | [^\np] ac-3
ac-3-04 ::= [\n] ac-3-01 | [a] ac-3-05 | [^\na] ac-3
ac-3-05 ::= [\n] ac-3-01 | [r] ac-3-06 | [^\nr] ac-3
ac-3-06 ::= [\n] ac-3-01 | [a] ac-3-07 | [^\na] ac-3
ac-3-07 ::= [\n] ac-3-01 | [m] ac-3-08 | [^\nm] ac-3
ac-3-08 ::= [\n] ac-3-01 | [e] ac-3-09 | [^\ne] ac-3
ac-3-09 ::= [\n] ac-3-01 | [t] ac-3-10 | [^\nt] ac-3
ac-3-10 ::= [\n] ac-3-01 | [e] ac-3-11 | [^\ne] ac-3
ac-3-11 ::= [\n] ac-3-01 | [r] ac-3-12 | [^\nr] ac-3
ac-3-12 ::= [\n] ac-3-01 | [>] ac-3-13 | [^\n>] ac-3
ac-3-13 ::= [\n] | [^\n] ac-3
root ::= ac-3
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)""", gbnf);
});
t.test("ac grammar multiple delimiters", [](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
return p.ac(p.eps(), std::vector<std::string>{"ab", "cd", "ef"});
});
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
parser.build_grammar(builder);
});
assert_gbnf_equal(t, R"""(
ac-1 ::= [a] ac-1-01 | [c] ac-1-03 | [e] ac-1-05 | [^ace] ac-1
ac-1-01 ::= [b] | [a] ac-1-01 | [c] ac-1-03 | [e] ac-1-05 | [^abce] ac-1
ac-1-03 ::= [d] | [a] ac-1-01 | [c] ac-1-03 | [e] ac-1-05 | [^acde] ac-1
ac-1-05 ::= [f] | [a] ac-1-01 | [c] ac-1-03 | [e] ac-1-05 | [^acef] ac-1
root ::= ac-1
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)""", gbnf);
});
+11 -1
View File
@@ -10,7 +10,7 @@
#undef NDEBUG
#include <cassert>
int main(void) {
static void test(void) {
common_params params;
printf("test-arg-parser: make sure there is no duplicated arguments in any examples\n\n");
@@ -210,3 +210,13 @@ int main(void) {
printf("test-arg-parser: all tests OK\n\n");
}
int main(void) {
try {
test();
} catch (std::exception & e) {
fprintf(stderr, "test-arg-parser: exception: %s\n", e.what());
return 1;
}
return 0;
}
+2
View File
@@ -8433,6 +8433,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {2, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 4, k, {3, 2}, {2, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {2, 1}));
@@ -8449,6 +8450,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {2, 3}, {1, 1}, {0, 1, 3, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 4, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 1, 3, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
+101 -26
View File
@@ -1562,37 +1562,112 @@ static void test_msgs_oaicompat_json_conversion() {
}
}
static void test_split_by_role() {
static void test_msg_token_delimiters_split() {
LOG_DBG("%s\n", __func__);
// Delimiters that share a leading token, distinguished by the second token,
// to exercise the per-position token matching.
const common_chat_msg_delimiters delims = {
{ { COMMON_CHAT_ROLE_USER, "", { 10, 11 } },
{ COMMON_CHAT_ROLE_ASSISTANT, "", { 10, 12 } } }
};
// Empty inputs
assert_equals<size_t>(0, common_chat_split_by_role("", {}).size());
assert_equals<size_t>(0, common_chat_split_by_role("hello", {}).size());
assert_equals<size_t>(0, common_chat_split_by_role("", { { "user", "<|user|>" } }).size());
assert_equals<size_t>(0, common_chat_msg_delimiters{}.split({}).spans.size());
assert_equals<size_t>(0, common_chat_msg_delimiters{}.split({ 10, 11 }).spans.size());
assert_equals<size_t>(0, delims.split({}).spans.size());
// Multi-role conversation, no leading/trailing content
// No delimiters match -> no spans
assert_equals<size_t>(0, delims.split({ 100, 101, 102 }).spans.size());
// Multi-role conversation: <user>Hi<assistant>Hello<user>Bye
{
const std::string prompt = "<|user|>Hi<|assistant|>Hello<|user|>Bye";
const auto splits = common_chat_split_by_role(prompt, {
{ "user", "<|user|>" },
{ "assistant", "<|assistant|>" },
});
assert_equals<size_t>(3, splits.size());
const llama_tokens tokens = {
10, 11, // <user>
100, 101, // Hi
10, 12, // <assistant>
200, 201, 202, // Hello
10, 11, // <user>
300, 301, // Bye
};
assert_equals<std::string>("user", splits[0].role);
assert_equals<size_t>(0, splits[0].pos);
assert_equals<size_t>(10, splits[0].len);
assert_equals<std::string>("<|user|>Hi", prompt.substr(splits[0].pos, splits[0].len));
const auto result = delims.split(tokens);
const auto & spans = result.spans;
assert_equals<size_t>(3, spans.size());
assert_equals<std::string>("assistant", splits[1].role);
assert_equals<size_t>(10, splits[1].pos);
assert_equals<size_t>(18, splits[1].len);
assert_equals<std::string>("<|assistant|>Hello", prompt.substr(splits[1].pos, splits[1].len));
assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role);
assert_equals<size_t>(0, spans[0].pos);
assert_equals<size_t>(4, spans[0].len);
assert_equals<std::string>("user", splits[2].role);
assert_equals<size_t>(28, splits[2].pos);
assert_equals<size_t>(11, splits[2].len);
assert_equals<std::string>("<|user|>Bye", prompt.substr(splits[2].pos, splits[2].len));
assert_equals(COMMON_CHAT_ROLE_ASSISTANT, spans[1].role);
assert_equals<size_t>(4, spans[1].pos);
assert_equals<size_t>(5, spans[1].len);
assert_equals(COMMON_CHAT_ROLE_USER, spans[2].role);
assert_equals<size_t>(9, spans[2].pos);
assert_equals<size_t>(4, spans[2].len);
// is_user_start() is true at the token position where a user span begins
assert_equals(true, result.is_user_start(0));
assert_equals(false, result.is_user_start(4)); // assistant span
assert_equals(true, result.is_user_start(9));
}
// Content before the first delimiter is not captured as a span
{
const llama_tokens tokens = {
500, 501, // leading content (dropped)
10, 11, // <user>
100, // Hi
};
const auto spans = delims.split(tokens).spans;
assert_equals<size_t>(1, spans.size());
assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role);
assert_equals<size_t>(2, spans[0].pos);
assert_equals<size_t>(3, spans[0].len);
}
// Skipped regions (media chunks) are jumped over but still count as span content
{
const llama_tokens tokens = {
10, 11, // <user>
LLAMA_TOKEN_NULL, // media chunk (3 tokens)
LLAMA_TOKEN_NULL,
LLAMA_TOKEN_NULL,
100, // Hi
10, 12, // <assistant>
};
const std::map<size_t, size_t> skips = { { 2, 3 } };
const auto spans = delims.split(tokens, skips).spans;
assert_equals<size_t>(2, spans.size());
assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role);
assert_equals<size_t>(0, spans[0].pos);
assert_equals<size_t>(6, spans[0].len);
assert_equals(COMMON_CHAT_ROLE_ASSISTANT, spans[1].role);
assert_equals<size_t>(6, spans[1].pos);
assert_equals<size_t>(2, spans[1].len);
}
// A delimiter sequence inside a skipped region is not matched
{
const llama_tokens tokens = {
10, 11, // <user>
10, 12, // skipped region that happens to contain delimiter tokens
100, // Hi
};
const std::map<size_t, size_t> skips = { { 2, 2 } };
const auto spans = delims.split(tokens, skips).spans;
assert_equals<size_t>(1, spans.size());
assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role);
assert_equals<size_t>(0, spans[0].pos);
assert_equals<size_t>(5, spans[0].len);
}
}
@@ -5022,14 +5097,14 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
tst.test("Hello, world!\nWhat's up?").tools({ special_function_tool }).expect(message_assist).expect_reconstruction().run();
tst.test(
"```json\n\"42\" \n```")
"```json\n\"42\"\n```")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.json_schema(const_schema)
.expect_content(R"("42")")
.run();
tst.test(
"\"42\" \n")
"\"42\"\n")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.json_schema(const_schema)
.expect_content(R"("42")")
@@ -5857,7 +5932,7 @@ int main(int argc, char ** argv) {
{
test_msg_diffs_compute();
test_msgs_oaicompat_json_conversion();
test_split_by_role();
test_msg_token_delimiters_split();
test_tools_oaicompat_json_conversion();
test_convert_responses_to_chatcmpl();
test_developer_role_to_system_workaround();
+26
View File
@@ -995,6 +995,32 @@ static void test_macros(testing & t) {
json::object(),
"Hello, John Smith,Hi, Jane Doe"
);
test_template(t, "macro with caller",
"\
{%- macro nest_dict(o, i, ff='') %}\n\
{{- caller(ff) }}\n\
{%- for k, v in o|items %}\n\
{{- i + k + ': ' }}\n\
{%- if v is mapping %}\n\
{{- '{' }}\n\
{% call(f) nest_dict(v, i + ' ') %}\n\
{{- 'fail' if ff is undefined }}\n\
{%- endcall %}\n\
{{- i + '}' }}\n\
{% else %}\n\
{{- v|string }}\n\
{% endif %}\n\
{%- endfor %}\n\
{%- endmacro %}\n\
{%- call(f) nest_dict({'root1': 1, 'root2': {'nest1': 1, 'nest2': {'nest3': 2}}}, ' ', 'Dict') %}\n\
{{- 'fail' if ff is defined }}\n\
{{- f + ' {' }}\n\
{% endcall %}\n\
{{- '}' }}",
json::object(),
"Dict {\n root1: 1\n root2: {\n nest1: 1\n nest2: {\n nest3: 2\n }\n }\n}"
);
}
static void test_namespace(testing & t) {
+155 -155
View File
@@ -92,7 +92,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"minimum": 0
})""",
R"""(
root ::= ([0] | [1-9] [0-9]{0,15}) space
root ::= ([0] | [1-9] [0-9]{0,15})
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -105,7 +105,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"minimum": 1
})""",
R"""(
root ::= ([1-9] [0-9]{0,15}) space
root ::= ([1-9] [0-9]{0,15})
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -118,7 +118,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"minimum": 3
})""",
R"""(
root ::= ([1-2] [0-9]{1,15} | [3-9] [0-9]{0,15}) space
root ::= ([1-2] [0-9]{1,15} | [3-9] [0-9]{0,15})
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -131,7 +131,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"minimum": 9
})""",
R"""(
root ::= ([1-8] [0-9]{1,15} | [9] [0-9]{0,15}) space
root ::= ([1-8] [0-9]{1,15} | [9] [0-9]{0,15})
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -144,7 +144,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"minimum": 10
})""",
R"""(
root ::= ([1] ([0-9]{1,15}) | [2-9] [0-9]{1,15}) space
root ::= ([1] ([0-9]{1,15}) | [2-9] [0-9]{1,15})
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -157,7 +157,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"minimum": 25
})""",
R"""(
root ::= ([1] [0-9]{2,15} | [2] ([0-4] [0-9]{1,14} | [5-9] [0-9]{0,14}) | [3-9] [0-9]{1,15}) space
root ::= ([1] [0-9]{2,15} | [2] ([0-4] [0-9]{1,14} | [5-9] [0-9]{0,14}) | [3-9] [0-9]{1,15})
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -170,7 +170,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"maximum": 30
})""",
R"""(
root ::= ("-" [1-9] [0-9]{0,15} | [0-9] | ([1-2] [0-9] | [3] "0")) space
root ::= ("-" [1-9] [0-9]{0,15} | [0-9] | ([1-2] [0-9] | [3] "0"))
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -183,7 +183,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"minimum": -5
})""",
R"""(
root ::= ("-" ([0-5]) | [0] | [1-9] [0-9]{0,15}) space
root ::= ("-" ([0-5]) | [0] | [1-9] [0-9]{0,15})
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -196,7 +196,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"minimum": -123
})""",
R"""(
root ::= ("-" ([0-9] | ([1-8] [0-9] | [9] [0-9]) | "1" ([0-1] [0-9] | [2] [0-3])) | [0] | [1-9] [0-9]{0,15}) space
root ::= ("-" ([0-9] | ([1-8] [0-9] | [9] [0-9]) | "1" ([0-1] [0-9] | [2] [0-3])) | [0] | [1-9] [0-9]{0,15})
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -209,7 +209,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"maximum": -5
})""",
R"""(
root ::= ("-" ([0-4] [0-9]{1,15} | [5-9] [0-9]{0,15})) space
root ::= ("-" ([0-4] [0-9]{1,15} | [5-9] [0-9]{0,15}))
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -222,7 +222,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"maximum": 1
})""",
R"""(
root ::= ("-" [1-9] [0-9]{0,15} | [0-1]) space
root ::= ("-" [1-9] [0-9]{0,15} | [0-1])
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -235,7 +235,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"maximum": 100
})""",
R"""(
root ::= ("-" [1-9] [0-9]{0,15} | [0-9] | ([1-8] [0-9] | [9] [0-9]) | "100") space
root ::= ("-" [1-9] [0-9]{0,15} | [0-9] | ([1-8] [0-9] | [9] [0-9]) | "100")
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -249,7 +249,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"maximum": 23
})""",
R"""(
root ::= ([0-9] | ([1] [0-9] | [2] [0-3])) space
root ::= ([0-9] | ([1] [0-9] | [2] [0-3]))
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -263,7 +263,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"maximum": 300
})""",
R"""(
root ::= (([1] ([5-9]) | [2-9] [0-9]) | ([1-2] [0-9]{2} | [3] "00")) space
root ::= (([1] ([5-9]) | [2-9] [0-9]) | ([1-2] [0-9]{2} | [3] "00"))
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -277,7 +277,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"maximum": 30
})""",
R"""(
root ::= ([5-9] | ([1-2] [0-9] | [3] "0")) space
root ::= ([5-9] | ([1-2] [0-9] | [3] "0"))
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -291,7 +291,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"maximum": 42
})""",
R"""(
root ::= ("-" ([0-9] | ([1-8] [0-9] | [9] [0-9]) | "1" ([0-1] [0-9] | [2] [0-3])) | [0-9] | ([1-3] [0-9] | [4] [0-2])) space
root ::= ("-" ([0-9] | ([1-8] [0-9] | [9] [0-9]) | "1" ([0-1] [0-9] | [2] [0-3])) | [0-9] | ([1-3] [0-9] | [4] [0-2]))
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -305,7 +305,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"maximum": 10
})""",
R"""(
root ::= ("-" ([0-9] | "10") | [0-9] | "10") space
root ::= ("-" ([0-9] | "10") | [0-9] | "10")
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -333,17 +333,17 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"empty schema (object)",
"{}",
R"""(
array ::= "[" space ( value ("," space value)* )? "]" space
boolean ::= ("true" | "false") space
array ::= "[" space ( value ("," space value)* )? space "]"
boolean ::= ("true" | "false")
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
decimal-part ::= [0-9]{1,16}
integral-part ::= [0] | [1-9] [0-9]{0,15}
null ::= "null" space
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
null ::= "null"
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? space "}"
root ::= object
space ::= | " " | "\n"{1,2} [ \t]{0,20}
string ::= "\"" char* "\"" space
string ::= "\"" char* "\""
value ::= object | array | string | number | boolean | null
)"""
});
@@ -361,17 +361,17 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""",
R"""(
date ::= [0-9]{4} "-" ( "0" [1-9] | "1" [0-2] ) "-" ( "0" [1-9] | [1-2] [0-9] | "3" [0-1] )
date-string ::= "\"" date "\"" space
date-string ::= "\"" date "\""
date-time ::= date "T" time
date-time-string ::= "\"" date-time "\"" space
root ::= "[" space tuple-0 "," space uuid "," space tuple-2 "," space tuple-3 "]" space
date-time-string ::= "\"" date-time "\""
root ::= "[" space tuple-0 "," space uuid "," space tuple-2 "," space tuple-3 space "]"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
time ::= ([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9]{3} )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )
time-string ::= "\"" time "\"" space
time-string ::= "\"" time "\""
tuple-0 ::= date-string
tuple-2 ::= time-string
tuple-3 ::= date-time-string
uuid ::= "\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\"" space
uuid ::= "\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\""
)"""
});
@@ -383,7 +383,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""",
R"""(
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
root ::= "\"" char* "\"" space
root ::= "\"" char* "\""
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -397,7 +397,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""",
R"""(
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
root ::= "\"" char+ "\"" space
root ::= "\"" char+ "\""
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -411,7 +411,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""",
R"""(
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
root ::= "\"" char{3,} "\"" space
root ::= "\"" char{3,} "\""
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -425,7 +425,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""",
R"""(
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
root ::= "\"" char{0,3} "\"" space
root ::= "\"" char{0,3} "\""
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -440,7 +440,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""",
R"""(
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
root ::= "\"" char{1,4} "\"" space
root ::= "\"" char{1,4} "\""
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -452,7 +452,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"type": "boolean"
})""",
R"""(
root ::= ("true" | "false") space
root ::= ("true" | "false")
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -465,7 +465,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""",
R"""(
integral-part ::= [0] | [1-9] [0-9]{0,15}
root ::= ("-"? integral-part) space
root ::= ("-"? integral-part)
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -477,7 +477,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"const": "foo"
})""",
R"""(
root ::= "\"foo\"" space
root ::= "\"foo\""
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -489,7 +489,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"const": 123
})""",
R"""(
root ::= "123" space
root ::= "123"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -501,7 +501,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"enum": ["red", "amber", "green", null, 42, ["foo"]]
})""",
R"""(
root ::= ("\"red\"" | "\"amber\"" | "\"green\"" | "null" | "42" | "[\"foo\"]") space
root ::= ("\"red\"" | "\"amber\"" | "\"green\"" | "null" | "42" | "[\"foo\"]")
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -515,9 +515,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""",
R"""(
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
root ::= "[" space (string ("," space string)*)? "]" space
root ::= "[" space (string ("," space string)*)? space "]"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
string ::= "\"" char* "\"" space
string ::= "\"" char* "\""
)"""
});
@@ -529,12 +529,12 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"prefixItems": { "type": "string" }
})""",
R"""(
alternative-0 ::= "[" space (string ("," space string)*)? "]" space
alternative-0 ::= "[" space (string ("," space string)*)? space "]"
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
null ::= "null" space
null ::= "null"
root ::= alternative-0 | null
space ::= | " " | "\n"{1,2} [ \t]{0,20}
string ::= "\"" char* "\"" space
string ::= "\"" char* "\""
)"""
});
@@ -546,9 +546,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""",
R"""(
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
root ::= "[" space string "]" space
root ::= "[" space string space "]"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
string ::= "\"" char* "\"" space
string ::= "\"" char* "\""
)"""
});
@@ -562,10 +562,10 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
decimal-part ::= [0-9]{1,16}
integral-part ::= [0] | [1-9] [0-9]{0,15}
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
root ::= "[" space string "," space number "]" space
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
root ::= "[" space string "," space number space "]"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
string ::= "\"" char* "\"" space
string ::= "\"" char* "\""
)"""
});
@@ -577,18 +577,18 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"items": {}
})""",
R"""(
array ::= "[" space ( value ("," space value)* )? "]" space
boolean ::= ("true" | "false") space
array ::= "[" space ( value ("," space value)* )? space "]"
boolean ::= ("true" | "false")
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
decimal-part ::= [0-9]{1,16}
integral-part ::= [0] | [1-9] [0-9]{0,15}
item ::= object
null ::= "null" space
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
root ::= "[" space (item ("," space item)*)? "]" space
null ::= "null"
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? space "}"
root ::= "[" space (item ("," space item)*)? space "]"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
string ::= "\"" char* "\"" space
string ::= "\"" char* "\""
value ::= object | array | string | number | boolean | null
)"""
});
@@ -602,18 +602,18 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"prefixItems": { "type": "string" }
})""",
R"""(
array ::= "[" space ( value ("," space value)* )? "]" space
boolean ::= ("true" | "false") space
array ::= "[" space ( value ("," space value)* )? space "]"
boolean ::= ("true" | "false")
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
decimal-part ::= [0-9]{1,16}
integral-part ::= [0] | [1-9] [0-9]{0,15}
item ::= object
null ::= "null" space
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
root ::= "[" space (item ("," space item)*)? "]" space
null ::= "null"
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? space "}"
root ::= "[" space (item ("," space item)*)? space "]"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
string ::= "\"" char* "\"" space
string ::= "\"" char* "\""
value ::= object | array | string | number | boolean | null
)"""
});
@@ -627,7 +627,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
R"""(
decimal-part ::= [0-9]{1,16}
integral-part ::= [0] | [1-9] [0-9]{0,15}
root ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
root ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -642,8 +642,8 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"minItems": 2
})""",
R"""(
boolean ::= ("true" | "false") space
root ::= "[" space boolean ("," space boolean)+ "]" space
boolean ::= ("true" | "false")
root ::= "[" space boolean ("," space boolean)+ space "]"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -658,8 +658,8 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"maxItems": 0
})""",
R"""(
boolean ::= ("true" | "false") space
root ::= "[" space "]" space
boolean ::= ("true" | "false")
root ::= "[" space space "]"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -674,8 +674,8 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"maxItems": 1
})""",
R"""(
boolean ::= ("true" | "false") space
root ::= "[" space boolean? "]" space
boolean ::= ("true" | "false")
root ::= "[" space boolean? space "]"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -690,8 +690,8 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"maxItems": 2
})""",
R"""(
boolean ::= ("true" | "false") space
root ::= "[" space (boolean ("," space boolean)?)? "]" space
boolean ::= ("true" | "false")
root ::= "[" space (boolean ("," space boolean)?)? space "]"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -708,11 +708,11 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""",
R"""(
decimal-part ::= [0-9]{1,16}
integer ::= ("-"? integral-part) space
integer ::= ("-"? integral-part)
integral-part ::= [0] | [1-9] [0-9]{0,15}
item ::= number | integer
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
root ::= "[" space item ("," space item){2,4} "]" space
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
root ::= "[" space item ("," space item){2,4} space "]"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -730,8 +730,8 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"maxItems": 5
})""",
R"""(
item ::= ("-" ([0-9] | "1" [0-2]) | [0-9] | ([1-8] [0-9] | [9] [0-9]) | ([1] [0-9]{2} | [2] "0" [0-7])) space
root ::= "[" space item ("," space item){2,4} "]" space
item ::= ("-" ([0-9] | "1" [0-2]) | [0-9] | ([1-8] [0-9] | [9] [0-9]) | ([1] [0-9]{2} | [2] "0" [0-7]))
root ::= "[" space item ("," space item){2,4} space "]"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -749,8 +749,8 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"maxItems": 5
})""",
R"""(
item ::= (([1] ([2-9]) | [2-9] [0-9]) | ([1] [0-9]{2} | [2] "0" [0-7])) space
root ::= "[" space item ("," space item){2,4} "]" space
item ::= (([1] ([2-9]) | [2-9] [0-9]) | ([1] [0-9]{2} | [2] "0" [0-7]))
root ::= "[" space item ("," space item){2,4} space "]"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -763,7 +763,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"pattern": "^abc?d*efg+(hij)?kl$"
})""",
R"""(
root ::= "\"" ("ab" "c"? "d"* "ef" "g"+ ("hij")? "kl") "\"" space
root ::= "\"" ("ab" "c"? "d"* "ef" "g"+ ("hij")? "kl") "\""
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -776,7 +776,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"pattern": "^\\[\\]\\{\\}\\(\\)\\|\\+\\*\\?$"
})""",
R"""(
root ::= "\"" ("[]{}()|+*?") "\"" space
root ::= "\"" ("[]{}()|+*?") "\""
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -789,7 +789,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"pattern": "^\"$"
})""",
R"""(
root ::= "\"" ("\"") "\"" space
root ::= "\"" ("\"") "\""
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -802,7 +802,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"pattern": "^A|B|C|D$"
})""",
R"""(
root ::= "\"" ("A" | "B" | "C" | "D") "\"" space
root ::= "\"" ("A" | "B" | "C" | "D") "\""
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -816,7 +816,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""",
R"""(
dot ::= [^\x0A\x0D]
root ::= "\"" (("(" root-1{1,3} ")")? root-1{3,3} "-" root-1{4,4} " " "a"{3,5} "nd" dot dot dot) "\"" space
root ::= "\"" (("(" root-1{1,3} ")")? root-1{3,3} "-" root-1{4,4} " " "a"{3,5} "nd" dot dot dot) "\""
root-1 ::= [0-9]
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
@@ -845,9 +845,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
b-kv ::= "\"b\"" space ":" space string
c-kv ::= "\"c\"" space ":" space string
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
root ::= "{" space b-kv "," space c-kv "," space a-kv "}" space
root ::= "{" space b-kv "," space c-kv "," space a-kv space "}"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
string ::= "\"" char* "\"" space
string ::= "\"" char* "\""
)"""
});
@@ -865,9 +865,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
R"""(
a-kv ::= "\"a\"" space ":" space string
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
root ::= "{" space (a-kv )? "}" space
root ::= "{" space (a-kv )? space "}"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
string ::= "\"" char* "\"" space
string ::= "\"" char* "\""
)"""
});
@@ -889,9 +889,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
b-rest ::= ( "," space c-kv )?
c-kv ::= "\"c\"" space ":" space string
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
root ::= "{" space (a-kv a-rest | b-kv b-rest | c-kv )? "}" space
root ::= "{" space (a-kv a-rest | b-kv b-rest | c-kv )? space "}"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
string ::= "\"" char* "\"" space
string ::= "\"" char* "\""
)"""
});
@@ -915,9 +915,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
d-kv ::= "\"d\"" space ":" space string
d-rest ::= ( "," space c-kv )?
root ::= "{" space b-kv "," space a-kv ( "," space ( d-kv d-rest | c-kv ) )? "}" space
root ::= "{" space b-kv "," space a-kv ( "," space ( d-kv d-rest | c-kv ) )? space "}"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
string ::= "\"" char* "\"" space
string ::= "\"" char* "\""
)"""
});
@@ -930,14 +930,14 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""",
R"""(
additional-kv ::= string ":" space additional-value
additional-value ::= "[" space (number ("," space number)*)? "]" space
additional-value ::= "[" space (number ("," space number)*)? space "]"
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
decimal-part ::= [0-9]{1,16}
integral-part ::= [0] | [1-9] [0-9]{0,15}
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
root ::= "{" space (additional-kv ( "," space additional-kv )* )? "}" space
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
root ::= "{" space (additional-kv ( "," space additional-kv )* )? space "}"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
string ::= "\"" char* "\"" space
string ::= "\"" char* "\""
)"""
});
@@ -949,17 +949,17 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"additionalProperties": true
})""",
R"""(
array ::= "[" space ( value ("," space value)* )? "]" space
boolean ::= ("true" | "false") space
array ::= "[" space ( value ("," space value)* )? space "]"
boolean ::= ("true" | "false")
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
decimal-part ::= [0-9]{1,16}
integral-part ::= [0] | [1-9] [0-9]{0,15}
null ::= "null" space
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
null ::= "null"
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? space "}"
root ::= object
space ::= | " " | "\n"{1,2} [ \t]{0,20}
string ::= "\"" char* "\"" space
string ::= "\"" char* "\""
value ::= object | array | string | number | boolean | null
)"""
});
@@ -971,17 +971,17 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"type": "object"
})""",
R"""(
array ::= "[" space ( value ("," space value)* )? "]" space
boolean ::= ("true" | "false") space
array ::= "[" space ( value ("," space value)* )? space "]"
boolean ::= ("true" | "false")
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
decimal-part ::= [0-9]{1,16}
integral-part ::= [0] | [1-9] [0-9]{0,15}
null ::= "null" space
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
null ::= "null"
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? space "}"
root ::= object
space ::= | " " | "\n"{1,2} [ \t]{0,20}
string ::= "\"" char* "\"" space
string ::= "\"" char* "\""
value ::= object | array | string | number | boolean | null
)"""
});
@@ -994,7 +994,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"additionalProperties": false
})""",
R"""(
root ::= "{" space "}" space
root ::= "{" space space "}"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -1012,15 +1012,15 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""",
R"""(
a-kv ::= "\"a\"" space ":" space number
additional-k ::= ["] ( [a] char+ | [^"a] char* )? ["] space
additional-k ::= ["] ( [a] char+ | [^"a] char* )? ["]
additional-kv ::= additional-k ":" space string
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
decimal-part ::= [0-9]{1,16}
integral-part ::= [0] | [1-9] [0-9]{0,15}
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
root ::= "{" space a-kv ( "," space ( additional-kv ( "," space additional-kv )* ) )? "}" space
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
root ::= "{" space a-kv ( "," space ( additional-kv ( "," space additional-kv )* ) )? space "}"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
string ::= "\"" char* "\"" space
string ::= "\"" char* "\""
)"""
});
@@ -1037,13 +1037,13 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
R"""(
a-kv ::= "\"a\"" space ":" space number
a-rest ::= ( "," space additional-kv )*
additional-k ::= ["] ( [a] char+ | [^"a] char* )? ["] space
additional-k ::= ["] ( [a] char+ | [^"a] char* )? ["]
additional-kv ::= additional-k ":" space number
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
decimal-part ::= [0-9]{1,16}
integral-part ::= [0] | [1-9] [0-9]{0,15}
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
root ::= "{" space (a-kv a-rest | additional-kv ( "," space additional-kv )* )? "}" space
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
root ::= "{" space (a-kv a-rest | additional-kv ( "," space additional-kv )* )? space "}"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -1061,7 +1061,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"additionalProperties": {"type": "number"}
})""",
R"""(
additional-k ::= ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["] space
additional-k ::= ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["]
additional-kv ::= additional-k ":" space number
also-kv ::= "\"also\"" space ":" space number
also-rest ::= ( "," space additional-kv )*
@@ -1069,8 +1069,8 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
decimal-part ::= [0-9]{1,16}
integral-part ::= [0] | [1-9] [0-9]{0,15}
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
root ::= "{" space and-kv ( "," space ( also-kv also-rest | additional-kv ( "," space additional-kv )* ) )? "}" space
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
root ::= "{" space and-kv ( "," space ( also-kv also-rest | additional-kv ( "," space additional-kv )* ) )? space "}"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -1090,13 +1090,13 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
-rest ::= ( "," space a-kv )? a-rest
a-kv ::= "\"a\"" space ":" space integer
a-rest ::= ( "," space additional-kv )*
additional-k ::= ["] ( [a] char+ | [^"a] char* ) ["] space
additional-k ::= ["] ( [a] char+ | [^"a] char* ) ["]
additional-kv ::= additional-k ":" space integer
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
integer ::= ("-"? integral-part) space
integer ::= ("-"? integral-part)
integral-part ::= [0] | [1-9] [0-9]{0,15}
root ::= ("-"? integral-part) space
root0 ::= "{" space (-kv -rest | a-kv a-rest | additional-kv ( "," space additional-kv )* )? "}" space
root ::= ("-"? integral-part)
root0 ::= "{" space (-kv -rest | a-kv a-rest | additional-kv ( "," space additional-kv )* )? space "}"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -1116,12 +1116,12 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
a-rest ::= ( "," space aa-kv )? aa-rest
aa-kv ::= "\"aa\"" space ":" space integer
aa-rest ::= ( "," space additional-kv )*
additional-k ::= ["] ( [a] ([a] char+ | [^"a] char*) | [^"a] char* )? ["] space
additional-k ::= ["] ( [a] ([a] char+ | [^"a] char*) | [^"a] char* )? ["]
additional-kv ::= additional-k ":" space integer
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
integer ::= ("-"? integral-part) space
integer ::= ("-"? integral-part)
integral-part ::= [0] | [1-9] [0-9]{0,15}
root ::= "{" space (a-kv a-rest | aa-kv aa-rest | additional-kv ( "," space additional-kv )* )? "}" space
root ::= "{" space (a-kv a-rest | aa-kv aa-rest | additional-kv ( "," space additional-kv )* )? space "}"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -1141,12 +1141,12 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
ab-rest ::= ( "," space ac-kv )? ac-rest
ac-kv ::= "\"ac\"" space ":" space integer
ac-rest ::= ( "," space additional-kv )*
additional-k ::= ["] ( [a] ([b] char+ | [c] char+ | [^"bc] char*) | [^"a] char* )? ["] space
additional-k ::= ["] ( [a] ([b] char+ | [c] char+ | [^"bc] char*) | [^"a] char* )? ["]
additional-kv ::= additional-k ":" space integer
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
integer ::= ("-"? integral-part) space
integer ::= ("-"? integral-part)
integral-part ::= [0] | [1-9] [0-9]{0,15}
root ::= "{" space (ab-kv ab-rest | ac-kv ac-rest | additional-kv ( "," space additional-kv )* )? "}" space
root ::= "{" space (ab-kv ab-rest | ac-kv ac-rest | additional-kv ( "," space additional-kv )* )? space "}"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -1173,11 +1173,11 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""",
R"""(
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
ref-definitions-foo ::= "{" space ref-definitions-foo-a-kv "}" space
ref-definitions-foo ::= "{" space ref-definitions-foo-a-kv space "}"
ref-definitions-foo-a-kv ::= "\"a\"" space ":" space string
root ::= ref-definitions-foo
space ::= | " " | "\n"{1,2} [ \t]{0,20}
string ::= "\"" char* "\"" space
string ::= "\"" char* "\""
)"""
});
@@ -1204,10 +1204,10 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
alternative-1 ::= ref-definitions-bar
decimal-part ::= [0-9]{1,16}
integral-part ::= [0] | [1-9] [0-9]{0,15}
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
ref-definitions-bar ::= "{" space (ref-definitions-bar-b-kv )? "}" space
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
ref-definitions-bar ::= "{" space (ref-definitions-bar-b-kv )? space "}"
ref-definitions-bar-b-kv ::= "\"b\"" space ":" space number
ref-definitions-foo ::= "{" space (ref-definitions-foo-a-kv )? "}" space
ref-definitions-foo ::= "{" space (ref-definitions-foo-a-kv )? space "}"
ref-definitions-foo-a-kv ::= "\"a\"" space ":" space number
root ::= alternative-0 | alternative-1
space ::= | " " | "\n"{1,2} [ \t]{0,20}
@@ -1241,14 +1241,14 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
b ::= b-0 | boolean
b-0 ::= string
b-kv ::= "\"b\"" space ":" space b
boolean ::= ("true" | "false") space
boolean ::= ("true" | "false")
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
decimal-part ::= [0-9]{1,16}
integral-part ::= [0] | [1-9] [0-9]{0,15}
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
root ::= "{" space (a-kv a-rest | b-kv )? "}" space
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
root ::= "{" space (a-kv a-rest | b-kv )? space "}"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
string ::= "\"" char* "\"" space
string ::= "\"" char* "\""
)"""
});
@@ -1290,8 +1290,8 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
d-rest ::= ( "," space c-kv )?
decimal-part ::= [0-9]{1,16}
integral-part ::= [0] | [1-9] [0-9]{0,15}
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
root ::= "{" space a-kv "," space b-kv ( "," space ( d-kv d-rest | c-kv ) )? "}" space
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
root ::= "{" space a-kv "," space b-kv ( "," space ( d-kv d-rest | c-kv ) )? space "}"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -1311,7 +1311,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
}
})""",
R"""(
root ::= ("\"a\"" | "\"b\"") space
root ::= ("\"a\"" | "\"b\"")
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -1336,7 +1336,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
}
})""",
R"""(
root ::= ("\"b\"" | "\"c\"") space
root ::= ("\"b\"" | "\"c\"")
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -1378,13 +1378,13 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
R"""(
decimal-part ::= [0-9]{1,16}
integral-part ::= [0] | [1-9] [0-9]{0,15}
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
number- ::= "{" space number-number-kv "}" space
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
number- ::= "{" space number-number-kv space "}"
number-kv ::= "\"number\"" space ":" space number-
number-number ::= "{" space number-number-root-kv "}" space
number-number ::= "{" space number-number-root-kv space "}"
number-number-kv ::= "\"number\"" space ":" space number-number
number-number-root-kv ::= "\"root\"" space ":" space number
root ::= "{" space number-kv "}" space
root ::= "{" space number-kv space "}"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -1394,17 +1394,17 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"description only (no type) treated as unconstrained",
R"""({"description": "The 0-based index of the last line to be retrieved (inclusive). If None, read until the end of the file."})""",
R"""(
array ::= "[" space ( value ("," space value)* )? "]" space
boolean ::= ("true" | "false") space
array ::= "[" space ( value ("," space value)* )? space "]"
boolean ::= ("true" | "false")
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
decimal-part ::= [0-9]{1,16}
integral-part ::= [0] | [1-9] [0-9]{0,15}
null ::= "null" space
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
null ::= "null"
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? space "}"
root ::= value
space ::= | " " | "\n"{1,2} [ \t]{0,20}
string ::= "\"" char* "\"" space
string ::= "\"" char* "\""
value ::= object | array | string | number | boolean | null
)"""
});
@@ -1428,9 +1428,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"type": "object"
})""",
R"""(
code ::= "\" \\r \\n \\\" \\\\ \"" space
code ::= "\" \\r \\n \\\" \\\\ \""
code-kv ::= "\"code\"" space ":" space code
root ::= "{" space code-kv "}" space
root ::= "{" space code-kv space "}"
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
@@ -1547,7 +1547,7 @@ int main() {
"pattern": "^(?:foo|bar)baz$"
})""",
R"""(
root ::= "\"" (("foo" | "bar") "baz") "\"" space
root ::= "\"" (("foo" | "bar") "baz") "\""
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)""",
});
@@ -1560,7 +1560,7 @@ int main() {
"pattern": "^(?:(?:ab)+c)?d$"
})""",
R"""(
root ::= "\"" ((("ab")+ "c")? "d") "\"" space
root ::= "\"" ((("ab")+ "c")? "d") "\""
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)""",
});
+2 -2
View File
@@ -360,9 +360,9 @@ int main(void) {
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 1}, {0.241818f, 0.241818f, 0.032727f, 0.241818f, 0.241818f}, 2.0f, 1.1f, 2, 5, {});
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 4, 7, {});
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f, 0.0f, 0.0f}, 1.00f);
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.0f, 0.0f, 0.428571f, 0.571429f}, 1.00f);
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.1f, 0.2f, 0.3f, 0.4f}, 0.00f); // top_n_sigma == 0 now represents a no-op rather than greedy decoding as of PR#13345
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 3.00f);
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.1f, 0.2f, 0.3f, 0.4f}, 3.00f);
test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f);
test_sampler_queue(10000, "k", 1, 1.0f, 1.0f);
+4 -2
View File
@@ -2,11 +2,13 @@
set(TARGET llama-cli-impl)
add_library(${TARGET} cli.cpp)
add_library(${TARGET} cli.cpp
cli-client.cpp
cli-context.cpp)
set_target_properties(${TARGET} PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS ON)
target_include_directories(${TARGET} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} ../server)
target_link_libraries(${TARGET} PUBLIC server-context llama-common ${CMAKE_THREAD_LIBS_INIT})
target_link_libraries(${TARGET} PUBLIC llama-server-impl llama-common ${CMAKE_THREAD_LIBS_INIT})
if(LLAMA_TOOLS_INSTALL)
install(TARGETS ${TARGET} LIBRARY)
+164
View File
@@ -0,0 +1,164 @@
#include "cli-client.h"
#include "http.h"
#include <algorithm>
#include <chrono>
#include <thread>
// generation can stall for a long time during prompt processing, so the
// read timeout must be generous
static constexpr time_t CLI_HTTP_READ_TIMEOUT_SEC = 3600;
// upper bound for the accumulated response body kept for error reporting
static constexpr size_t CLI_HTTP_MAX_ERROR_BODY = 1024 * 1024;
// returns the path with the base url's path prefix prepended (if any)
static std::string join_path(const common_http_url & parts, const std::string & path) {
if (parts.path.empty() || parts.path == "/") {
return path;
}
std::string prefix = parts.path;
if (prefix.back() == '/') {
prefix.pop_back();
}
return prefix + path;
}
json cli_client::get(const std::string & path) {
auto [cli, parts] = common_http_client(server_base);
cli.set_read_timeout(CLI_HTTP_READ_TIMEOUT_SEC, 0);
auto path_with_model = path + (model.empty() ? "" : ("?model=" + model));
auto res = cli.Get(join_path(parts, path_with_model));
if (!res) {
throw std::runtime_error("failed to connect to " + server_base + ": " + httplib::to_string(res.error()));
}
if (res->status < 200 || res->status >= 300) {
throw std::runtime_error("GET " + path + " failed with status " + std::to_string(res->status) + ": " + res->body);
}
json result = json::parse(res->body, nullptr, false);
if (result.is_discarded()) {
throw std::runtime_error("GET " + path + " returned invalid JSON");
}
return result;
}
json cli_client::post(const std::string & path, const json & body) {
auto [cli, parts] = common_http_client(server_base);
cli.set_read_timeout(CLI_HTTP_READ_TIMEOUT_SEC, 0);
auto body_with_model = body;
if (!model.empty()) {
body_with_model["model"] = model;
}
auto res = cli.Post(join_path(parts, path), body_with_model.dump(), "application/json");
if (!res) {
throw std::runtime_error("failed to connect to " + server_base + ": " + httplib::to_string(res.error()));
}
if (res->status < 200 || res->status >= 300) {
throw std::runtime_error("POST " + path + " failed with status " + std::to_string(res->status) + ": " + res->body);
}
json result = json::parse(res->body, nullptr, false);
if (result.is_discarded()) {
throw std::runtime_error("POST " + path + " returned invalid JSON");
}
return result;
}
json cli_client::post_sse(const std::string & path,
const json & body,
const std::function<bool()> & should_stop,
const std::function<void(const json &)> & on_data) {
auto [cli, parts] = common_http_client(server_base);
cli.set_read_timeout(CLI_HTTP_READ_TIMEOUT_SEC, 0);
std::string pending; // buffer for incomplete SSE lines
std::string raw_body; // accumulated body, used only for error reporting
auto receiver = [&](const char * data, size_t len) -> bool {
if (should_stop()) {
return false; // aborts the request
}
if (raw_body.size() < CLI_HTTP_MAX_ERROR_BODY) {
raw_body.append(data, std::min(len, CLI_HTTP_MAX_ERROR_BODY - raw_body.size()));
}
pending.append(data, len);
size_t pos;
while ((pos = pending.find('\n')) != std::string::npos) {
std::string line = pending.substr(0, pos);
pending.erase(0, pos + 1);
if (!line.empty() && line.back() == '\r') {
line.pop_back();
}
if (line.rfind("data: ", 0) != 0) {
continue;
}
std::string payload = line.substr(6);
if (payload == "[DONE]") {
continue;
}
json event = json::parse(payload, nullptr, false);
if (!event.is_discarded()) {
on_data(event);
}
}
return true;
};
httplib::Headers headers = {{"Accept", "text/event-stream"}};
auto body_with_model = body;
if (!model.empty()) {
body_with_model["model"] = model;
}
auto res = cli.Post(join_path(parts, path), headers, body_with_model.dump(), "application/json", receiver);
if (!res) {
if (res.error() == httplib::Error::Canceled && should_stop()) {
return json(); // cancelled by the user
}
return json {{"error", {{"message", "failed to connect to " + server_base + ": " + httplib::to_string(res.error())}}}};
}
if (res->status < 200 || res->status >= 300) {
json error_body = json::parse(raw_body, nullptr, false);
if (!error_body.is_discarded() && error_body.contains("error")) {
return error_body;
}
return json {{"error", {{"message", "request failed with status " + std::to_string(res->status)}}}};
}
return json();
}
bool cli_client::wait_health(const std::function<bool()> & is_aborted) {
int connect_attempts = 0;
while (!is_aborted()) {
auto [cli, parts] = common_http_client(server_base);
cli.set_connection_timeout(1, 0);
auto res = cli.Get(join_path(parts, "/health"));
if (res) {
if (res->status == 200) {
return true;
}
// any other status means the server is up but not ready yet
// (e.g. 503 while the model is still loading)
} else if (++connect_attempts >= 10) {
last_error = "failed to connect to " + server_base + ": " + httplib::to_string(res.error());
return false;
}
std::this_thread::sleep_for(std::chrono::milliseconds(300));
}
last_error = "aborted while waiting for the server to become ready";
return false;
}
std::vector<std::string> cli_client::list_models() {
json resp = get("/v1/models");
if (!resp.contains("data") || !resp.at("data").is_array()) {
throw std::runtime_error("invalid response from /v1/models");
}
std::vector<std::string> models;
for (const auto & m : resp.at("data")) {
if (m.contains("id") && m.at("id").is_string()) {
models.push_back(m.at("id").get<std::string>());
}
}
return models;
}
+56
View File
@@ -0,0 +1,56 @@
#pragma once
#include "ggml.h"
#define JSON_ASSERT GGML_ASSERT
#include <nlohmann/json.hpp>
#include <functional>
#include <string>
using json = nlohmann::ordered_json;
// openai-like client for CLI
struct cli_client {
std::string server_base; // base url, for example "http://127.0.0.1:8080"
std::string last_error; // set when wait_health() fails
std::string model; // optional, set when the server has multiple models (router mode)
// simple GET request, returns the response json
// throws std::runtime_error on transport error or non-2xx status
json get(const std::string & path);
// simple POST request, returns the response json
// throws std::runtime_error on transport error or non-2xx status
json post(const std::string & path, const json & body);
// POST request with an SSE streaming response; on_data is invoked once
// per "data:" event; the function returns after the stream is finished:
// a null json on graceful exit (incl. cancellation via should_stop),
// the error response json otherwise
json post_sse(const std::string & path,
const json & body,
const std::function<bool()> & should_stop,
const std::function<void(const json &)> & on_data);
// poll /health until the server is ready to accept requests
// returns false if is_aborted returned true or the server is unreachable
bool wait_health(const std::function<bool()> & is_aborted);
//
// higher-level wrappers
//
json create_chat_completion(const json & request,
const std::function<bool()> & should_stop,
const std::function<void(const json &)> & on_data) {
return post_sse("/v1/chat/completions", request, should_stop, on_data);
}
json get_props() {
return get("/props");
}
std::vector<std::string> list_models();
};
+559
View File
@@ -0,0 +1,559 @@
#include "cli-context.h"
#include "cli-view.h"
#include "arg.h"
#include "base64.hpp"
#include "log.h"
#include "console.h"
#include <algorithm>
#include <filesystem>
#include <fstream>
#include <map>
#include <set>
std::atomic<bool> g_cli_interrupted = false;
static bool should_stop() {
return g_cli_interrupted.load();
}
static constexpr size_t FILE_GLOB_MAX_RESULTS = 100;
const char * LLAMA_ASCII_LOGO = R"(
)";
// number of values an arg consumes on the command line
static int arg_num_values(const common_arg & opt) {
if (opt.value_hint_2 != nullptr) {
return 2;
}
if (opt.value_hint != nullptr) {
return 1;
}
return 0;
}
static std::string format_error_message(const json & err) {
if (err.contains("error") && err.at("error").is_object()) {
const auto & e = err.at("error");
if (e.contains("message") && e.at("message").is_string()) {
return e.at("message").get<std::string>();
}
}
return err.dump();
}
static std::string media_type_from_ext(const std::string & fname) {
std::string ext = std::filesystem::path(fname).extension().string();
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
if (ext == ".wav" || ext == ".mp3") {
return "audio";
}
if (ext == ".mp4" || ext == ".avi" || ext == ".mkv" || ext == ".mov" || ext == ".webm") {
return "video";
}
return "image";
}
bool cli_context::init() {
view::init(params);
std::optional<view::spinner> spinner;
bool use_external_server = !params.server_base.empty();
if (use_external_server) {
std::string base = params.server_base;
while (!base.empty() && base.back() == '/') {
base.pop_back();
}
client.server_base = base;
spinner.emplace("Connecting to server at " + base);
} else {
if (params.model.path.empty() && params.model.url.empty() &&
params.model.hf_repo.empty() && params.model.docker_repo.empty()) {
view::show_error(
"no model specified",
"use -m <file.gguf> or -hf <user/repo> to run a local model,\n"
"or --server-base <url> to connect to a running llama-server"
);
return false;
}
spinner.emplace("\n\nLoading model...");
server.emplace();
if (!server->start(params)) {
view::show_error("server start failed");
return false;
}
if (!server->wait_ready(should_stop)) {
if (!should_stop()) {
view::show_error("the server exited before becoming ready");
}
return false;
}
client.server_base = server->address();
}
// for --server-base this is the main availability check; for a spawned
// server it is a cheap sanity check on top of the ready signal
auto is_aborted = [this]() {
return should_stop() || (server && !server->alive());
};
bool healthy = false;
try {
healthy = client.wait_health(is_aborted);
} catch (const std::exception & e) {
client.last_error = e.what();
}
if (!healthy) {
if (!should_stop()) {
view::show_error(client.last_error);
}
return false;
}
if (use_external_server) {
spinner.reset();
if (!list_and_ask_models()) {
return false;
}
// restore the spinner for the next step
spinner.emplace("Waiting for server...");
}
fetch_server_props();
return true;
}
void cli_context::fetch_server_props() {
try {
json props = client.get_props();
model_name = props.value("model_alias", "");
if (model_name.empty()) {
const std::string path = props.value("model_path", "");
if (!path.empty()) {
model_name = std::filesystem::path(path).filename().string();
}
}
build_info = props.value("build_info", "");
if (props.contains("modalities") && props.at("modalities").is_object()) {
const auto & modalities = props.at("modalities");
has_vision = modalities.value("vision", false);
has_audio = modalities.value("audio", false);
has_video = modalities.value("video", false);
}
} catch (const std::exception & e) {
// /props can be disabled on remote servers; not fatal
LOG_DBG("failed to fetch /props: %s\n", e.what());
}
}
bool cli_context::list_and_ask_models() {
auto models = client.list_models();
// only one model: use it without asking
if (models.size() == 1) {
model_name = models[0];
client.model = model_name;
return true;
}
std::string message = "\nAvailable models:";
if (!models.empty()) {
for (size_t i = 0; i < models.size(); ++i) {
message += "\n " + std::to_string(i + 1) + ". " + models[i];
}
}
message += "\n";
view::show_message(message);
std::string selection;
while (selection.empty()) {
if (should_stop()) {
return false;
}
view::user_turn user_turn;
selection = user_turn.read_input(false, "Select model by number: ");
if (selection.empty()) {
continue;
}
try {
size_t idx = std::stoul(selection);
if (idx > 0 && idx <= models.size()) {
model_name = models[idx - 1];
client.model = model_name;
view::show_message("Selected model: " + model_name);
break;
}
} catch (...) {
// ignore
}
view::show_error("Invalid selection. Please enter a valid number.");
selection.clear();
continue;
}
return true;
}
void cli_context::add_system_prompt() {
if (!params.system_prompt.empty()) {
messages.push_back({
{"role", "system"},
{"content", params.system_prompt}
});
}
}
void cli_context::push_user_message(const std::string & text) {
json content;
if (pending_media.empty()) {
content = text;
} else {
// multimodal message: media parts first, then the text
content = pending_media;
content.push_back({
{"type", "text"},
{"text", text}
});
pending_media = json::array();
}
messages.push_back({
{"role", "user"},
{"content", content}
});
}
bool cli_context::stage_media_file(const std::string & fname, const std::string & type) {
std::ifstream file(fname, std::ios::binary);
if (!file) {
return false;
}
std::string data((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
std::string encoded = base64::encode(data);
if (type == "audio") {
std::string ext = std::filesystem::path(fname).extension().string();
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
pending_media.push_back({
{"type", "input_audio"},
{"input_audio", {
{"data", encoded},
{"format", ext == ".mp3" ? "mp3" : "wav"}
}}
});
} else if (type == "video") {
pending_media.push_back({
{"type", "input_video"},
{"input_video", {
{"data", encoded}
}}
});
} else {
// the server detects the actual image type from the data
pending_media.push_back({
{"type", "image_url"},
{"image_url", {
{"url", "data:image/unknown;base64," + encoded}
}}
});
}
return true;
}
bool cli_context::generate_completion(std::string & assistant_content, cli_timings & timings) {
json body = {
{"messages", messages},
{"stream", true},
// in order to get timings even when we cancel mid-way
{"timings_per_token", true},
};
bool stream_error = false;
view::assistant_turn a;
json err = client.create_chat_completion(body, should_stop, [&](const json & chunk) {
if (chunk.contains("error")) {
stream_error = true;
view::show_error(format_error_message(chunk));
return;
}
if (chunk.contains("timings")) {
const auto & t = chunk.at("timings");
timings.prompt_per_second = t.value("prompt_per_second", 0.0);
timings.predicted_per_second = t.value("predicted_per_second", 0.0);
}
if (!chunk.contains("choices") || !chunk.at("choices").is_array() || chunk.at("choices").empty()) {
return;
}
const auto & choice = chunk.at("choices").at(0);
if (!choice.contains("delta")) {
return;
}
const auto & delta = choice.at("delta");
if (delta.contains("reasoning_content") && delta.at("reasoning_content").is_string()) {
const std::string text = delta.at("reasoning_content").get<std::string>();
if (!text.empty()) {
a.push(view::ASSISTANT_DISPLAY_MODE_REASONING, text);
}
}
if (delta.contains("content") && delta.at("content").is_string()) {
const std::string text = delta.at("content").get<std::string>();
if (!text.empty()) {
assistant_content += text;
a.push(view::ASSISTANT_DISPLAY_MODE_CONTENT, text);
}
}
});
g_cli_interrupted.store(false);
if (!err.is_null()) {
view::show_error(format_error_message(err));
return false;
}
return !stream_error;
}
int cli_context::run() {
add_system_prompt();
std::string modalities = "text";
if (has_vision) {
modalities += ", vision";
}
if (has_audio) {
modalities += ", audio";
}
if (has_video) {
modalities += ", video";
}
std::string banner;
banner += "\n";
banner += LLAMA_ASCII_LOGO;
banner += "\n";
banner += "build : " + build_info + "\n";
banner += "model : " + model_name + "\n";
banner += "modalities : " + modalities + "\n";
if (!params.system_prompt.empty()) {
banner += "using custom system prompt\n";
}
banner += "\n";
banner += "available commands:\n";
banner += " /exit or Ctrl+C stop or exit\n";
banner += " /regen regenerate the last response\n";
banner += " /clear clear the chat history\n";
banner += " /read <file> add a text file\n";
banner += " /glob <pattern> add text files using globbing pattern\n";
if (has_vision) {
banner += " /image <file> add an image file\n";
}
if (has_audio) {
banner += " /audio <file> add an audio file\n";
}
if (has_video) {
banner += " /video <file> add a video file\n";
}
banner += "\n";
view::show_message(banner);
// interactive loop
std::string cur_msg;
auto add_text_file = [&](const std::string & fname) -> bool {
std::ifstream file(fname, std::ios::binary);
if (!file) {
view::show_error(string_format("file does not exist or cannot be opened: '%s'", fname.c_str()));
return false;
}
std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
cur_msg += "--- File: ";
cur_msg += fname;
cur_msg += " ---\n";
cur_msg += content;
view::show_message(string_format("Loaded text from '%s'", fname.c_str()));
return true;
};
while (true) {
std::string buffer;
{
view::user_turn user_turn;
if (params.prompt.empty()) {
buffer = user_turn.read_input(params.multiline_input);
} else {
// process input prompt from args
for (auto & fname : params.image) {
if (!stage_media_file(fname, media_type_from_ext(fname))) {
view::show_error(string_format("file does not exist or cannot be opened: '%s'", fname.c_str()));
break;
}
view::show_message(string_format("Loaded media from '%s'", fname.c_str()));
}
buffer = params.prompt;
user_turn.echo(buffer);
params.prompt.clear(); // only use it once
}
}
if (should_stop()) {
g_cli_interrupted.store(false);
break;
}
// remove trailing newline
if (!buffer.empty() && buffer.back() == '\n') {
buffer.pop_back();
}
// skip empty messages
if (buffer.empty()) {
continue;
}
bool add_user_msg = true;
// process commands
if (string_starts_with(buffer, "/exit")) {
break;
} else if (string_starts_with(buffer, "/regen")) {
if (messages.size() >= 2) {
size_t last_idx = messages.size() - 1;
messages.erase(last_idx);
add_user_msg = false;
} else {
view::show_error("No message to regenerate.");
continue;
}
} else if (string_starts_with(buffer, "/clear")) {
messages.clear();
add_system_prompt();
pending_media = json::array();
view::show_message("Chat history cleared.");
continue;
} else if (
(string_starts_with(buffer, "/image ") && has_vision) ||
(string_starts_with(buffer, "/audio ") && has_audio) ||
(string_starts_with(buffer, "/video ") && has_video)) {
std::string type = buffer.substr(1, 5);
// just in case (bad copy-paste for example), we strip all trailing/leading spaces
std::string fname = string_strip(buffer.substr(7));
if (!stage_media_file(fname, type)) {
view::show_error(string_format("file does not exist or cannot be opened: '%s'", fname.c_str()));
continue;
}
view::show_message(string_format("Loaded media from '%s'", fname.c_str()));
continue;
} else if (string_starts_with(buffer, "/read ")) {
std::string fname = string_strip(buffer.substr(6));
add_text_file(fname);
continue;
} else if (string_starts_with(buffer, "/glob ")) {
std::error_code ec;
size_t count = 0;
auto curdir = std::filesystem::current_path();
std::string pattern = string_strip(buffer.substr(6));
std::filesystem::path rel_path;
auto startglob = pattern.find_first_of("![*?");
if (startglob != std::string::npos && startglob != 0) {
auto endpath = pattern.substr(0, startglob).find_last_of('/');
if (endpath != std::string::npos) {
std::string rel_pattern = pattern.substr(0, endpath);
#if !defined(_WIN32)
if (string_starts_with(rel_pattern, '~')) {
const char * home = std::getenv("HOME");
if (home && home[0]) {
rel_pattern = home + rel_pattern.substr(1);
}
}
#endif
rel_path = rel_pattern;
pattern.erase(0, endpath + 1);
curdir /= rel_path;
}
}
for (const auto & entry : std::filesystem::recursive_directory_iterator(curdir,
std::filesystem::directory_options::skip_permission_denied, ec)) {
if (!entry.is_regular_file()) {
continue;
}
std::string rel = std::filesystem::relative(entry.path(), curdir, ec).string();
if (ec) {
ec.clear();
continue;
}
std::replace(rel.begin(), rel.end(), '\\', '/');
if (!glob_match(pattern, rel)) {
continue;
}
if (!add_text_file((rel_path / rel).string())) {
continue;
}
if (++count >= FILE_GLOB_MAX_RESULTS) {
view::show_error(string_format("Maximum number of globbed files allowed (%zu) reached.", FILE_GLOB_MAX_RESULTS));
break;
}
}
continue;
} else {
// not a command
cur_msg += buffer;
}
// generate response
if (add_user_msg) {
push_user_message(cur_msg);
cur_msg.clear();
}
cli_timings timings;
std::string assistant_content;
generate_completion(assistant_content, timings);
messages.push_back({
{"role", "assistant"},
{"content", assistant_content}
});
if (params.show_timings) {
view::show_info(string_format(
"\n[ Prompt: %.1f t/s | Generation: %.1f t/s ]",
timings.prompt_per_second,
timings.predicted_per_second
));
}
if (params.single_turn) {
break;
}
}
view::show_message("\n\nExiting...");
return 0;
}
void cli_context::shutdown() {
if (server) {
server->stop();
server.reset();
}
}
+65
View File
@@ -0,0 +1,65 @@
#pragma once
#include "common.h"
#include "cli-client.h"
#include "cli-server.h"
#include <atomic>
#include <optional>
#include <string>
struct cli_timings {
double prompt_per_second = 0.0;
double predicted_per_second = 0.0;
};
// set by the SIGINT handler; cleared once the interrupt has been handled
extern std::atomic<bool> g_cli_interrupted;
struct cli_context {
common_params params;
cli_client client; // always initialized
std::optional<cli_server> server; // only set when no --server-base is given
json messages = json::array();
json pending_media = json::array(); // staged multimodal content parts
// properties of the connected server
// will be populated by fetch_server_props()
std::string model_name;
std::string build_info;
bool has_vision = false;
bool has_audio = false;
bool has_video = false;
cli_context(const common_params & params) : params(params) {}
~cli_context() {
shutdown();
}
// connect to --server-base or spawn a local llama-server child;
// argc/argv are needed to forward the server-relevant args to the child
bool init();
// run the interactive chat loop, returns the process exit code
int run();
// stop the local server child (if any)
void shutdown();
private:
bool generate_completion(std::string & assistant_content, cli_timings & timings);
void fetch_server_props();
void add_system_prompt();
void push_user_message(const std::string & text);
// check if server have multiple models (router mode)
// if yes, list them then ask; do nothing otherwise
bool list_and_ask_models();
// read a file and stage it as a multimodal content part; type is one of
// "image", "audio", "video"; returns false if the file cannot be read
bool stage_media_file(const std::string & fname, const std::string & type);
};
+83
View File
@@ -0,0 +1,83 @@
#pragma once
#include <thread>
#include "http.h"
// llama_server will be available as a dynamic library symbol
int llama_server(common_params & params, int argc, char ** argv);
void llama_server_terminate();
struct cli_server {
std::thread th;
int port = -1;
std::atomic<bool> is_alive = false;
std::atomic<bool> is_stopping = false;
~cli_server() {
stop();
}
void stop() {
if (alive() && !is_stopping.exchange(true)) {
llama_server_terminate();
th.join();
}
}
// spawn llama-server in a thread and interact with it via a random port
bool start(common_params & params) {
port = common_http_get_free_port();
if (port <= 0) {
fprintf(stderr, "failed to get a free port\n");
exit(1);
}
is_alive.store(true, std::memory_order_release);
th = std::thread([&]() {
common_params server_params = params; // copy
server_params.port = port;
// argc / argv are only used in router mode, we can skip them for now
int res = llama_server(server_params, 0, nullptr);
if (res != 0) {
fprintf(stderr, "llama_server exited with code %d\n", res);
}
is_alive.store(false, std::memory_order_release);
});
return true;
}
std::string address() const {
return "http://127.0.0.1:" + std::to_string(port);
}
bool wait_ready(std::function<bool()> should_stop) {
if (!alive()) {
return false;
}
while (!should_stop()) {
auto [cli, parts] = common_http_client(address());
cli.set_connection_timeout(1, 0);
auto res = cli.Get("/health");
if (res) {
if (res->status == 200) {
return true;
}
// any other status means the server is up but not ready yet
// (e.g. 503 while the model is still loading)
}
if (!alive()) {
// in case server die permanently
return false;
}
std::this_thread::sleep_for(std::chrono::milliseconds(200));
}
return true;
}
bool alive() const {
return is_alive.load(std::memory_order_acquire);
}
};
+250
View File
@@ -0,0 +1,250 @@
#pragma once
#include "common.h"
#include "console.h"
#include <array>
#include <algorithm>
#include <filesystem>
#include <string_view>
// TODO?: Make this reusable, enums, docs
static const std::array<std::string_view, 8> cmds = {
"/audio ",
"/clear",
"/exit",
"/glob ",
"/image ",
"/read ",
"/regen",
"/video ",
};
static std::vector<std::pair<std::string, size_t>> auto_completion_callback(std::string_view line, size_t cursor_byte_pos) {
std::vector<std::pair<std::string, size_t>> matches;
std::string cmd;
if (line.length() > 1 && line.front() == '/' && !std::any_of(cmds.begin(), cmds.end(), [line](std::string_view prefix) {
return string_starts_with(line, prefix);
})) {
auto it = cmds.begin();
while ((it = std::find_if(it, cmds.end(), [line](std::string_view cmd_line) {
return string_starts_with(cmd_line, line);
})) != cmds.end()) {
matches.emplace_back(*it, it->length());
++it;
}
} else {
auto it = std::find_if(cmds.begin(), cmds.end(), [line](std::string_view prefix) {
return prefix.back() == ' ' && string_starts_with(line, prefix);
});
if (it != cmds.end()) {
cmd = *it;
}
}
if (!cmd.empty() && cmd != "/glob " && line.length() >= cmd.length() && cursor_byte_pos >= cmd.length()) {
const std::string path_prefix = std::string(line.substr(cmd.length(), cursor_byte_pos - cmd.length()));
const std::string path_postfix = std::string(line.substr(cursor_byte_pos));
auto cur_dir = std::filesystem::current_path();
std::string cur_dir_str = cur_dir.string();
std::string expanded_prefix = path_prefix;
#if !defined(_WIN32)
if (string_starts_with(path_prefix, '~')) {
const char * home = std::getenv("HOME");
if (home && home[0]) {
expanded_prefix = home + path_prefix.substr(1);
}
}
if (string_starts_with(expanded_prefix, '/')) {
#else
if (std::isalpha(expanded_prefix[0]) && expanded_prefix.find(':') == 1) {
#endif
cur_dir = std::filesystem::path(expanded_prefix).parent_path();
cur_dir_str.clear();
} else if (!path_prefix.empty()) {
cur_dir /= std::filesystem::path(path_prefix).parent_path();
}
std::error_code ec;
for (const auto & entry : std::filesystem::directory_iterator(cur_dir, ec)) {
if (ec) {
break;
}
if (!entry.exists(ec)) {
ec.clear();
continue;
}
const std::string path_full = entry.path().string();
std::string path_entry = !cur_dir_str.empty() && string_starts_with(path_full, cur_dir_str) ? path_full.substr(cur_dir_str.length() + 1) : path_full;
if (entry.is_directory(ec)) {
path_entry.push_back(std::filesystem::path::preferred_separator);
}
if (expanded_prefix.empty() || string_starts_with(path_entry, expanded_prefix)) {
const std::string updated_line = cmd + path_entry;
matches.emplace_back(updated_line + path_postfix, updated_line.length());
}
if (ec) {
ec.clear();
}
}
if (matches.empty()) {
const std::string updated_line = cmd + path_prefix;
matches.emplace_back(updated_line + path_postfix, updated_line.length());
}
// Add the longest common prefix
if (!expanded_prefix.empty() && matches.size() > 1) {
const std::string_view match0(matches[0].first);
const std::string_view match1(matches[1].first);
auto it = std::mismatch(match0.begin(), match0.end(), match1.begin(), match1.end());
size_t len = it.first - match0.begin();
for (size_t i = 2; i < matches.size(); ++i) {
const std::string_view matchi(matches[i].first);
auto cmp = std::mismatch(match0.begin(), match0.end(), matchi.begin(), matchi.end());
len = std::min(len, static_cast<size_t>(cmp.first - match0.begin()));
}
const std::string updated_line = std::string(match0.substr(0, len));
matches.emplace_back(updated_line + path_postfix, updated_line.length());
}
std::sort(matches.begin(), matches.end(), [](const auto & a, const auto & b) {
return a.first.compare(0, a.second, b.first, 0, b.second) < 0;
});
}
return matches;
}
// note: make this view implementation generic, so that we can move to TUI in the future if we want to
namespace view {
static void init(const common_params & params) {
// TODO: avoid using atexit() here by making `console` a singleton
console::init(params.simple_io, params.use_color);
atexit([]() { console::cleanup(); });
console::set_completion_callback(auto_completion_callback);
}
struct spinner {
spinner(const std::string & message) {
if (!message.empty()) {
console::log("%s ", message.c_str());
}
console::spinner::start();
}
~spinner() {
console::spinner::stop();
}
};
struct user_turn {
user_turn() {
console::set_display(DISPLAY_TYPE_USER_INPUT);
}
~user_turn() {
console::set_display(DISPLAY_TYPE_RESET);
}
void echo(const std::string & buffer) {
if (buffer.size() > 500) {
console::log("\n> %s ... (truncated)\n", buffer.substr(0, 500).c_str());
} else {
console::log("\n> %s\n", buffer.c_str());
}
}
std::string read_input(bool multiline_input, const char * prompt = nullptr) {
if (prompt) {
console::log("%s", prompt);
} else {
console::log("\n> ");
}
std::string buffer;
std::string line;
bool another_line = true;
do {
another_line = console::readline(line, multiline_input);
buffer += line;
} while (another_line);
return buffer;
}
};
enum assistant_display_mode {
ASSISTANT_DISPLAY_MODE_REASONING,
ASSISTANT_DISPLAY_MODE_CONTENT,
};
struct assistant_turn {
assistant_display_mode mode = ASSISTANT_DISPLAY_MODE_CONTENT;
bool trailing_newline = true;
bool is_inside_reasoning = false;
assistant_turn() {
console::set_display(DISPLAY_TYPE_RESET);
}
~assistant_turn() {
console::set_display(DISPLAY_TYPE_RESET);
add_newline_if_needed();
}
void push(assistant_display_mode m, const std::string & buffer) {
if (m != mode) {
add_newline_if_needed();
switch (m) {
case ASSISTANT_DISPLAY_MODE_CONTENT:
{
if (is_inside_reasoning) {
console::log("[End thinking]\n\n");
is_inside_reasoning = false;
}
console::set_display(DISPLAY_TYPE_RESET);
} break;
case ASSISTANT_DISPLAY_MODE_REASONING:
{
console::set_display(DISPLAY_TYPE_REASONING);
is_inside_reasoning = true;
console::log("\n[Start thinking]\n\n");
} break;
}
}
mode = m;
if (buffer.empty()) {
return;
}
trailing_newline = buffer.back() == '\n';
console::log("%s", buffer.c_str());
console::flush();
}
void add_newline_if_needed() {
if (!trailing_newline) {
console::log("\n");
console::flush();
}
}
};
static void show_error(const std::string & title, const std::string & message = "") {
console::spinner::stop();
console::error("Error: %s\n", title.c_str());
if (!message.empty()) {
console::log("%s\n", message.c_str());
}
}
static void show_message(const std::string & message) {
console::log("%s\n", message.c_str());
}
static void show_info(const std::string & message) {
console::set_display(DISPLAY_TYPE_INFO);
console::log("%s\n", message.c_str());
console::set_display(DISPLAY_TYPE_RESET);
}
}
+10 -624
View File
@@ -1,20 +1,10 @@
#include "chat.h"
#include "common.h"
#include "arg.h"
#include "console.h"
#include "fit.h"
// #include "log.h"
#include "common.h"
#include "log.h"
#include "server-common.h"
#include "server-context.h"
#include "server-task.h"
#include "cli-context.h"
#include "cli-view.h"
#include <array>
#include <atomic>
#include <algorithm>
#include <filesystem>
#include <fstream>
#include <thread>
#include <signal.h>
#if defined(_WIN32)
@@ -25,342 +15,19 @@
#include <windows.h>
#endif
const char * LLAMA_ASCII_LOGO = R"(
)";
static std::atomic<bool> g_is_interrupted = false;
static bool should_stop() {
return g_is_interrupted.load();
}
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
static void signal_handler(int) {
if (g_is_interrupted.load()) {
if (g_cli_interrupted.load()) {
// second Ctrl+C - exit immediately
// make sure to clear colors before exiting (not using LOG or console.cpp here to avoid deadlock)
fprintf(stdout, "\033[0m\n");
fflush(stdout);
std::exit(130);
}
g_is_interrupted.store(true);
g_cli_interrupted.store(true);
}
#endif
struct cli_context {
server_context ctx_server;
json messages = json::array();
std::vector<raw_buffer> input_files;
task_params defaults;
bool verbose_prompt;
// thread for showing "loading" animation
std::atomic<bool> loading_show;
cli_context(const common_params & params) {
defaults.sampling = params.sampling;
defaults.speculative = params.speculative;
defaults.n_keep = params.n_keep;
defaults.n_predict = params.n_predict;
defaults.antiprompt = params.antiprompt;
defaults.stream = true; // make sure we always use streaming mode
defaults.timings_per_token = true; // in order to get timings even when we cancel mid-way
// defaults.return_progress = true; // TODO: show progress
verbose_prompt = params.verbose_prompt;
}
std::string generate_completion(result_timings & out_timings) {
server_response_reader rd = ctx_server.get_response_reader();
auto chat_params = format_chat();
{
// TODO: reduce some copies here in the future
server_task task = server_task(SERVER_TASK_TYPE_COMPLETION);
task.id = rd.get_new_id();
task.index = 0;
task.params = defaults; // copy
task.cli_prompt = chat_params.prompt; // copy
task.cli_files = input_files; // copy
task.cli = true;
// chat template settings
task.params.chat_parser_params = common_chat_parser_params(chat_params);
task.params.chat_parser_params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
if (!chat_params.parser.empty()) {
task.params.chat_parser_params.parser.load(chat_params.parser);
}
// Copy the preserved tokens into the sampling params
const llama_vocab * vocab = llama_model_get_vocab(
llama_get_model(ctx_server.get_llama_context()));
for (const auto & token : chat_params.preserved_tokens) {
auto ids = common_tokenize(vocab, token, false, true);
if (ids.size() == 1) {
task.params.sampling.preserved_tokens.insert(ids[0]);
}
}
// reasoning budget sampler
if (!chat_params.thinking_end_tag.empty()) {
task.params.sampling.reasoning_budget_tokens = defaults.sampling.reasoning_budget_tokens;
task.params.sampling.generation_prompt = chat_params.generation_prompt;
if (!chat_params.thinking_start_tag.empty()) {
task.params.sampling.reasoning_budget_start =
common_tokenize(vocab, chat_params.thinking_start_tag, false, true);
}
task.params.sampling.reasoning_budget_end =
common_tokenize(vocab, chat_params.thinking_end_tag, false, true);
task.params.sampling.reasoning_budget_forced =
common_tokenize(vocab, defaults.sampling.reasoning_budget_message + chat_params.thinking_end_tag, false, true);
}
rd.post_task({std::move(task)});
}
if (verbose_prompt) {
console::set_display(DISPLAY_TYPE_PROMPT);
console::log("%s\n\n", chat_params.prompt.c_str());
console::set_display(DISPLAY_TYPE_RESET);
}
// wait for first result
console::spinner::start();
server_task_result_ptr result = rd.next(should_stop);
while (true) {
auto res_partial = dynamic_cast<server_task_result_cmpl_partial *>(result.get());
if (res_partial && res_partial->is_begin) {
// this is the "send 200 status to client" signal in streaming mode
// skip, do not stop the spinner
result = rd.next(should_stop);
} else {
console::spinner::stop();
break;
}
}
std::string curr_content;
bool is_thinking = false;
while (result) {
if (should_stop()) {
break;
}
if (result->is_error()) {
json err_data = result->to_json();
if (err_data.contains("message")) {
console::error("Error: %s\n", err_data["message"].get<std::string>().c_str());
} else {
console::error("Error: %s\n", err_data.dump().c_str());
}
return curr_content;
}
auto res_partial = dynamic_cast<server_task_result_cmpl_partial *>(result.get());
if (res_partial) {
out_timings = std::move(res_partial->timings);
for (const auto & diff : res_partial->oaicompat_msg_diffs) {
if (!diff.content_delta.empty()) {
if (is_thinking) {
console::log("\n[End thinking]\n\n");
console::set_display(DISPLAY_TYPE_RESET);
is_thinking = false;
}
curr_content += diff.content_delta;
console::log("%s", diff.content_delta.c_str());
console::flush();
}
if (!diff.reasoning_content_delta.empty()) {
console::set_display(DISPLAY_TYPE_REASONING);
if (!is_thinking) {
console::log("[Start thinking]\n");
}
is_thinking = true;
console::log("%s", diff.reasoning_content_delta.c_str());
console::flush();
}
}
}
auto res_final = dynamic_cast<server_task_result_cmpl_final *>(result.get());
if (res_final) {
out_timings = std::move(res_final->timings);
break;
}
result = rd.next(should_stop);
}
g_is_interrupted.store(false);
// server_response_reader automatically cancels pending tasks upon destruction
return curr_content;
}
// TODO: support remote files in the future (http, https, etc)
std::string load_input_file(const std::string & fname, bool is_media) {
std::ifstream file = fs_open_ifstream(fname, std::ios::binary);
if (!file) {
return "";
}
if (is_media) {
raw_buffer buf;
buf.assign((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
input_files.push_back(std::move(buf));
return get_media_marker();
} else {
std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
return content;
}
}
common_chat_params format_chat() {
auto meta = ctx_server.get_meta();
auto & chat_params = meta.chat_params;
auto caps = common_chat_templates_get_caps(chat_params.tmpls.get());
common_chat_templates_inputs inputs;
inputs.messages = common_chat_msgs_parse_oaicompat(messages);
inputs.tools = {}; // TODO
inputs.tool_choice = COMMON_CHAT_TOOL_CHOICE_NONE;
inputs.json_schema = ""; // TODO
inputs.grammar = ""; // TODO
inputs.use_jinja = chat_params.use_jinja;
inputs.parallel_tool_calls = caps["supports_parallel_tool_calls"];
inputs.add_generation_prompt = true;
inputs.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
inputs.force_pure_content = chat_params.force_pure_content;
inputs.enable_thinking = chat_params.enable_thinking ? common_chat_templates_support_enable_thinking(chat_params.tmpls.get()) : false;
// Apply chat template to the list of messages
return common_chat_templates_apply(chat_params.tmpls.get(), inputs);
}
};
// TODO?: Make this reusable, enums, docs
static const std::array<std::string_view, 8> cmds = {
"/audio ",
"/clear",
"/exit",
"/glob ",
"/image ",
"/read ",
"/regen",
"/video ",
};
static std::vector<std::pair<std::string, size_t>> auto_completion_callback(std::string_view line, size_t cursor_byte_pos) {
std::vector<std::pair<std::string, size_t>> matches;
std::string cmd;
if (line.length() > 1 && line.front() == '/' && !std::any_of(cmds.begin(), cmds.end(), [line](std::string_view prefix) {
return string_starts_with(line, prefix);
})) {
auto it = cmds.begin();
while ((it = std::find_if(it, cmds.end(), [line](std::string_view cmd_line) {
return string_starts_with(cmd_line, line);
})) != cmds.end()) {
matches.emplace_back(*it, it->length());
++it;
}
} else {
auto it = std::find_if(cmds.begin(), cmds.end(), [line](std::string_view prefix) {
return prefix.back() == ' ' && string_starts_with(line, prefix);
});
if (it != cmds.end()) {
cmd = *it;
}
}
if (!cmd.empty() && cmd != "/glob " && line.length() >= cmd.length() && cursor_byte_pos >= cmd.length()) {
const std::string path_prefix = std::string(line.substr(cmd.length(), cursor_byte_pos - cmd.length()));
const std::string path_postfix = std::string(line.substr(cursor_byte_pos));
auto cur_dir = std::filesystem::current_path();
std::string cur_dir_str = cur_dir.string();
std::string expanded_prefix = path_prefix;
#if !defined(_WIN32)
if (string_starts_with(path_prefix, '~')) {
const char * home = std::getenv("HOME");
if (home && home[0]) {
expanded_prefix = home + path_prefix.substr(1);
}
}
if (string_starts_with(expanded_prefix, '/')) {
#else
if (std::isalpha(expanded_prefix[0]) && expanded_prefix.find(':') == 1) {
#endif
cur_dir = std::filesystem::path(expanded_prefix).parent_path();
cur_dir_str.clear();
} else if (!path_prefix.empty()) {
cur_dir /= std::filesystem::path(path_prefix).parent_path();
}
std::error_code ec;
for (const auto & entry : std::filesystem::directory_iterator(cur_dir, ec)) {
if (ec) {
break;
}
if (!entry.exists(ec)) {
ec.clear();
continue;
}
const std::string path_full = entry.path().string();
std::string path_entry = !cur_dir_str.empty() && string_starts_with(path_full, cur_dir_str) ? path_full.substr(cur_dir_str.length() + 1) : path_full;
if (entry.is_directory(ec)) {
path_entry.push_back(std::filesystem::path::preferred_separator);
}
if (expanded_prefix.empty() || string_starts_with(path_entry, expanded_prefix)) {
const std::string updated_line = cmd + path_entry;
matches.emplace_back(updated_line + path_postfix, updated_line.length());
}
if (ec) {
ec.clear();
}
}
if (matches.empty()) {
const std::string updated_line = cmd + path_prefix;
matches.emplace_back(updated_line + path_postfix, updated_line.length());
}
// Add the longest common prefix
if (!expanded_prefix.empty() && matches.size() > 1) {
const std::string_view match0(matches[0].first);
const std::string_view match1(matches[1].first);
auto it = std::mismatch(match0.begin(), match0.end(), match1.begin(), match1.end());
size_t len = it.first - match0.begin();
for (size_t i = 2; i < matches.size(); ++i) {
const std::string_view matchi(matches[i].first);
auto cmp = std::mismatch(match0.begin(), match0.end(), matchi.begin(), matchi.end());
len = std::min(len, static_cast<size_t>(cmp.first - match0.begin()));
}
const std::string updated_line = std::string(match0.substr(0, len));
matches.emplace_back(updated_line + path_postfix, updated_line.length());
}
std::sort(matches.begin(), matches.end(), [](const auto & a, const auto & b) {
return a.first.compare(0, a.second, b.first, 0, b.second) < 0;
});
}
return matches;
}
static constexpr size_t FILE_GLOB_MAX_RESULTS = 100;
// satisfies -Wmissing-declarations
int llama_cli(int argc, char ** argv);
@@ -375,25 +42,6 @@ int llama_cli(int argc, char ** argv) {
return 1;
}
// TODO: maybe support it later?
if (params.conversation_mode == COMMON_CONVERSATION_MODE_DISABLED) {
console::error("--no-conversation is not supported by llama-cli\n");
console::error("please use llama-completion instead\n");
}
// struct that contains llama context and inference
cli_context ctx_cli(params);
llama_backend_init();
llama_numa_init(params.numa);
// TODO: avoid using atexit() here by making `console` a singleton
console::init(params.simple_io, params.use_color);
atexit([]() { console::cleanup(); });
console::set_display(DISPLAY_TYPE_RESET);
console::set_completion_callback(auto_completion_callback);
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
struct sigaction sigint_action;
sigint_action.sa_handler = signal_handler;
@@ -408,273 +56,11 @@ int llama_cli(int argc, char ** argv) {
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
#endif
console::log("\nLoading model... "); // followed by loading animation
console::spinner::start();
if (!ctx_cli.ctx_server.load_model(params)) {
console::spinner::stop();
console::error("\nFailed to load the model\n");
cli_context ctx_cli(params);
if (!ctx_cli.init()) {
return 1;
}
ctx_cli.defaults.sampling = params.sampling;
console::spinner::stop();
console::log("\n");
std::thread inference_thread([&ctx_cli]() {
ctx_cli.ctx_server.start_loop();
});
auto inf = ctx_cli.ctx_server.get_meta();
std::string modalities = "text";
if (inf.has_inp_image) {
modalities += ", vision";
}
if (inf.has_inp_audio) {
modalities += ", audio";
}
auto add_system_prompt = [&]() {
if (!params.system_prompt.empty()) {
ctx_cli.messages.push_back({
{"role", "system"},
{"content", params.system_prompt}
});
}
};
add_system_prompt();
console::log("\n");
console::log("%s\n", LLAMA_ASCII_LOGO);
console::log("build : %s\n", inf.build_info.c_str());
console::log("model : %s\n", inf.model_name.c_str());
console::log("modalities : %s\n", modalities.c_str());
if (!params.system_prompt.empty()) {
console::log("using custom system prompt\n");
}
console::log("\n");
console::log("available commands:\n");
console::log(" /exit or Ctrl+C stop or exit\n");
console::log(" /regen regenerate the last response\n");
console::log(" /clear clear the chat history\n");
console::log(" /read <file> add a text file\n");
console::log(" /glob <pattern> add text files using globbing pattern\n");
if (inf.has_inp_image) {
console::log(" /image <file> add an image file\n");
}
if (inf.has_inp_audio) {
console::log(" /audio <file> add an audio file\n");
}
if (inf.has_inp_video) {
console::log(" /video <file> add a video file\n");
}
console::log("\n");
// interactive loop
std::string cur_msg;
auto add_text_file = [&](const std::string & fname) -> bool {
std::string marker = ctx_cli.load_input_file(fname, false);
if (marker.empty()) {
console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
return false;
}
if (inf.fim_sep_token != LLAMA_TOKEN_NULL) {
cur_msg += common_token_to_piece(ctx_cli.ctx_server.get_llama_context(), inf.fim_sep_token, true);
cur_msg += fname;
cur_msg.push_back('\n');
} else {
cur_msg += "--- File: ";
cur_msg += fname;
cur_msg += " ---\n";
}
cur_msg += marker;
console::log("Loaded text from '%s'\n", fname.c_str());
return true;
};
while (true) {
std::string buffer;
console::set_display(DISPLAY_TYPE_USER_INPUT);
if (params.prompt.empty()) {
console::log("\n> ");
std::string line;
bool another_line = true;
do {
another_line = console::readline(line, params.multiline_input);
buffer += line;
} while (another_line);
} else {
// process input prompt from args
for (auto & fname : params.image) {
std::string marker = ctx_cli.load_input_file(fname, true);
if (marker.empty()) {
console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
break;
}
console::log("Loaded media from '%s'\n", fname.c_str());
cur_msg += marker;
}
buffer = params.prompt;
if (buffer.size() > 500) {
console::log("\n> %s ... (truncated)\n", buffer.substr(0, 500).c_str());
} else {
console::log("\n> %s\n", buffer.c_str());
}
params.prompt.clear(); // only use it once
}
console::set_display(DISPLAY_TYPE_RESET);
console::log("\n");
if (should_stop()) {
g_is_interrupted.store(false);
break;
}
// remove trailing newline
if (!buffer.empty() &&buffer.back() == '\n') {
buffer.pop_back();
}
// skip empty messages
if (buffer.empty()) {
continue;
}
bool add_user_msg = true;
// process commands
if (string_starts_with(buffer, "/exit")) {
break;
} else if (string_starts_with(buffer, "/regen")) {
if (ctx_cli.messages.size() >= 2) {
size_t last_idx = ctx_cli.messages.size() - 1;
ctx_cli.messages.erase(last_idx);
add_user_msg = false;
} else {
console::error("No message to regenerate.\n");
continue;
}
} else if (string_starts_with(buffer, "/clear")) {
ctx_cli.messages.clear();
add_system_prompt();
ctx_cli.input_files.clear();
console::log("Chat history cleared.\n");
continue;
} else if (
(string_starts_with(buffer, "/image ") && inf.has_inp_image) ||
(string_starts_with(buffer, "/audio ") && inf.has_inp_audio) ||
(string_starts_with(buffer, "/video ") && inf.has_inp_video)) {
// just in case (bad copy-paste for example), we strip all trailing/leading spaces
std::string fname = string_strip(buffer.substr(7));
std::string marker = ctx_cli.load_input_file(fname, true);
if (marker.empty()) {
console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
continue;
}
cur_msg += marker;
console::log("Loaded media from '%s'\n", fname.c_str());
continue;
} else if (string_starts_with(buffer, "/read ")) {
std::string fname = string_strip(buffer.substr(6));
add_text_file(fname);
continue;
} else if (string_starts_with(buffer, "/glob ")) {
std::error_code ec;
size_t count = 0;
auto curdir = std::filesystem::current_path();
std::string pattern = string_strip(buffer.substr(6));
std::filesystem::path rel_path;
auto startglob = pattern.find_first_of("![*?");
if (startglob != std::string::npos && startglob != 0) {
auto endpath = pattern.substr(0, startglob).find_last_of('/');
if (endpath != std::string::npos) {
std::string rel_pattern = pattern.substr(0, endpath);
#if !defined(_WIN32)
if (string_starts_with(rel_pattern, '~')) {
const char * home = std::getenv("HOME");
if (home && home[0]) {
rel_pattern = home + rel_pattern.substr(1);
}
}
#endif
rel_path = rel_pattern;
pattern.erase(0, endpath + 1);
curdir /= rel_path;
}
}
for (const auto & entry : std::filesystem::recursive_directory_iterator(curdir,
std::filesystem::directory_options::skip_permission_denied, ec)) {
if (!entry.is_regular_file()) {
continue;
}
std::string rel = std::filesystem::relative(entry.path(), curdir, ec).string();
if (ec) {
ec.clear();
continue;
}
std::replace(rel.begin(), rel.end(), '\\', '/');
if (!glob_match(pattern, rel)) {
continue;
}
if (!add_text_file((rel_path / rel).string())) {
continue;
}
if (++count >= FILE_GLOB_MAX_RESULTS) {
console::error("Maximum number of globbed files allowed (%zu) reached.\n", FILE_GLOB_MAX_RESULTS);
break;
}
}
continue;
} else {
// not a command
cur_msg += buffer;
}
// generate response
if (add_user_msg) {
ctx_cli.messages.push_back({
{"role", "user"},
{"content", cur_msg}
});
cur_msg.clear();
}
result_timings timings;
std::string assistant_content = ctx_cli.generate_completion(timings);
ctx_cli.messages.push_back({
{"role", "assistant"},
{"content", assistant_content}
});
console::log("\n");
if (params.show_timings) {
console::set_display(DISPLAY_TYPE_INFO);
console::log("\n");
console::log("[ Prompt: %.1f t/s | Generation: %.1f t/s ]\n", timings.prompt_per_second, timings.predicted_per_second);
console::set_display(DISPLAY_TYPE_RESET);
}
if (params.single_turn) {
break;
}
}
console::set_display(DISPLAY_TYPE_RESET);
console::log("\nExiting...\n");
ctx_cli.ctx_server.terminate();
inference_thread.join();
// bump the log level to display timings
common_log_set_verbosity_thold(LOG_LEVEL_INFO);
common_memory_breakdown_print(ctx_cli.ctx_server.get_llama_context());
return 0;
return ctx_cli.run();
}
+1 -1
View File
@@ -42,6 +42,7 @@
#define KEY_N_HEAD "clip.%s.attention.head_count"
#define KEY_N_HEAD_KV "clip.%s.attention.head_count_kv"
#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon"
#define KEY_FEATURE_LAYERS "clip.%s.feature_layer"
// vision-specific
#define KEY_VISION_PROJ_TYPE "clip.vision.projector_type" // for models with mixed modalities
@@ -54,7 +55,6 @@
#define KEY_PATCH_SIZE "clip.vision.patch_size"
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
#define KEY_IMAGE_STD "clip.vision.image_std"
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
#define KEY_PROJ_SAMPLE_QUERY_SIDE "clip.vision.projector.query_side"
#define KEY_PROJ_SAMPLE_WINDOW_SIDE "clip.vision.projector.window_side"
+3 -3
View File
@@ -91,7 +91,7 @@ struct clip_hparams {
float eps = 1e-6;
float rope_theta = 0.0;
std::vector<int32_t> vision_feature_layer;
std::vector<int32_t> feature_layers;
int32_t attn_window_size = 0;
int32_t n_wa_pattern = 0;
std::unordered_set<int32_t> wa_layer_indexes; // explicit layer indexes that use full attention (for irregular patterns like YoutuVL)
@@ -165,8 +165,8 @@ struct clip_hparams {
return false;
}
bool is_vision_feature_layer(int32_t layer) const {
return std::find(vision_feature_layer.begin(), vision_feature_layer.end(), layer) != vision_feature_layer.end();
bool is_feature_layer(int32_t layer) const {
return std::find(feature_layers.begin(), feature_layers.end(), layer) != feature_layers.end();
}
};
+68 -34
View File
@@ -1045,8 +1045,17 @@ struct clip_model_loader {
bool has_vision = false;
bool has_audio = false;
mtmd_progress_callback progress_callback = nullptr;
void * progress_callback_user_data = nullptr;
// TODO @ngxson : we should not pass clip_ctx here, it should be clip_model
clip_model_loader(const char * fname, bool skip_tensors = false) : fname(fname) {
clip_model_loader(const char * fname,
bool skip_tensors = false,
mtmd_progress_callback progress_cb = nullptr,
void * progress_user_data = nullptr)
: fname(fname),
progress_callback(progress_cb),
progress_callback_user_data(progress_user_data) {
struct ggml_context * meta = nullptr;
struct gguf_init_params params = {
@@ -1255,12 +1264,10 @@ struct clip_model_loader {
}
}
// Load the vision feature layer indices if they are explicitly provided;
// if multiple vision feature layers are present, the values will be concatenated
// to form the final visual features.
// Load the vision/audio feature layer indices if they are explicitly provided
// NOTE: gguf conversions should standardize the values of the vision feature layer to
// be non-negative, since we use -1 to mark values as unset here.
get_arr_int(KEY_FEATURE_LAYER, hparams.vision_feature_layer, false);
get_arr_int(string_format(KEY_FEATURE_LAYERS, prefix), hparams.feature_layers, false);
// model-specific params
switch (model.proj_type) {
@@ -1642,6 +1649,7 @@ struct clip_model_loader {
get_u32(KEY_A_PROJ_WINDOW_SIZE, hparams.audio_proj_window_size);
get_u32(KEY_A_PROJ_DOWNSAMPLE_RATE, hparams.audio_proj_downsample_rate);
get_u32(KEY_A_PROJ_HEAD_COUNT, hparams.audio_proj_head_count);
// NOTE: feature layers loaded above in common path
} break;
case PROJECTOR_TYPE_JANUS_PRO:
{
@@ -1654,11 +1662,11 @@ struct clip_model_loader {
hparams.image_resize_algo = RESIZE_ALGO_BICUBIC_PILLOW;
hparams.image_resize_pad = PAD_CEIL;
get_arr_int(KEY_FEATURE_LAYER, hparams.vision_feature_layer);
// NOTE: feature_layers loaded in common path as optional
get_arr_int(KEY_PROJ_SPATIAL_OFFSETS, hparams.proj_spatial_offsets);
if (hparams.vision_feature_layer.size() != hparams.proj_spatial_offsets.size()) {
throw std::runtime_error(string_format("%s: vision_feature_layer.size() %d != proj_spatial_offsets.size() %d",
hparams.vision_feature_layer.size(), hparams.proj_spatial_offsets.size()));
if (hparams.feature_layers.size() != hparams.proj_spatial_offsets.size()) {
throw std::runtime_error(string_format("%s: feature_layers.size() %d != proj_spatial_offsets.size() %d",
hparams.feature_layers.size(), hparams.proj_spatial_offsets.size()));
}
get_u32(KEY_PROJ_SAMPLE_QUERY_SIDE, hparams.downsample_query_side);
@@ -2731,7 +2739,7 @@ struct clip_model_loader {
model.image_newline = get_tensor(TN_IMAGE_NEWLINE);
// Load separate layerwise and spatial projector tensors
const auto projector_count = hparams.vision_feature_layer.size();
const auto projector_count = hparams.feature_layers.size();
model.qf_proj_blocks.resize(projector_count);
for (size_t bid = 0; bid < projector_count; ++bid) {
auto & b = model.qf_proj_blocks[bid];
@@ -2787,37 +2795,60 @@ struct clip_model_loader {
}
// load data
if (!ctx_clip.no_alloc) {
{
std::vector<uint8_t> read_buf;
// start loading event
if (progress_callback){
progress_callback(0.0, progress_callback_user_data);
}
// compute total tensor data size for progress reporting
size_t total_data_size = 0;
for (auto & t : tensors_to_load) {
total_data_size += ggml_nbytes(t);
}
// alloc memory and offload data
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(ctx_clip.backend);
ctx_clip.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(ctx_clip.ctx_data.get(), buft));
ggml_backend_buffer_set_usage(ctx_clip.buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
for (auto & t : tensors_to_load) {
ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data.get(), t->name);
GGML_ASSERT(cur && "tensor not found in ctx_data");
auto it_off = tensor_offset.find(t->name);
GGML_ASSERT(it_off != tensor_offset.end() && "no offset for tensor");
const size_t offset = it_off->second;
fin.seekg(offset, std::ios::beg);
if (!fin) {
throw std::runtime_error(string_format("%s: failed to seek for tensor %s\n", __func__, t->name));
}
size_t num_bytes = ggml_nbytes(cur);
if (ggml_backend_buft_is_host(buft)) {
// for the CPU and Metal backend, we can read directly into the tensor
fin.read(reinterpret_cast<char *>(cur->data), num_bytes);
} else {
// read into a temporary buffer first, then copy to device memory
read_buf.resize(num_bytes);
fin.read(reinterpret_cast<char *>(read_buf.data()), num_bytes);
ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes);
// read the weight from file
if (!ctx_clip.no_alloc) {
size_t data_loaded = 0;
for (auto & t : tensors_to_load) {
ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data.get(), t->name);
GGML_ASSERT(cur && "tensor not found in ctx_data");
auto it_off = tensor_offset.find(t->name);
GGML_ASSERT(it_off != tensor_offset.end() && "no offset for tensor");
const size_t offset = it_off->second;
fin.seekg(offset, std::ios::beg);
if (!fin) {
throw std::runtime_error(string_format("%s: failed to seek for tensor %s\n", __func__, t->name));
}
size_t num_bytes = ggml_nbytes(cur);
if (ggml_backend_buft_is_host(buft)) {
// for the CPU and Metal backend, we can read directly into the tensor
fin.read(reinterpret_cast<char *>(cur->data), num_bytes);
} else {
// read into a temporary buffer first, then copy to device memory
read_buf.resize(num_bytes);
fin.read(reinterpret_cast<char *>(read_buf.data()), num_bytes);
ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes);
}
data_loaded += num_bytes;
if (progress_callback && total_data_size > 0) {
const float progress = (float)data_loaded / (float)total_data_size;
if (!progress_callback(progress, progress_callback_user_data)) {
throw std::runtime_error(string_format("%s: model loading cancelled by progress_callback\n", __func__));
}
}
}
LOG_DBG("%s: loaded %zu tensors from %s\n", __func__, tensors_to_load.size(), fname.c_str());
} else {
LOG_DBG("%s: no_alloc is set, skipping tensor data loading (%zu tensors)\n", __func__, tensors_to_load.size());
}
fin.close();
LOG_DBG("%s: loaded %zu tensors from %s\n", __func__, tensors_to_load.size(), fname.c_str());
}
}
@@ -3105,7 +3136,10 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params
clip_ctx * ctx_audio = nullptr;
try {
clip_model_loader loader(fname);
clip_model_loader loader(fname,
/* skip_tensors */ false,
ctx_params.progress_callback,
ctx_params.progress_callback_user_data);
bool skip_audio = false;
if (loader.has_vision) {
@@ -4353,7 +4387,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, int n_threads, const clip_image_f32
// Stage 1b only uses block 0's permutations; future stages
// will upload all blocks.
for (size_t bid = 0; bid < hparams.vision_feature_layer.size(); ++bid) {
for (size_t bid = 0; bid < hparams.feature_layers.size(); ++bid) {
const std::string prefix = "g4v_blk" + std::to_string(bid) + "_";
upload(prefix + "win_idx", make_win_idx(image_side, window_side));
upload(prefix + "qwin_idx", make_win_idx(new_side, query_side));
+2
View File
@@ -54,6 +54,8 @@ struct clip_context_params {
ggml_backend_sched_eval_callback cb_eval;
void * cb_eval_user_data;
bool no_alloc;
mtmd_progress_callback progress_callback;
void * progress_callback_user_data;
};
struct clip_init_result {
+35 -1
View File
@@ -1,5 +1,7 @@
#include "models.h"
#include <algorithm>
ggml_cgraph * clip_graph_granite_speech::build() {
const int n_frames = img.nx();
const int context_size = hparams.audio_chunk_size;
@@ -11,6 +13,10 @@ ggml_cgraph * clip_graph_granite_speech::build() {
const int padded_len = num_blocks * context_size;
const int remainder = n_frames % context_size;
// Calculate projector input dimension based on feature layers
const int proj_input_dim = n_embd * (hparams.feature_layers.size() + 1);
const bool use_feature_concat = !hparams.feature_layers.empty();
ggml_tensor * attn_dists = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, context_size * context_size);
ggml_set_name(attn_dists, "attn_dists");
ggml_set_input(attn_dists);
@@ -31,6 +37,15 @@ ggml_cgraph * clip_graph_granite_speech::build() {
cur = ggml_add(ctx0, cur, model.inp_proj_b);
cb(cur, "inp_linear", -1);
// Capture layer 0 if requested (after input_linear)
ggml_tensor * concat_result = nullptr;
if (use_feature_concat) {
if (std::find(hparams.feature_layers.begin(), hparams.feature_layers.end(), 0) != hparams.feature_layers.end()) {
concat_result = cur;
cb(concat_result, "feature_layer_0", -1);
}
}
for (int il = 0; il < n_layer; il++) {
const auto & layer = model.layers[il];
auto * residual = cur;
@@ -168,6 +183,18 @@ ggml_cgraph * clip_graph_granite_speech::build() {
NORM_TYPE_NORMAL, eps, il);
cb(cur, "layer_out", il);
// Capture intermediate layer (il + 1) if requested
if (use_feature_concat) {
if (hparams.is_feature_layer(il + 1)) {
if (concat_result == nullptr) {
concat_result = cur;
} else {
concat_result = ggml_concat(ctx0, concat_result, cur, 0);
}
cb(concat_result, string_format("feature_layer_%d", il + 1).c_str(), il);
}
}
// CTC branch
if (il + 1 == ctc_layer) {
auto * mid = build_mm(model.ctc_out_w, cur);
@@ -180,6 +207,13 @@ ggml_cgraph * clip_graph_granite_speech::build() {
}
}
// Append final output to concatenated features if using feature concatenation
if (use_feature_concat && concat_result != nullptr) {
concat_result = ggml_concat(ctx0, concat_result, cur, 0);
cb(concat_result, "concat_final", -1);
cur = concat_result;
}
cb(cur, "encoder_out", -1);
// QFormer projector
@@ -197,7 +231,7 @@ ggml_cgraph * clip_graph_granite_speech::build() {
cur = ggml_pad(ctx0, cur, 0, padded_proj - n_frames, 0, 0);
}
ggml_tensor * enc_windows = ggml_reshape_3d(ctx0, cur, n_embd, window_size, nblocks_proj);
ggml_tensor * enc_windows = ggml_reshape_3d(ctx0, cur, proj_input_dim, window_size, nblocks_proj);
ggml_tensor * queries = build_norm(model.qf_proj_blocks[0].qf_proj_query,
model.qf_proj_blocks[0].qf_proj_norm_w, model.qf_proj_blocks[0].qf_proj_norm_b,
+2 -2
View File
@@ -304,14 +304,14 @@ ggml_cgraph * clip_graph_granite4_vision::build() {
}
// --- Stage 1b/1c: WindowQFormer blocks ---
const int projector_count = hparams.vision_feature_layer.size();
const int projector_count = hparams.feature_layers.size();
const float qformer_eps = 1e-12f;
ggml_tensor * mmproj = nullptr;
for (int bid = 0; bid < projector_count; ++bid) {
const auto & blk = model.qf_proj_blocks[bid];
int vlayer = hparams.vision_feature_layer[bid];
int vlayer = hparams.feature_layers[bid];
GGML_ASSERT(vlayer >= 0 && vlayer < n_layer);
ggml_tensor * h = layer_outs[vlayer];
+3 -3
View File
@@ -21,7 +21,7 @@ ggml_cgraph * clip_graph_llava::build() {
// If we set explicit vision feature layers, only go up to the deepest one
// NOTE: only used by granite-vision models for now
for (const auto & feature_layer : hparams.vision_feature_layer) {
for (const auto & feature_layer : hparams.feature_layers) {
if (feature_layer > deepest_feature_layer) {
deepest_feature_layer = feature_layer;
}
@@ -59,7 +59,7 @@ ggml_cgraph * clip_graph_llava::build() {
// If this is an embedding feature layer, save the output.
// NOTE: 0 index here refers to the input to the encoder.
if (hparams.is_vision_feature_layer(il)) {
if (hparams.is_feature_layer(il)) {
embedding_stack.push_back(cur);
}
@@ -134,7 +134,7 @@ ggml_cgraph * clip_graph_llava::build() {
// process vision feature layers (used by granite)
{
// final layer is a vision feature layer
if (hparams.is_vision_feature_layer(max_feature_layer)) {
if (hparams.is_feature_layer(max_feature_layer)) {
embedding_stack.push_back(inpL);
}
+8 -1
View File
@@ -251,6 +251,8 @@ mtmd_context_params mtmd_context_params_default() {
/* cb_eval */ nullptr,
/* cb_eval_user_data */ nullptr,
/* batch_max_tokens */ 1024,
/* progress_callback */ nullptr,
/* progress_callback_user_data */ nullptr,
};
return params;
}
@@ -345,6 +347,8 @@ struct mtmd_context {
/* cb_eval */ ctx_params.cb_eval,
/* cb_eval_user_data */ ctx_params.cb_eval_user_data,
/* no_alloc */ no_alloc,
/* progress_callback */ ctx_params.progress_callback,
/* progress_callback_user_data */ ctx_params.progress_callback_user_data,
};
auto res = clip_init(mmproj_fname, ctx_clip_params);
@@ -2133,9 +2137,12 @@ std::map<ggml_backend_dev_t, size_t> mtmd_get_memory_usage(const char * mmproj_f
mtmd::context_ptr ctx;
auto saved_log_callback = g_logger_state.log_callback;
auto saved_log_user_data = g_logger_state.log_callback_user_data;
ctx_params.progress_callback = nullptr;
try {
mtmd_log_set(stub_log_callback, nullptr); // suppress logging
ctx.reset(new mtmd_context(mmproj_fname, nullptr, ctx_params));
ctx.reset(new mtmd_context(mmproj_fname, nullptr, ctx_params, true));
mtmd_log_set(saved_log_callback, saved_log_user_data); // restore log callback
std::map<ggml_backend_dev_t, size_t> total_mem;
auto merge = [&](const struct clip_ctx * c) {
+8
View File
@@ -83,6 +83,8 @@ typedef struct mtmd_input_chunks mtmd_input_chunks;
typedef struct mtmd_input_text mtmd_input_text;
typedef struct mtmd_batch mtmd_batch;
typedef bool (*mtmd_progress_callback)(float progress, void * user_data);
struct mtmd_context_params {
bool use_gpu;
bool print_timings;
@@ -104,6 +106,12 @@ struct mtmd_context_params {
int32_t batch_max_tokens; // maximum number of output tokens in a batch
// (note: this is not a hard-limit, the first image will always be added even if it exceeds this limit)
// (default: 1024)
// Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
// If the provided progress_callback returns true, model loading continues.
// If it returns false, model loading is immediately aborted.
mtmd_progress_callback progress_callback;
void * progress_callback_user_data;
};
MTMD_API const char * mtmd_default_marker(void);
+3 -3
View File
@@ -204,9 +204,9 @@ Instead of building everything from the ground up (like what most AI agents will
The flow for downloading a new model:
- POST request comes in --> `post_router_models` --> validation
- `server_models::download()` is called
- Sets up a new thread `inst.th` and runs the download inside
- If a stop request comes in, set `stop_download` to `true`
- A new `llama-server` subprocess will be spawned with special `SERVER_CHILD_MODE_DOWNLOAD`
- Child process runs the download and report status back to router via stdin/out
- If a stop request comes in, the router asks the child process to stop (same mechanism as running a model in child process)
- Otherwise, upon completion, we call `load_models()` to refresh the list of models
### Notable Related PRs
+42 -7
View File
@@ -1230,8 +1230,6 @@ print(completion.choices[0].text)
Given a ChatML-formatted json description in `messages`, it returns the predicted completion. Both synchronous and streaming mode are supported, so scripted and interactive applications work fine. While no strong claims of compatibility with OpenAI API spec is being made, in our experience it suffices to support many apps. Only models with a [supported chat template](https://github.com/ggml-org/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template) can be used optimally with this endpoint. By default, the ChatML template will be used.
If model supports multimodal, you can input the media file via `image_url` content part. We support both base64 and remote URL as input. See OAI documentation for more.
*Options:*
See [OpenAI Chat Completions API documentation](https://platform.openai.com/docs/api-reference/chat). llama.cpp `/completion`-specific features such as `mirostat` are also supported.
@@ -1250,9 +1248,18 @@ The `response_format` parameter supports both plain JSON output (e.g. `{"type":
`parallel_tool_calls` : Whether to enable parallel/multiple tool calls (only supported on some models, verification is based on jinja template).
For multimodal input:
- Content type `image_url` and `input_audio` are the same as OAI schema
- Content type `input_video` is an extension from OAI schema. For now, it only accepts base64 input
For multimodal input (typed content, `messages[i].content[j]`):
- If `type == "image_url"`:
- `image_url.url` can be a remote URL, base64 (raw or URI-encoded via `data:image/...;base64`) or path to local file
- Accepts formats supported by `stb_image` (jpeg, png, tga, bmp, gif, ...)
- If `type == "input_audio"`:
- Either `input_audio.data` or `input_audio.url` can be specified, can be a remote URL, raw base64 or path to local file
- Accepts formats supported by `miniaudio` (mp3, wav, flac)
- `input_audio.format` will be ignored, the file format will be determined automatically
- If `type == "input_video"`:
- Either `input_video.data` or `input_video.url` can be specified, can be a remote URL, raw base64 or path to local file
- Accepts formats supported by `ffmpeg`
- Note: for local file, make sure to set `--media-path`. File path must be prefixed by `file://`
*Examples:*
@@ -1859,9 +1866,37 @@ Example events:
{
"model": "...",
"event": "download_finished",
"event": "model_status",
"data": {
"status": "loading"
"status": "loading",
"progress": {
"stages": ["text_model", "spec_model", "mmproj_model"],
"current": "text_model",
"value": 0.5
}
}
}
// note for "loading" status:
// - subsequent events will follow the same order of "stages" list
// - mmap is may report incorrect progress on some platforms; if you need exact progress, use --no-mmap
{
"model": "...",
"event": "model_status",
"data": {
"status": "loaded",
"info": {
// note: only include info on first load
// waking up from sleep doesn't have this
}
}
}
{
"model": "...",
"event": "model_status",
"data": {
"status": "sleeping"
}
}
+46 -34
View File
@@ -518,6 +518,14 @@ size_t server_tokens::get_common_prefix(const server_tokens & b) const {
return max_idx; // all tokens are equal
}
common_chat_msg_spans server_tokens::find_message_spans(const common_chat_msg_delimiters & delims) const {
std::map<size_t, size_t> skips;
for (const auto & it : map_idx_to_media) {
skips[it.first] = mtmd_input_chunk_get_n_tokens(it.second.get());
}
return delims.split(tokens, skips);
}
bool server_tokens::validate(const struct llama_context * ctx) const {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
@@ -817,12 +825,21 @@ json oaicompat_completion_params_parse(const json & body) {
return llama_params;
}
// media_path always end with '/', see arg.cpp
// url can be
// - http(s):// for remote files
// - file:// for local files (only allowed if media_path is set)
// - data: for base64 encoded data with uri scheme (e.g. data:image/png;base64,...)
// - raw base64 encoded data
static void handle_media(
std::vector<raw_buffer> & out_files,
json & media_obj,
const std::string & media_path) {
std::string url = json_value(media_obj, "url", std::string());
const std::string & url,
const std::string & media_path,
bool accept_base64_uri) {
if (!media_path.empty()) {
// should already be enforced by arg.cpp, but checking just in case
GGML_ASSERT(media_path.back() == DIRECTORY_SEPARATOR);
}
if (string_starts_with(url, "http")) {
// download remote image
// TODO @ngxson : maybe make these params configurable
@@ -858,20 +875,28 @@ static void handle_media(
data.assign((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
out_files.push_back(data);
} else {
} else if (accept_base64_uri && string_starts_with(url, "data:")) {
// try to decode base64 image
std::vector<std::string> parts = string_split<std::string>(url, /*separator*/ ',');
if (parts.size() != 2) {
throw std::runtime_error("Invalid url value");
throw std::runtime_error("Invalid uri-encoded base64 value");
} else if (!string_starts_with(parts[0], "data:image/")) {
throw std::runtime_error("Invalid url format: " + parts[0]);
throw std::runtime_error("Invalid uri format: " + parts[0]);
} else if (!string_ends_with(parts[0], "base64")) {
throw std::runtime_error("url must be base64 encoded");
throw std::runtime_error("uri must be base64 encoded");
} else {
auto base64_data = parts[1];
auto decoded_data = base64_decode(base64_data);
out_files.push_back(decoded_data);
}
} else {
// try as raw base64 string
auto decoded_data = base64_decode(url);
if (decoded_data.empty()) {
throw std::runtime_error("Invalid base64 value");
}
out_files.push_back(decoded_data);
}
}
@@ -957,14 +982,15 @@ json oaicompat_chat_params_parse(
}
for (auto & p : content) {
std::string type = json_value(p, "type", std::string());
std::string type = json_value(p, "type", std::string());
if (type == "image_url") {
if (!opt.allow_image) {
throw std::runtime_error("image input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
}
json image_url = json_value(p, "image_url", json::object());
handle_media(out_files, image_url, opt.media_path);
std::string url = json_value(image_url, "url", std::string());
handle_media(out_files, url, opt.media_path, true);
p["type"] = "media_marker";
p["text"] = get_media_marker();
@@ -975,17 +1001,11 @@ json oaicompat_chat_params_parse(
throw std::runtime_error("audio input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
}
json input_audio = json_value(p, "input_audio", json::object());
std::string data = json_value(input_audio, "data", std::string());
std::string format = json_value(input_audio, "format", std::string());
// while we also support flac, we don't allow it here so we matches the OAI spec
if (format != "wav" && format != "mp3") {
throw std::invalid_argument("input_audio.format must be either 'wav' or 'mp3'");
}
auto decoded_data = base64_decode(data); // expected to be base64 encoded
out_files.push_back(decoded_data);
// TODO: add audio_url support by reusing handle_media()
// note: don't need to validate "format", it's redundant
json input_audio = json_value(p, "input_audio", json::object());
std::string url = json_value(input_audio, "data",
json_value(input_audio, "url", std::string()));
handle_media(out_files, url, opt.media_path, false);
p["type"] = "media_marker";
p["text"] = get_media_marker();
@@ -996,10 +1016,10 @@ json oaicompat_chat_params_parse(
throw std::runtime_error("video input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
}
json input_video = json_value(p, "input_video", json::object());
std::string data = json_value(input_video, "data", std::string());
auto decoded_data = base64_decode(data); // expected to be base64 encoded
out_files.push_back(decoded_data);
json input_video = json_value(p, "input_video", json::object());
std::string url = json_value(input_video, "data",
json_value(input_video, "url", std::string()));
handle_media(out_files, url, opt.media_path, false);
p["type"] = "media_marker";
p["text"] = get_media_marker();
@@ -1092,15 +1112,7 @@ json oaicompat_chat_params_parse(
llama_params["chat_parser"] = chat_params.parser;
}
llama_params["message_spans"] = json::array();
for (const auto & span : chat_params.message_spans) {
llama_params["message_spans"].push_back({
{ "role", span.role },
{ "pos", span.pos },
{ "len", span.len },
});
}
llama_params["message_delimiters"] = chat_params.message_delimiters.to_json();
// Reasoning budget: pass parameters through to sampling layer
{
+3
View File
@@ -218,6 +218,9 @@ public:
size_t get_common_prefix(const server_tokens & b) const;
// split the tokens into message spans, skipping over media chunks
common_chat_msg_spans find_message_spans(const common_chat_msg_delimiters & delims) const;
// make sure all text tokens are within the vocab range
bool validate(const struct llama_context * ctx) const;
File diff suppressed because it is too large Load Diff
+3 -1
View File
@@ -53,7 +53,7 @@ struct server_context_meta {
};
enum server_state {
// SERVER_STATE_DOWNLOADING,
SERVER_STATE_DOWNLOADING,
SERVER_STATE_LOADING,
SERVER_STATE_READY,
SERVER_STATE_SLEEPING,
@@ -61,6 +61,7 @@ enum server_state {
static std::string server_state_to_str(server_state state) {
switch (state) {
case SERVER_STATE_DOWNLOADING: return "downloading";
case SERVER_STATE_LOADING: return "loading";
case SERVER_STATE_READY: return "ready";
case SERVER_STATE_SLEEPING: return "sleeping";
@@ -69,6 +70,7 @@ static std::string server_state_to_str(server_state state) {
}
static server_state server_state_from_str(const std::string & str) {
if (str == "downloading") return SERVER_STATE_DOWNLOADING;
if (str == "loading") return SERVER_STATE_LOADING;
if (str == "ready") return SERVER_STATE_READY;
if (str == "sleeping") return SERVER_STATE_SLEEPING;
+19 -3
View File
@@ -7,9 +7,18 @@
#include <unordered_set>
#include <list>
#include <map>
#include <algorithm>
#include <cctype>
#include "server-http.h"
static std::string proxy_header_to_lower(std::string header) {
std::transform(header.begin(), header.end(), header.begin(), [](unsigned char c) {
return std::tolower(c);
});
return header;
}
static server_http_res_ptr proxy_request(const server_http_req & req, std::string method) {
std::string target_url = req.get_param("url");
common_http_url parsed_url = common_http_parse_url(target_url);
@@ -33,11 +42,18 @@ static server_http_res_ptr proxy_request(const server_http_req & req, std::strin
SRV_INF("proxying %s request to %s://%s:%i%s\n", method.c_str(), parsed_url.scheme.c_str(), parsed_url.host.c_str(), parsed_url.port, parsed_url.path.c_str());
std::map<std::string, std::string> headers;
const std::string proxy_header_prefix = "x-llama-server-proxy-header-";
for (auto [key, value] : req.headers) {
auto new_key = key;
if (string_starts_with(new_key, "x-proxy-header-")) {
string_replace_all(new_key, "x-proxy-header-", "");
const std::string lowered_key = proxy_header_to_lower(key);
if (!string_starts_with(lowered_key, proxy_header_prefix)) {
continue;
}
auto new_key = key.substr(proxy_header_prefix.size());
if (new_key.empty()) {
continue;
}
headers[new_key] = value;
}
+253 -203
View File
@@ -5,6 +5,7 @@
#include "build-info.h"
#include "preset.h"
#include "download.h"
#include "http.h"
#include <cpp-httplib/httplib.h> // TODO: remove this once we use HTTP client from download.h
#include <sheredom/subprocess.h>
@@ -25,14 +26,7 @@
#include <sstream>
#include <cstring>
#ifdef _WIN32
#include <winsock2.h>
#include <windows.h>
#else
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <unistd.h>
#ifndef _WIN32
extern char **environ;
#endif
@@ -64,6 +58,17 @@ struct server_subproc {
return sproc.has_value() && subprocess_alive(&sproc.value());
}
void request_exit() {
if (sproc.has_value()) {
FILE * stdin_file = subprocess_stdin(&sproc.value());
if (stdin_file) {
fprintf(stdin_file, "%s\n", CMD_ROUTER_TO_CHILD_EXIT);
fflush(stdin_file);
}
}
stopped.store(true, std::memory_order_relaxed);
}
void terminate() {
if (!sproc.has_value()) {
return;
@@ -213,7 +218,7 @@ void server_model_meta::update_caps() {
});
params.offline = true;
// params.skip_download = true; // TODO: ideally, we should validate the model here, but it takes too much time
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER);
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, {});
if (params.mmproj.path.empty()) {
multimodal = { false, false };
} else {
@@ -323,7 +328,7 @@ void server_models::notify_sse(const std::string & event, const std::string & mo
}
void server_models::load_models() {
// Phase 1: load presets from all sources pure I/O, no lock needed
// Phase 1: load presets from all sources - pure I/O, no lock needed
// 1. cached models
common_presets cached_models = ctx_preset.load_from_cache();
SRV_INF("Loaded %zu cached model presets\n", cached_models.size());
@@ -376,7 +381,7 @@ void server_models::load_models() {
return source_map.count(name) ? source_map.at(name) : SERVER_MODEL_SOURCE_PRESET;
};
// Helpers that read `mapping` must be called while holding the lock.
// Helpers that read `mapping` - must be called while holding the lock.
std::unordered_set<std::string> custom_names;
for (const auto & [name, preset] : custom_presets) custom_names.insert(name);
auto join_set = [](const std::set<std::string> & s) {
@@ -442,6 +447,7 @@ void server_models::load_models() {
/* last_used */ 0,
/* args */ std::vector<std::string>(),
/* loaded_info */ {},
/* progress */ {},
/* exit_code */ 0,
/* stop_timeout */ DEFAULT_STOP_TIMEOUT,
/* multimodal */ mtmd_caps{false, false},
@@ -522,7 +528,7 @@ void server_models::load_models() {
}
}
// join outside the lock monitoring thread calls update_status (needs lock)
// join outside the lock - monitoring thread calls update_status (needs lock)
lk.unlock();
for (auto & th : threads_to_join) th.join();
lk.lock();
@@ -608,6 +614,7 @@ void server_models::load_models() {
/* last_used */ 0,
/* args */ std::vector<std::string>(),
/* loaded_info */ {},
/* progress */ {},
/* exit_code */ 0,
/* stop_timeout */ DEFAULT_STOP_TIMEOUT,
/* multimodal */ mtmd_caps{false, false},
@@ -620,7 +627,7 @@ void server_models::load_models() {
apply_stop_timeout();
// clear reload flag before unlocking for autoload load() blocks on !is_reloading,
// clear reload flag before unlocking for autoload - load() blocks on !is_reloading,
// so clearing it here (while still locked) prevents a deadlock in the autoload calls below
is_reloading = false;
cv.notify_all();
@@ -691,66 +698,6 @@ std::optional<server_model_meta> server_models::get_meta(const std::string & nam
return std::nullopt;
}
static int get_free_port() {
#ifdef _WIN32
WSADATA wsaData;
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {
return -1;
}
typedef SOCKET native_socket_t;
#define INVALID_SOCKET_VAL INVALID_SOCKET
#define CLOSE_SOCKET(s) closesocket(s)
#else
typedef int native_socket_t;
#define INVALID_SOCKET_VAL -1
#define CLOSE_SOCKET(s) close(s)
#endif
native_socket_t sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock == INVALID_SOCKET_VAL) {
#ifdef _WIN32
WSACleanup();
#endif
return -1;
}
struct sockaddr_in serv_addr;
std::memset(&serv_addr, 0, sizeof(serv_addr));
serv_addr.sin_family = AF_INET;
serv_addr.sin_addr.s_addr = htonl(INADDR_ANY);
serv_addr.sin_port = htons(0);
if (bind(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) != 0) {
CLOSE_SOCKET(sock);
#ifdef _WIN32
WSACleanup();
#endif
return -1;
}
#ifdef _WIN32
int namelen = sizeof(serv_addr);
#else
socklen_t namelen = sizeof(serv_addr);
#endif
if (getsockname(sock, (struct sockaddr*)&serv_addr, &namelen) != 0) {
CLOSE_SOCKET(sock);
#ifdef _WIN32
WSACleanup();
#endif
return -1;
}
int port = ntohs(serv_addr.sin_port);
CLOSE_SOCKET(sock);
#ifdef _WIN32
WSACleanup();
#endif
return port;
}
// helper to convert vector<string> to char **
// pointers are only valid as long as the original vector is valid
static std::vector<char *> to_char_ptr_array(const std::vector<std::string> & vec) {
@@ -813,17 +760,23 @@ void server_models::unload_lru() {
}
void server_models::load(const std::string & name) {
if (!has_model(name)) {
throw std::runtime_error("model name=" + name + " is not found");
load(name, load_options{});
}
void server_models::load(const std::string & name, const load_options & opts) {
if (!opts.custom_meta.has_value()) {
if (!has_model(name)) {
throw std::runtime_error("model name=" + name + " is not found");
}
unload_lru();
}
unload_lru();
std::unique_lock<std::mutex> lk(mutex);
// edge case: block until any in-progress reload has finished so we always load
// against the freshest preset and a consistent mapping state
cv.wait(lk, [this]() { return !is_reloading; });
auto meta = mapping[name].meta;
auto meta = opts.custom_meta.has_value() ? *opts.custom_meta : mapping[name].meta;
if (meta.status != SERVER_MODEL_STATUS_UNLOADED) {
SRV_INF("model %s is not ready\n", name.c_str());
return;
@@ -848,7 +801,7 @@ void server_models::load(const std::string & name) {
// prepare new instance info
instance_t inst;
inst.meta = meta;
inst.meta.port = get_free_port();
inst.meta.port = common_http_get_free_port();
inst.meta.status = SERVER_MODEL_STATUS_LOADING;
inst.meta.loaded_info = json{};
inst.meta.last_used = ggml_time_ms();
@@ -867,6 +820,12 @@ void server_models::load(const std::string & name) {
std::vector<std::string> child_env = base_env; // copy
child_env.push_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(base_params.port));
if (opts.mode == SERVER_CHILD_MODE_DOWNLOAD) {
inst.meta.status = SERVER_MODEL_STATUS_DOWNLOADING;
child_env.push_back("LLAMA_SERVER_CHILD_MODE=download");
child_env.push_back("LLAMA_ARG_HF_REPO=" + name);
}
SRV_INF("%s", "spawning server instance with args:\n");
for (const auto & arg : child_args) {
SRV_INF(" %s\n", arg.c_str());
@@ -884,13 +843,17 @@ void server_models::load(const std::string & name) {
if (result != 0) {
throw std::runtime_error("failed to spawn server instance");
}
inst.stdin_file = subprocess_stdin(&inst.subproc->get());
}
// start a thread to manage the child process
// captured variables are guaranteed to be destroyed only after the thread is joined
inst.th = std::thread([this, name, child_proc = inst.subproc, port = inst.meta.port, stop_timeout = inst.meta.stop_timeout]() {
inst.th = std::thread([
this, name,
child_proc = inst.subproc,
port = inst.meta.port,
stop_timeout = inst.meta.stop_timeout,
child_mode = opts.mode
]() {
FILE * stdin_file = subprocess_stdin(&child_proc->get());
FILE * stdout_file = subprocess_stdout(&child_proc->get()); // combined stdout/stderr
@@ -923,7 +886,7 @@ void server_models::load(const std::string & name) {
return is_stopping() || child_proc->stopped.load(std::memory_order_acquire);
});
}
// child crashed or finished on its own skip graceful shutdown sequence
// child crashed or finished on its own, skip graceful shutdown sequence
if (child_proc->stopped.load(std::memory_order_acquire)) {
return;
}
@@ -971,10 +934,14 @@ void server_models::load(const std::string & name) {
subprocess_destroy(&child_proc->get());
// update status and exit code
this->update_status(name, {
SERVER_MODEL_STATUS_UNLOADED,
exit_code
});
if (child_mode == SERVER_CHILD_MODE_DOWNLOAD) {
// instance will be cleaned up on next load_models() call
} else {
this->update_status(name, {
SERVER_MODEL_STATUS_UNLOADED,
exit_code
});
}
SRV_INF("instance name=%s exited with status %d\n", name.c_str(), exit_code);
});
@@ -982,7 +949,7 @@ void server_models::load(const std::string & name) {
{
auto & old_instance = mapping[name];
// old process should have exited already, but just in case, we clean it up here
if (old_instance.subproc->is_alive()) {
if (old_instance.subproc && old_instance.subproc->is_alive()) {
SRV_WRN("old process for model name=%s is still alive, this is unexpected\n", name.c_str());
old_instance.subproc->terminate(); // force kill
}
@@ -999,92 +966,13 @@ void server_models::load(const std::string & name) {
cv.notify_all();
}
// callback for model downloading functionality
struct server_models_download_res : public common_download_callback {
common_params_model model;
common_download_opts opts;
std::function<bool()> should_stop;
std::function<void(const common_download_progress & p)> on_progress;
bool is_ok = false;
bool run() {
try {
common_download_model(model, opts);
is_ok = true;
} catch (const std::exception & e) {
auto model_name = model.get_name();
SRV_ERR("download failed for model name=%s: %s\n", model_name.c_str(), e.what());
is_ok = false;
}
return is_ok;
}
void on_start(const common_download_progress & p) override {
on_progress(p);
}
void on_update(const common_download_progress & p) override {
on_progress(p);
}
void on_done(const common_download_progress &, bool ok) override {
is_ok = ok;
}
bool is_cancelled() const override {
return should_stop();
}
};
void server_models::download(common_params_model && model, common_download_opts && opts) {
std::string name = model.get_name();
GGML_ASSERT(name == model.hf_repo);
std::unique_lock<std::mutex> lk(mutex);
if (mapping.find(name) != mapping.end()) {
throw std::runtime_error("model name=" + name + " already exists");
}
instance_t inst;
inst.meta.name = name;
inst.meta.status = SERVER_MODEL_STATUS_DOWNLOADING;
inst.subproc = std::make_shared<server_subproc>();
auto dl = std::make_unique<server_models_download_res>();
dl->model = model; // copy
dl->opts = opts; // copy
dl->should_stop = [sp = inst.subproc]() {
return sp->stopped.load(std::memory_order_relaxed);
};
dl->on_progress = [this, name](const common_download_progress & p) {
update_download_progress(name, p, false);
};
inst.th = std::thread([this, dl = std::move(dl)]() {
dl->opts.callback = dl.get();
bool ok = dl->run();
auto model_name = dl->model.get_name();
SRV_INF("download finished for model name=%s with status=%s\n",
model_name.c_str(), ok ? "success" : "failure");
update_download_progress(model_name, {}, true, ok);
// need_reload is set inside update_download_progress under the mutex;
// the next load_models() call will clean up this instance
});
mapping[name] = std::move(inst);
notify_sse("status_update", name, {
{"status", server_model_status_to_string(SERVER_MODEL_STATUS_DOWNLOADING)},
});
cv.notify_all();
}
void server_models::unload(const std::string & name) {
std::unique_lock<std::mutex> lk(mutex);
auto it = mapping.find(name);
if (it != mapping.end()) {
if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) {
SRV_INF("cancelling download for model name=%s\n", name.c_str());
it->second.subproc->stopped.store(true, std::memory_order_relaxed);
it->second.subproc->request_exit();
// for convenience, we wait the status change here
wait(lk, name, [](const server_model_meta & new_meta) {
return new_meta.status != SERVER_MODEL_STATUS_DOWNLOADING;
@@ -1140,6 +1028,9 @@ void server_models::update_status(const std::string & name, const update_status_
if (!args.loaded_info.is_null()) {
meta.loaded_info = args.loaded_info;
}
if (!args.progress.is_null()) {
meta.progress = args.progress;
}
}
// broadcast status change to SSE
{
@@ -1152,6 +1043,9 @@ void server_models::update_status(const std::string & name, const update_status_
if (!args.loaded_info.is_null()) {
data["info"] = args.loaded_info;
}
if (!args.progress.is_null()) {
data["progress"] = args.progress;
}
// note: notify_sse doesn't acquire the lock, so no deadlock here
notify_sse("status_change", name, data);
}
@@ -1190,37 +1084,65 @@ void server_models::update_download_progress(const std::string & name, const com
}
bool server_models::remove(const std::string & name) {
auto meta = get_meta(name);
// do everything under one lock acquisition; avoid get_meta() /
// unload() because they can trigger load_models() which erases
// transient DOWNLOADING / DOWNLOADED entries as a side-effect
std::unique_lock<std::mutex> lk(mutex);
if (!meta.has_value()) {
auto it = mapping.find(name);
if (it == mapping.end()) {
throw std::runtime_error("model name=" + name + " is not found");
}
if (meta->source != SERVER_MODEL_SOURCE_CACHE) {
if (it->second.meta.source != SERVER_MODEL_SOURCE_CACHE) {
throw std::runtime_error("model name=" + name + " is not removable (not from cache)");
}
unload(name); // cancel download or stop running instance
{
std::unique_lock<std::mutex> lk(mutex);
// a cancelled download lands on DOWNLOADED; a stopped instance lands on UNLOADED
wait(lk, name, [](const server_model_meta & new_meta) {
return new_meta.status == SERVER_MODEL_STATUS_UNLOADED
|| new_meta.status == SERVER_MODEL_STATUS_DOWNLOADED;
});
// join before erasing - after status reaches UNLOADED/DOWNLOADED the thread no
// longer acquires this mutex, so joining while holding it is safe
if (mapping[name].th.joinable()) {
mapping[name].th.join();
if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) {
// cancel in-flight download
SRV_INF("cancelling download for model name=%s\n", name.c_str());
it->second.subproc->request_exit();
} else if (it->second.meta.is_running()) {
// stop running instance
SRV_INF("stopping model instance name=%s\n", name.c_str());
stopping_models.insert(name);
if (it->second.meta.status == SERVER_MODEL_STATUS_LOADING) {
it->second.subproc->terminate();
}
// remove the model from disk (hold lock to prevent concurrent load)
bool ok = common_download_remove(name);
if (ok) {
mapping.erase(name);
}
SRV_INF("removing model name=%s from cache (%s)\n", name.c_str(), ok ? "succeeded" : "failed");
notify_sse("model_remove", name, {});
return ok;
cv_stop.notify_all();
}
// wait until the monitoring thread finishes
wait(lk, name, [](const server_model_meta & meta) {
return meta.status == SERVER_MODEL_STATUS_UNLOADED
|| meta.status == SERVER_MODEL_STATUS_DOWNLOADED;
});
// re-find after wait - load_models() may have erased the entry during the wait
it = mapping.find(name);
if (it == mapping.end()) {
// load_models() already joined the thread and erased the entry;
// we just need to clean up the cached files on disk
lk.unlock();
bool ok = common_download_remove(name);
SRV_INF("removing model name=%s from cache (%s)\n", name.c_str(), ok ? "succeeded" : "partial");
notify_sse("model_remove", name, {});
return true;
}
// join before erasing - thread no longer acquires this mutex
if (it->second.th.joinable()) {
it->second.th.join();
}
// remove from disk (best-effort: cancelled downloads may have no cached files)
bool ok = common_download_remove(name);
mapping.erase(name);
if (!ok) {
SRV_WRN("removing model name=%s from disk returned false (no cached files?)\n", name.c_str());
}
SRV_INF("removing model name=%s from cache (%s)\n", name.c_str(), ok ? "succeeded" : "partial");
notify_sse("model_remove", name, {});
return true;
}
void server_models::wait(const std::string & name, std::function<bool(const server_model_meta &)> predicate) {
@@ -1235,7 +1157,9 @@ void server_models::wait(std::unique_lock<std::mutex> & lk, const std::string &
return predicate(it->second.meta);
}
return false;
// model was removed from mapping by another code path (e.g. load_models()).
// nothing left to wait for - tell the caller to proceed.
return true;
});
}
@@ -1320,10 +1244,39 @@ void server_models::handle_child_state(const std::string & name, const std::stri
}
switch (state) {
case SERVER_STATE_DOWNLOADING:
{
std::string result = json_value(payload, "result", std::string());
std::string url = json_value(payload, "url", std::string());
auto request_exit = [&]() {
std::lock_guard<std::mutex> lk(mutex);
auto it = mapping.find(name);
if (it != mapping.end()) {
return it->second.subproc->request_exit();
}
};
if (result == "download_finished") {
update_download_progress(name, {}, true, true);
request_exit();
} else if (result == "download_failed") {
update_download_progress(name, {}, true, false);
request_exit();
} else if (!url.empty()) {
common_download_progress p;
p.url = url;
p.downloaded = json_value(payload, "downloaded", (size_t)0);
p.total = json_value(payload, "total", (size_t)0);
update_download_progress(name, p, false);
}
} break;
case SERVER_STATE_LOADING:
{
// do nothing for now
// TODO: report loading progress for first load and wakeup from sleep
update_status(name, {
SERVER_MODEL_STATUS_LOADING,
0,
nullptr, // no loaded_info yet
payload,
});
} break;
case SERVER_STATE_READY:
{
@@ -1331,7 +1284,8 @@ void server_models::handle_child_state(const std::string & name, const std::stri
SERVER_MODEL_STATUS_LOADED,
0,
// note: payload can be empty if this is a wakeup from sleep
payload.size() > 0 ? payload : nullptr
payload.size() > 0 ? payload : nullptr,
{}, // reset progress info
});
} break;
case SERVER_STATE_SLEEPING:
@@ -1353,6 +1307,92 @@ bool server_child::is_child() {
return router_port != nullptr;
}
server_child_mode server_child::get_mode() {
const char * mode = std::getenv("LLAMA_SERVER_CHILD_MODE");
std::string mode_str(mode ? mode : "");
if (mode_str == "download") {
return SERVER_CHILD_MODE_DOWNLOAD;
} else {
return SERVER_CHILD_MODE_NORMAL;
}
}
struct server_download_state : public common_download_callback {
server_child * self;
std::function<bool()> should_stop;
std::atomic<int64_t> last_progress_time{0}; // multiple files downloading in different threads
bool is_ok = false;
server_download_state(server_child * s) : self(s) {}
bool run(common_params & params) {
try {
common_params_handle_models_params p;
p.callback = this;
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, p);
is_ok = true;
} catch (const std::exception & e) {
auto model_name = params.model.get_name();
SRV_ERR("download failed for model name=%s: %s\n", model_name.c_str(), e.what());
is_ok = false;
}
return is_ok;
}
void on_progress(const common_download_progress & p) {
json data = {
{"url", p.url},
{"downloaded", p.downloaded},
{"total", p.total},
};
self->notify_to_router(server_state_to_str(SERVER_STATE_DOWNLOADING), data);
}
void on_start(const common_download_progress & p) override {
on_progress(p);
}
void on_update(const common_download_progress & p) override {
int64_t now = ggml_time_ms();
// throttle progress updates to avoid flooding logs
if (now - last_progress_time.load(std::memory_order_relaxed) >= 100) {
on_progress(p);
last_progress_time.store(now, std::memory_order_relaxed);
}
}
void on_done(const common_download_progress & p, bool) override {
on_progress(p);
}
bool is_cancelled() const override {
return should_stop ? should_stop() : false;
}
};
int server_child::run_download(common_params & params) {
auto cancelled = std::make_shared<std::atomic<bool>>(false);
// monitor stdin for cancellation command from the router
std::thread signal_thread = setup([cancelled](int) {
cancelled->store(true, std::memory_order_relaxed);
});
server_download_state dl(this);
dl.should_stop = [cancelled]() {
return cancelled->load(std::memory_order_relaxed);
};
bool ok = dl.run(params);
notify_to_router(server_state_to_str(SERVER_STATE_DOWNLOADING), {
{"result", ok ? "download_finished" : "download_failed"},
});
// router should send CMD_ROUTER_TO_CHILD_EXIT after receiving the result
if (signal_thread.joinable()) {
signal_thread.join();
}
SRV_INF("download completed %s\n", ok ? "successfully" : "with errors");
return 0;
}
std::thread server_child::setup(const std::function<void(int)> & shutdown_handler) {
// setup thread for monitoring stdin
return std::thread([shutdown_handler]() {
@@ -1384,6 +1424,7 @@ void server_child::notify_to_router(const std::string & state, const json & payl
{"state", state},
{"payload", payload},
};
std::lock_guard<std::mutex> lk(mtx_stdout);
common_log_pause(common_log_main());
fflush(stdout);
fprintf(stdout, "%s%s\n", CMD_CHILD_TO_ROUTER_STATE, safe_json_to_str(data).c_str());
@@ -1625,7 +1666,7 @@ void server_models_routes::init_routes() {
res_err(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST));
return res;
}
if (!model->is_running()) {
if (!model->is_running() && model->status != SERVER_MODEL_STATUS_DOWNLOADING) {
res_err(res, format_error_response("model is not running", ERROR_TYPE_INVALID_REQUEST));
return res;
}
@@ -1666,8 +1707,9 @@ void server_models_routes::init_routes() {
model.hf_repo = name;
opts.bearer_token = params.hf_token;
opts.download_mmproj = true;
opts.download_mtp = true;
// note: we only check main model, no need sidecar here
opts.download_mmproj = false;
opts.download_mtp = false;
// first, only check if the model is valid and can be downloaded
opts.skip_download = true;
@@ -1688,10 +1730,21 @@ void server_models_routes::init_routes() {
throw std::invalid_argument("model validation failed, unable to download");
}
// reject if model already exists
if (models.has_model(name)) {
throw std::invalid_argument("model '" + name + "' already exists");
}
// then, proceed with the actual download
opts.skip_download = false;
SRV_INF("starting download for model '%s'\n", name.c_str());
models.download(std::move(model), std::move(opts));
{
server_models::load_options load_opts;
load_opts.mode = SERVER_CHILD_MODE_DOWNLOAD;
load_opts.custom_meta = server_model_meta{};
load_opts.custom_meta->source = SERVER_MODEL_SOURCE_CACHE;
load_opts.custom_meta->name = name;
models.load(name, load_opts);
}
res_ok(res, {{"success", true}});
return res;
@@ -1705,10 +1758,7 @@ void server_models_routes::init_routes() {
throw std::invalid_argument("model must be a non-empty string");
}
bool ok = models.remove(name);
if (!ok) {
throw std::runtime_error("failed to remove model '" + name + "'");
}
models.remove(name); // throws on error
res_ok(res, {{"success", true}});
return res;
+22 -6
View File
@@ -40,6 +40,11 @@ enum server_model_source {
SERVER_MODEL_SOURCE_CACHE,
};
enum server_child_mode {
SERVER_CHILD_MODE_NORMAL, // load the model and run normally
SERVER_CHILD_MODE_DOWNLOAD, // download the model and exit
};
static std::string server_model_status_to_string(server_model_status status) {
switch (status) {
case SERVER_MODEL_STATUS_DOWNLOADING: return "downloading";
@@ -72,6 +77,7 @@ struct server_model_meta {
int64_t last_used = 0; // for LRU unloading
std::vector<std::string> args; // args passed to the model instance, will be populated by render_args()
json loaded_info; // info to be reflected via /v1/models endpoint ; if in DOWNLOADING state, it should contain download progress info
json progress; // reflect load or download progress info, if any
int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED)
int stop_timeout = 0; // seconds to wait before force-killing the model instance during shutdown
mtmd_caps multimodal; // multimodal capabilities
@@ -104,7 +110,6 @@ private:
std::shared_ptr<server_subproc> subproc; // shared between main thread and monitoring thread
std::thread th;
server_model_meta meta;
FILE * stdin_file = nullptr;
};
std::mutex mutex;
@@ -160,22 +165,27 @@ public:
// return a copy of all model metadata (thread-safe)
std::vector<server_model_meta> get_all_meta();
struct load_options {
server_child_mode mode = SERVER_CHILD_MODE_NORMAL;
// used for spawning a downloading child process
std::optional<server_model_meta> custom_meta = std::nullopt;
};
// load and unload model instances
// these functions are thread-safe
void load(const std::string & name);
void load(const std::string & name, const load_options & opts);
void unload(const std::string & name);
void unload_all();
// download a new model, progress is reported via SSE
// to stop the download, call unload()
void download(common_params_model && model, common_download_opts && opts);
// update the status of a model instance (thread-safe)
struct update_status_args {
server_model_status status;
int exit_code = 0; // only valid if status == UNLOADED
json loaded_info = nullptr;
json progress = nullptr;
};
// update the status of a model instance (thread-safe)
// also send SSE notification to /models/sse endpoint
void update_status(const std::string & name, const update_status_args & args);
void update_download_progress(const std::string & name, const common_download_progress & progress, bool done, bool ok = true);
@@ -208,8 +218,14 @@ public:
};
struct server_child {
// serializes the notify_to_router writes
std::mutex mtx_stdout;
std::atomic<bool> is_finished_downloading = false; // set by run_download
// return true if the current process is a child server instance
bool is_child();
server_child_mode get_mode();
int run_download(common_params & params);
// register the shutdown_handler to be called by the router
// return the monitoring thread (to be joined by the caller)
+3
View File
@@ -14,6 +14,9 @@ std::vector<std::unique_ptr<field>> make_llama_cmpl_schema(const common_params &
fields.emplace_back(f);
};
add((new field_bool("verbose", params.verbose))
->set_desc("Include __verbose field in the response with additional debug information"));
add((new field_bool("timings_per_token", params.timings_per_token))
->set_desc("Include prompt processing and text generation speed information in each response"));
+6 -3
View File
@@ -591,10 +591,11 @@ json server_task_result_cmpl_final::to_json_oaicompat_resp() {
for (const common_chat_tool_call & tool_call : oaicompat_msg.tool_calls) {
output.push_back(json {
{"id", "fc_" + tool_call.id},
{"type", "function_call"},
{"status", "completed"},
{"arguments", tool_call.arguments},
{"call_id", "fc_" + tool_call.id},
{"call_id", "call_" + tool_call.id},
{"name", tool_call.name},
});
}
@@ -690,10 +691,11 @@ json server_task_result_cmpl_final::to_json_oaicompat_resp_stream() {
for (const common_chat_tool_call & tool_call : oaicompat_msg.tool_calls) {
const json output_item = {
{"id", "fc_" + tool_call.id},
{"type", "function_call"},
{"status", "completed"},
{"arguments", tool_call.arguments},
{"call_id", "fc_" + tool_call.id},
{"call_id", "call_" + tool_call.id},
{"name", tool_call.name}
};
server_sent_events.push_back(json {
@@ -1277,8 +1279,9 @@ json server_task_result_cmpl_partial::to_json_oaicompat_resp() {
{"data", json {
{"type", "response.output_item.added"},
{"item", json {
{"id", "fc_" + diff.tool_call_delta.id},
{"arguments", ""},
{"call_id", "fc_" + diff.tool_call_delta.id},
{"call_id", "call_" + diff.tool_call_delta.id},
{"name", diff.tool_call_delta.name},
{"type", "function_call"},
{"status", "in_progress"},

Some files were not shown because too many files have changed in this diff Show More