mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-07-01 01:57:43 +02:00
Compare commits
68 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a432e6f863 | |||
| 5d67f69f59 | |||
| beef5cf077 | |||
| b093e46873 | |||
| 1401fc3ca7 | |||
| 85c58bbcd0 | |||
| 19296c1735 | |||
| 90c111bf98 | |||
| 75ad0b23ed | |||
| f7421eabe8 | |||
| 59797670dc | |||
| c926ad0985 | |||
| a3900a6694 | |||
| 7c908502ea | |||
| 035cd8f9a6 | |||
| 73618f27a8 | |||
| 23ee8797e1 | |||
| dec5ca5577 | |||
| 9c0ac887f3 | |||
| 721354fbdf | |||
| 6ee0f65793 | |||
| 099b579acb | |||
| f8cc15f163 | |||
| 37957e8531 | |||
| d0f9d2e5ac | |||
| 0ef6f06d55 | |||
| 52b3df0023 | |||
| 7c082bc417 | |||
| bddfd2b113 | |||
| 0d135df48c | |||
| bf533823cd | |||
| 2f89acc2bc | |||
| bfa3219177 | |||
| d6d899580d | |||
| 8a118ee86c | |||
| d789527482 | |||
| 063d9c156e | |||
| c57607016a | |||
| 4a80943174 | |||
| 84de01a1f1 | |||
| 75f460ac28 | |||
| 8452824611 | |||
| e27f308597 | |||
| 67e9fd3b74 | |||
| 796f41bedc | |||
| 37a77fb057 | |||
| f4043fec01 | |||
| f449e05537 | |||
| 2b686a9120 | |||
| 4b48a53b6c | |||
| e475fa2b5f | |||
| 175147e8f6 | |||
| fabde3bf51 | |||
| 0d2d9ccbf6 | |||
| 8c2d6f6475 | |||
| 38724ab593 | |||
| e2e7a9b2d0 | |||
| b14e3fb90c | |||
| 159d093a43 | |||
| 5fd2dc2c41 | |||
| 1868af13ac | |||
| 5bd21b8555 | |||
| 80452d65b9 | |||
| 8141e730f1 | |||
| db52540f73 | |||
| 3a3edc9ac6 | |||
| 40f3aafc45 | |||
| a6b3260a42 |
@@ -13,6 +13,20 @@ ARG APP_REVISION=N/A
|
||||
# BUILD STAGE
|
||||
# Compile all binary files and libraries
|
||||
# ==============================================================================
|
||||
ARG NODE_VERSION=24
|
||||
|
||||
FROM docker.io/node:$NODE_VERSION AS web
|
||||
|
||||
ARG APP_VERSION
|
||||
|
||||
WORKDIR /app/tools/ui
|
||||
|
||||
COPY tools/ui/package.json tools/ui/package-lock.json ./
|
||||
RUN npm ci
|
||||
|
||||
COPY tools/ui/ ./
|
||||
RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build
|
||||
|
||||
FROM ${CANN_BASE_IMAGE} AS build
|
||||
|
||||
# -- Install build dependencies --
|
||||
@@ -26,6 +40,8 @@ WORKDIR /app
|
||||
# -- Copy project files --
|
||||
COPY . .
|
||||
|
||||
COPY --from=web /app/tools/ui/dist tools/ui/dist
|
||||
|
||||
# -- Set CANN environment variables (required for compilation) --
|
||||
# Using ENV instead of `source` allows environment variables to persist across the entire image layer
|
||||
ENV ASCEND_TOOLKIT_HOME=/usr/local/Ascend/ascend-toolkit/latest
|
||||
|
||||
@@ -3,6 +3,20 @@ ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
|
||||
ARG NODE_VERSION=24
|
||||
|
||||
FROM docker.io/node:$NODE_VERSION AS web
|
||||
|
||||
ARG APP_VERSION
|
||||
|
||||
WORKDIR /app/tools/ui
|
||||
|
||||
COPY tools/ui/package.json tools/ui/package-lock.json ./
|
||||
RUN npm ci
|
||||
|
||||
COPY tools/ui/ ./
|
||||
RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build
|
||||
|
||||
FROM docker.io/ubuntu:$UBUNTU_VERSION AS build
|
||||
|
||||
ARG TARGETARCH
|
||||
@@ -16,6 +30,8 @@ WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
COPY --from=web /app/tools/ui/dist tools/ui/dist
|
||||
|
||||
RUN if [ "$TARGETARCH" = "amd64" ] || [ "$TARGETARCH" = "arm64" ]; then \
|
||||
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DGGML_NATIVE=OFF -DLLAMA_BUILD_TESTS=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON; \
|
||||
else \
|
||||
|
||||
@@ -11,6 +11,20 @@ ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
|
||||
ARG NODE_VERSION=24
|
||||
|
||||
FROM docker.io/node:$NODE_VERSION AS web
|
||||
|
||||
ARG APP_VERSION
|
||||
|
||||
WORKDIR /app/tools/ui
|
||||
|
||||
COPY tools/ui/package.json tools/ui/package-lock.json ./
|
||||
RUN npm ci
|
||||
|
||||
COPY tools/ui/ ./
|
||||
RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build
|
||||
|
||||
FROM ${BASE_CUDA_DEV_CONTAINER} AS build
|
||||
|
||||
ARG GCC_VERSION
|
||||
@@ -26,6 +40,8 @@ WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
COPY --from=web /app/tools/ui/dist tools/ui/dist
|
||||
|
||||
RUN if [ "${CUDA_DOCKER_ARCH}" != "default" ]; then \
|
||||
export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=${CUDA_DOCKER_ARCH}"; \
|
||||
fi && \
|
||||
|
||||
@@ -5,6 +5,20 @@ ARG APP_REVISION=N/A
|
||||
|
||||
## Build Image
|
||||
|
||||
ARG NODE_VERSION=24
|
||||
|
||||
FROM docker.io/node:$NODE_VERSION AS web
|
||||
|
||||
ARG APP_VERSION
|
||||
|
||||
WORKDIR /app/tools/ui
|
||||
|
||||
COPY tools/ui/package.json tools/ui/package-lock.json ./
|
||||
RUN npm ci
|
||||
|
||||
COPY tools/ui/ ./
|
||||
RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build
|
||||
|
||||
FROM docker.io/intel/deep-learning-essentials:$ONEAPI_VERSION AS build
|
||||
|
||||
ARG GGML_SYCL_F16=ON
|
||||
@@ -22,6 +36,8 @@ WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
COPY --from=web /app/tools/ui/dist tools/ui/dist
|
||||
|
||||
RUN if [ "${GGML_SYCL_F16}" = "ON" ]; then \
|
||||
echo "GGML_SYCL_F16 is set" \
|
||||
&& export OPT_SYCL_F16="-DGGML_SYCL_F16=ON" \
|
||||
|
||||
@@ -10,6 +10,20 @@ ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
|
||||
ARG NODE_VERSION=24
|
||||
|
||||
FROM docker.io/node:$NODE_VERSION AS web
|
||||
|
||||
ARG APP_VERSION
|
||||
|
||||
WORKDIR /app/tools/ui
|
||||
|
||||
COPY tools/ui/package.json tools/ui/package-lock.json ./
|
||||
RUN npm ci
|
||||
|
||||
COPY tools/ui/ ./
|
||||
RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build
|
||||
|
||||
FROM ${BASE_MUSA_DEV_CONTAINER} AS build
|
||||
|
||||
# MUSA architecture to build for (defaults to all supported archs)
|
||||
@@ -29,6 +43,8 @@ WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
COPY --from=web /app/tools/ui/dist tools/ui/dist
|
||||
|
||||
RUN if [ "${MUSA_DOCKER_ARCH}" != "default" ]; then \
|
||||
export CMAKE_ARGS="-DMUSA_ARCHITECTURES=${MUSA_DOCKER_ARCH}"; \
|
||||
fi && \
|
||||
|
||||
@@ -22,6 +22,20 @@ ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
|
||||
ARG NODE_VERSION=24
|
||||
|
||||
FROM docker.io/node:$NODE_VERSION AS web
|
||||
|
||||
ARG APP_VERSION
|
||||
|
||||
WORKDIR /app/tools/ui
|
||||
|
||||
COPY tools/ui/package.json tools/ui/package-lock.json ./
|
||||
RUN npm ci
|
||||
|
||||
COPY tools/ui/ ./
|
||||
RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build
|
||||
|
||||
## Build Image
|
||||
FROM docker.io/ubuntu:${UBUNTU_VERSION} AS build
|
||||
|
||||
@@ -69,6 +83,8 @@ WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
COPY --from=web /app/tools/ui/dist tools/ui/dist
|
||||
|
||||
# Build Stage
|
||||
RUN bash -c "source ${OpenVINO_DIR}/setupvars.sh && \
|
||||
cmake -B build/ReleaseOV -G Ninja \
|
||||
|
||||
@@ -11,6 +11,20 @@ ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
|
||||
ARG NODE_VERSION=24
|
||||
|
||||
FROM docker.io/node:$NODE_VERSION AS web
|
||||
|
||||
ARG APP_VERSION
|
||||
|
||||
WORKDIR /app/tools/ui
|
||||
|
||||
COPY tools/ui/package.json tools/ui/package-lock.json ./
|
||||
RUN npm ci
|
||||
|
||||
COPY tools/ui/ ./
|
||||
RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build
|
||||
|
||||
### Build image
|
||||
FROM ${BASE_ROCM_DEV_CONTAINER} AS build
|
||||
|
||||
@@ -38,6 +52,8 @@ WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
COPY --from=web /app/tools/ui/dist tools/ui/dist
|
||||
|
||||
RUN HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" \
|
||||
cmake -S . -B build \
|
||||
-DGGML_HIP=ON \
|
||||
|
||||
@@ -3,6 +3,20 @@ ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
|
||||
ARG NODE_VERSION=24
|
||||
|
||||
FROM docker.io/node:$NODE_VERSION AS web
|
||||
|
||||
ARG APP_VERSION
|
||||
|
||||
WORKDIR /app/tools/ui
|
||||
|
||||
COPY tools/ui/package.json tools/ui/package-lock.json ./
|
||||
RUN npm ci
|
||||
|
||||
COPY tools/ui/ ./
|
||||
RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build
|
||||
|
||||
FROM docker.io/ubuntu:$UBUNTU_VERSION AS build
|
||||
|
||||
# Install build tools
|
||||
@@ -17,6 +31,8 @@ WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
COPY --from=web /app/tools/ui/dist tools/ui/dist
|
||||
|
||||
RUN cmake -B build -DGGML_NATIVE=OFF -DGGML_VULKAN=ON -DLLAMA_BUILD_TESTS=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON && \
|
||||
cmake --build build --config Release -j$(nproc)
|
||||
|
||||
|
||||
@@ -3,6 +3,20 @@ ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
|
||||
ARG NODE_VERSION=24
|
||||
|
||||
FROM docker.io/node:$NODE_VERSION AS web
|
||||
|
||||
ARG APP_VERSION
|
||||
|
||||
WORKDIR /app/tools/ui
|
||||
|
||||
COPY tools/ui/package.json tools/ui/package-lock.json ./
|
||||
RUN npm ci
|
||||
|
||||
COPY tools/ui/ ./
|
||||
RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build
|
||||
|
||||
FROM docker.io/ubuntu:$UBUNTU_VERSION AS build
|
||||
|
||||
RUN apt-get update && \
|
||||
@@ -14,6 +28,8 @@ WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
COPY --from=web /app/tools/ui/dist tools/ui/dist
|
||||
|
||||
RUN cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DGGML_NATIVE=OFF -DLLAMA_BUILD_TESTS=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_ZENDNN=ON && \
|
||||
cmake --build build -j $(nproc)
|
||||
|
||||
|
||||
@@ -10,6 +10,8 @@
|
||||
|
||||
build*/
|
||||
|
||||
tools/ui/node_modules/
|
||||
|
||||
models/*
|
||||
|
||||
/llama-cli
|
||||
|
||||
@@ -58,6 +58,13 @@ jobs:
|
||||
git tag ${{ steps.srctag.outputs.name }} || exit 0
|
||||
git push origin ${{ steps.srctag.outputs.name }} || exit 0
|
||||
|
||||
build_ui:
|
||||
name: Build UI
|
||||
needs: create_tag
|
||||
uses: ./.github/workflows/ui-build.yml
|
||||
with:
|
||||
hf_ui_version: ${{ needs.create_tag.outputs.source_tag }}
|
||||
|
||||
prepare_matrices:
|
||||
name: Prepare Docker matrices
|
||||
runs-on: ubuntu-24.04
|
||||
@@ -79,7 +86,7 @@ jobs:
|
||||
[
|
||||
{ "tag": "cpu", "dockerfile": ".devops/cpu.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04" },
|
||||
{ "tag": "cpu", "dockerfile": ".devops/cpu.Dockerfile", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-arm" },
|
||||
{ "tag": "cpu", "dockerfile": ".devops/s390x.Dockerfile", "platforms": "linux/s390x", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-s390x" },
|
||||
{ "tag": "cpu", "dockerfile": ".devops/s390x.Dockerfile", "platforms": "linux/s390x", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-s390x", "prebuilt_ui": true },
|
||||
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "12.8.1", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
|
||||
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "12.8.1", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" },
|
||||
{ "tag": "cuda13", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "13.3.0", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
|
||||
@@ -135,7 +142,7 @@ jobs:
|
||||
|
||||
push_to_registry:
|
||||
name: Push Docker image to Docker Registry
|
||||
needs: [prepare_matrices, create_tag]
|
||||
needs: [prepare_matrices, create_tag, build_ui]
|
||||
|
||||
runs-on: ${{ matrix.config.runs_on }}
|
||||
strategy:
|
||||
@@ -150,6 +157,13 @@ jobs:
|
||||
fetch-depth: 0
|
||||
ref: ${{ needs.create_tag.outputs.source_tag }}
|
||||
|
||||
- name: Download prebuilt UI
|
||||
if: ${{ matrix.config.prebuilt_ui == true }}
|
||||
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8
|
||||
with:
|
||||
name: ui-build
|
||||
path: tools/ui/dist
|
||||
|
||||
- name: Set up QEMU
|
||||
if: ${{ contains(matrix.config.platforms, 'linux/amd64') }}
|
||||
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4
|
||||
|
||||
@@ -1627,6 +1627,7 @@ jobs:
|
||||
**Windows:**
|
||||
- [Windows x64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cpu-x64.zip)
|
||||
- [Windows arm64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cpu-arm64.zip)
|
||||
- [Windows arm64 (OpenCL Adreno)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-opencl-adreno-arm64.zip)
|
||||
- [Windows x64 (CUDA 12)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cuda-12.4-x64.zip) - [CUDA 12.4 DLLs](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/cudart-llama-bin-win-cuda-12.4-x64.zip)
|
||||
- [Windows x64 (CUDA 13)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cuda-13.3-x64.zip) - [CUDA 13.3 DLLs](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/cudart-llama-bin-win-cuda-13.3-x64.zip)
|
||||
- [Windows x64 (Vulkan)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-vulkan-x64.zip)
|
||||
|
||||
@@ -25,13 +25,3 @@ Commits:
|
||||
- Do not explicitly set the git author in commits - rely on the default git config
|
||||
- Always use `--no-gpg-sign` when committing
|
||||
- Never `git push` without explicit confirmation from the user
|
||||
|
||||
Resources (read on demand):
|
||||
- [CONTRIBUTING.md](CONTRIBUTING.md)
|
||||
- [Build documentation](docs/build.md)
|
||||
- [Server usage documentation](tools/server/README.md)
|
||||
- [Server development documentation](tools/server/README-dev.md)
|
||||
- [PEG parser](docs/development/parsing.md)
|
||||
- [Auto parser](docs/autoparser.md)
|
||||
- [Jinja engine](common/jinja/README.md)
|
||||
- [PR template](.github/pull_request_template.md)
|
||||
|
||||
+1
-1
@@ -10,7 +10,7 @@
|
||||
# ggml-org/ggml-rpc : rgerganov
|
||||
# ggml-org/ggml-sycl : arthw
|
||||
# ggml-org/ggml-vulkan : 0cc4m, jeffbolznv
|
||||
# ggml-org/ggml-webgpu : reeselevine
|
||||
# ggml-org/ggml-webgpu : reeselevine, yomaytk
|
||||
# ggml-org/ggml-zdnn : taronaeo
|
||||
# ggml-org/llama-common : ggerganov, aldehir, angt, danbev, ngxson, pwilkin
|
||||
# ggml-org/llama-mtmd : ngxson
|
||||
|
||||
+80
-63
@@ -17,6 +17,7 @@
|
||||
# define NOMINMAX
|
||||
#endif
|
||||
#include <windows.h>
|
||||
#include <shellapi.h>
|
||||
#endif
|
||||
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
@@ -300,9 +301,10 @@ static handle_model_result common_params_handle_model(struct common_params_model
|
||||
const common_download_opts & opts) {
|
||||
handle_model_result result;
|
||||
|
||||
// TODO @ngxson : refactor this into a new common_model_download_context
|
||||
|
||||
if (!model.docker_repo.empty()) {
|
||||
model.path = common_docker_resolve_model(model.docker_repo);
|
||||
model.name = model.docker_repo;
|
||||
} else if (!model.hf_repo.empty()) {
|
||||
// If -m was used with -hf, treat the model "path" as the hf_file to download
|
||||
if (model.hf_file.empty() && !model.path.empty()) {
|
||||
@@ -322,7 +324,6 @@ static handle_model_result common_params_handle_model(struct common_params_model
|
||||
throw std::runtime_error("failed to download model from Hugging Face");
|
||||
}
|
||||
|
||||
model.name = model.hf_repo;
|
||||
model.path = download_result.model_path;
|
||||
|
||||
if (!download_result.mmproj_path.empty()) {
|
||||
@@ -397,7 +398,7 @@ static bool parse_bool_value(const std::string & value) {
|
||||
// CLI argument parsing functions
|
||||
//
|
||||
|
||||
bool common_params_handle_models(common_params & params, llama_example curr_ex) {
|
||||
bool common_params_handle_models(common_params & params, llama_example curr_ex, const common_params_handle_models_params & handle_params) {
|
||||
const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(),
|
||||
params.speculative.types.end(),
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end();
|
||||
@@ -408,6 +409,11 @@ bool common_params_handle_models(common_params & params, llama_example curr_ex)
|
||||
opts.skip_download = params.skip_download;
|
||||
opts.download_mtp = spec_type_draft_mtp;
|
||||
opts.download_mmproj = !params.no_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty();
|
||||
opts.preset_only = handle_params.preset_only;
|
||||
|
||||
if (handle_params.callback) {
|
||||
opts.callback = handle_params.callback;
|
||||
}
|
||||
|
||||
// sub-models (draft, mmproj, vocoder) are explicitly specified by the user,
|
||||
// so we should not auto-discover mtp/mmproj siblings for them
|
||||
@@ -585,19 +591,20 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n");
|
||||
}
|
||||
|
||||
// export_graph_ops loads only metadata
|
||||
const bool skip_model_download = ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
|
||||
const bool skip_model_download =
|
||||
// server will call common_params_handle_models() later, so we skip it here
|
||||
ctx_arg.ex == LLAMA_EXAMPLE_SERVER ||
|
||||
// export_graph_ops loads only metadata
|
||||
ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
|
||||
|
||||
if (!skip_model_download) {
|
||||
// handle model and download
|
||||
common_params_handle_models(params, ctx_arg.ex);
|
||||
common_params_handle_models(params, ctx_arg.ex, {});
|
||||
|
||||
// model is required (except for server)
|
||||
// TODO @ngxson : maybe show a list of available models in CLI in this case
|
||||
if (params.model.path.empty()
|
||||
&& ctx_arg.ex != LLAMA_EXAMPLE_SERVER
|
||||
&& !params.usage
|
||||
&& !params.completion) {
|
||||
bool can_skip_model = params.usage || params.completion || !params.server_base.empty();
|
||||
if (!can_skip_model && params.model.path.empty()) {
|
||||
throw std::invalid_argument("error: --model is required\n");
|
||||
}
|
||||
}
|
||||
@@ -893,7 +900,44 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<com
|
||||
return true;
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
struct utf8_argv {
|
||||
std::vector<std::string> buf;
|
||||
std::vector<char*> ptrs;
|
||||
};
|
||||
|
||||
static utf8_argv make_utf8_argv() {
|
||||
utf8_argv out;
|
||||
int wargc = 0;
|
||||
LPWSTR* wargv = CommandLineToArgvW(GetCommandLineW(), &wargc);
|
||||
if (!wargv) return out;
|
||||
|
||||
out.buf.reserve(wargc);
|
||||
for (int i = 0; i < wargc; ++i) {
|
||||
int n = WideCharToMultiByte(CP_UTF8, WC_ERR_INVALID_CHARS, wargv[i], -1, nullptr, 0, nullptr, nullptr);
|
||||
if (n <= 0) { out.buf.emplace_back(); continue; }
|
||||
auto& s = out.buf.emplace_back();
|
||||
s.resize(static_cast<size_t>(n - 1));
|
||||
(void)WideCharToMultiByte(CP_UTF8, 0, wargv[i], -1, s.data(), n, nullptr, nullptr);
|
||||
}
|
||||
LocalFree(wargv);
|
||||
|
||||
out.ptrs.reserve(out.buf.size() + 1);
|
||||
for (auto& s : out.buf) out.ptrs.push_back(s.data());
|
||||
out.ptrs.push_back(nullptr);
|
||||
return out;
|
||||
}
|
||||
#endif
|
||||
|
||||
bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **)) {
|
||||
#ifdef _WIN32
|
||||
auto utf8 = make_utf8_argv();
|
||||
// repair argv only when it matches the process command line
|
||||
if (static_cast<int>(utf8.buf.size()) == argc) {
|
||||
argv = utf8.ptrs.data();
|
||||
}
|
||||
#endif
|
||||
|
||||
auto ctx_arg = common_params_parser_init(params, ex, print_usage);
|
||||
const common_params params_org = ctx_arg.params; // the example can modify the default params
|
||||
|
||||
@@ -1074,6 +1118,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.completion = true;
|
||||
}
|
||||
));
|
||||
add_opt(common_arg(
|
||||
{"--server-base"}, "URL",
|
||||
string_format("connect to this server instead of starting a new one, example: 'http://localhost:8080' (default: none)"),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.server_base = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_CLI}));
|
||||
add_opt(common_arg(
|
||||
{"--verbose-prompt"},
|
||||
string_format("print a verbose prompt before generation (default: %s)", params.verbose_prompt ? "true" : "false"),
|
||||
@@ -2830,62 +2881,26 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.api_prefix = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_API_PREFIX"));
|
||||
// Deprecated: use --ui-config instead (kept for backward compat)
|
||||
add_opt(common_arg(
|
||||
{"--webui-config"}, "JSON",
|
||||
"[DEPRECATED: use --ui-config] JSON that provides default WebUI settings (overrides WebUI defaults)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.ui_config_json = value;
|
||||
params.webui_config_json = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_WEBUI_CONFIG"));
|
||||
|
||||
add_opt(common_arg(
|
||||
{"--ui-config"}, "JSON",
|
||||
{"--ui-config", "--webui-config"}, "JSON",
|
||||
"JSON that provides default UI settings (overrides UI defaults)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.ui_config_json = value;
|
||||
params.webui_config_json = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_UI_CONFIG"));
|
||||
|
||||
// Deprecated: use --ui-config-file instead (kept for backward compat)
|
||||
add_opt(common_arg(
|
||||
{"--webui-config-file"}, "PATH",
|
||||
"[DEPRECATED: use --ui-config-file] JSON file that provides default WebUI settings (overrides WebUI defaults)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.ui_config_json = read_file(value);
|
||||
params.webui_config_json = params.ui_config_json;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_WEBUI_CONFIG_FILE"));
|
||||
|
||||
add_opt(common_arg(
|
||||
{"--ui-config-file"}, "PATH",
|
||||
{"--ui-config-file", "--webui-config-file"}, "PATH",
|
||||
"JSON file that provides default UI settings (overrides UI defaults)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.ui_config_json = read_file(value);
|
||||
params.webui_config_json = params.ui_config_json;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_UI_CONFIG_FILE"));
|
||||
|
||||
// Deprecated: use --ui-mcp-proxy instead (kept for backward compat)
|
||||
add_opt(common_arg(
|
||||
{"--webui-mcp-proxy"},
|
||||
{"--no-webui-mcp-proxy"},
|
||||
"[DEPRECATED: use --ui-mcp-proxy/--no-ui-mcp-proxy] experimental: whether to enable MCP CORS proxy",
|
||||
[](common_params & params, bool value) {
|
||||
params.ui_mcp_proxy = value;
|
||||
params.webui_mcp_proxy = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_WEBUI_MCP_PROXY"));
|
||||
|
||||
add_opt(common_arg(
|
||||
{"--ui-mcp-proxy"},
|
||||
{"--no-ui-mcp-proxy"},
|
||||
{"--ui-mcp-proxy", "--webui-mcp-proxy"},
|
||||
{"--no-ui-mcp-proxy", "--no-webui-mcp-proxy"},
|
||||
"experimental: whether to enable MCP CORS proxy - do not enable in untrusted environments (default: disabled)",
|
||||
[](common_params & params, bool value) {
|
||||
params.ui_mcp_proxy = value;
|
||||
params.webui_mcp_proxy = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_UI_MCP_PROXY"));
|
||||
add_opt(common_arg(
|
||||
@@ -2897,24 +2912,26 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.server_tools = parse_csv_row(value);
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_TOOLS"));
|
||||
// Deprecated: use --ui/--no-ui instead (kept for backward compat)
|
||||
add_opt(common_arg(
|
||||
{"--webui"},
|
||||
{"--no-webui"},
|
||||
"[DEPRECATED: use --ui/--no-ui] whether to enable the Web UI",
|
||||
{"-ag", "--agent"},
|
||||
{"-no-ag", "--no-agent"},
|
||||
"whether to enable CORS proxy and all built-in tools - do not enable in untrusted environments (default: disabled)",
|
||||
[](common_params & params, bool value) {
|
||||
params.ui = value;
|
||||
params.webui = value;
|
||||
if (value) {
|
||||
params.server_tools = {"all"};
|
||||
params.ui_mcp_proxy = true;
|
||||
} else {
|
||||
params.server_tools.clear();
|
||||
params.ui_mcp_proxy = false;
|
||||
}
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_WEBUI"));
|
||||
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_AGENT"));
|
||||
add_opt(common_arg(
|
||||
{"--ui"},
|
||||
{"--no-ui"},
|
||||
{"--ui", "--webui"},
|
||||
{"--no-ui", "--no-webui"},
|
||||
string_format("whether to enable the Web UI (default: %s)", params.ui ? "enabled" : "disabled"),
|
||||
[](common_params & params, bool value) {
|
||||
params.ui = value;
|
||||
params.webui = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_UI"));
|
||||
add_opt(common_arg(
|
||||
@@ -2945,7 +2962,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_API_KEY"));
|
||||
add_opt(common_arg(
|
||||
{"--api-key-file"}, "FNAME",
|
||||
"path to file containing API keys (default: none)",
|
||||
"path to file containing API keys, one per line; lines starting with a hash are treated as comments (default: none)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
std::ifstream key_file(value);
|
||||
if (!key_file) {
|
||||
@@ -2953,7 +2970,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
}
|
||||
std::string key;
|
||||
while (std::getline(key_file, key)) {
|
||||
if (!key.empty()) {
|
||||
if (!key.empty() && key[0] != '#') {
|
||||
params.api_keys.push_back(key);
|
||||
}
|
||||
}
|
||||
|
||||
+10
-1
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include "download.h"
|
||||
|
||||
#include <set>
|
||||
#include <map>
|
||||
@@ -129,11 +130,19 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<com
|
||||
// see: https://github.com/ggml-org/llama.cpp/issues/18163
|
||||
void common_params_add_preset_options(std::vector<common_arg> & args);
|
||||
|
||||
struct common_params_handle_models_params {
|
||||
common_download_callback * callback = nullptr;
|
||||
bool preset_only = false; // if true, only check & download remote preset (for router mode)
|
||||
};
|
||||
|
||||
// populate model paths (main model, mmproj, etc) from -hf if necessary
|
||||
// return true if the model is ready to use
|
||||
// throw an exception if there is an error that prevents the model from being used (e.g. network error, model not found, etc)
|
||||
// if params.skip_download is true, no downloads will be attempted. return false if the model is invalid or missing (e.g. ETag check failed)
|
||||
bool common_params_handle_models(common_params & params, llama_example curr_ex);
|
||||
bool common_params_handle_models(
|
||||
common_params & params,
|
||||
llama_example curr_ex,
|
||||
const common_params_handle_models_params & handle_params);
|
||||
|
||||
// initialize argument parser context - used by test-arg-parser and preset
|
||||
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
|
||||
|
||||
@@ -395,10 +395,11 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
|
||||
arguments.name_suffix) +
|
||||
arguments.value_prefix +
|
||||
(schema_info.resolves_to_string(param_schema) ?
|
||||
p.tool_arg_string_value(until_suffix) :
|
||||
p.tool_arg_json_value(p.schema(
|
||||
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false))) +
|
||||
p.tool_arg_close(p.literal(arguments.value_suffix)));
|
||||
p.ac(p.tool_arg_string_value(until_suffix) +
|
||||
p.tool_arg_close(p.literal(arguments.value_suffix)), arguments.value_suffix) :
|
||||
(p.tool_arg_json_value(p.schema(
|
||||
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false)) +
|
||||
p.tool_arg_close(p.literal(arguments.value_suffix)))));
|
||||
|
||||
auto named_arg = p.rule("tool-" + name + "-arg-" + param_name, arg);
|
||||
if (is_required) {
|
||||
|
||||
+103
-53
@@ -90,41 +90,93 @@ std::string common_chat_msg::render_content(const std::string & delimiter) const
|
||||
return text;
|
||||
}
|
||||
|
||||
std::vector<common_chat_msg_span> common_chat_split_by_role(const std::string & prompt, const std::vector<common_chat_msg_delimiter> & delims) {
|
||||
if (delims.empty() || prompt.empty()) {
|
||||
return {};
|
||||
common_chat_role common_chat_role_from_string(const std::string & role) {
|
||||
if (role == "system") { return COMMON_CHAT_ROLE_SYSTEM; }
|
||||
if (role == "assistant") { return COMMON_CHAT_ROLE_ASSISTANT; }
|
||||
if (role == "user") { return COMMON_CHAT_ROLE_USER; }
|
||||
if (role == "tool") { return COMMON_CHAT_ROLE_TOOL; }
|
||||
return COMMON_CHAT_ROLE_UNKNOWN;
|
||||
}
|
||||
|
||||
const char * common_chat_role_to_string(common_chat_role role) {
|
||||
switch (role) {
|
||||
case COMMON_CHAT_ROLE_SYSTEM: return "system";
|
||||
case COMMON_CHAT_ROLE_ASSISTANT: return "assistant";
|
||||
case COMMON_CHAT_ROLE_USER: return "user";
|
||||
case COMMON_CHAT_ROLE_TOOL: return "tool";
|
||||
case COMMON_CHAT_ROLE_UNKNOWN: return "";
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
json common_chat_msg_delimiters::to_json() const {
|
||||
json result = json::array();
|
||||
for (const auto & d : delimiters) {
|
||||
result.push_back({
|
||||
{ "role", common_chat_role_to_string(d.role) },
|
||||
{ "delimiter", d.delimiter },
|
||||
});
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
common_chat_msg_delimiters common_chat_msg_delimiters_parse(const json & delimiters) {
|
||||
common_chat_msg_delimiters result;
|
||||
|
||||
if (!delimiters.is_array()) {
|
||||
return result;
|
||||
}
|
||||
|
||||
auto parser = build_peg_parser([&](common_peg_parser_builder & p) {
|
||||
std::vector<std::string> all_delims;
|
||||
std::vector<common_peg_parser> tagged_messages;
|
||||
|
||||
all_delims.reserve(delims.size());
|
||||
tagged_messages.reserve(delims.size());
|
||||
for (const auto & d : delims) {
|
||||
all_delims.push_back(d.delimiter);
|
||||
result.delimiters.reserve(delimiters.size());
|
||||
for (const auto & d : delimiters) {
|
||||
if (!d.is_object()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto any_delim = p.until_one_of(all_delims);
|
||||
for (const auto & d : delims) {
|
||||
tagged_messages.push_back(p.tag(d.role, p.literal(d.delimiter) + any_delim));
|
||||
}
|
||||
|
||||
return any_delim + p.zero_or_more(p.choice(tagged_messages)) + p.end();
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx(prompt);
|
||||
const auto result = parser.parse(ctx);
|
||||
if (!result.success()) {
|
||||
return {};
|
||||
result.delimiters.push_back({
|
||||
common_chat_role_from_string(d.value("role", std::string())),
|
||||
d.value("delimiter", std::string()),
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<common_chat_msg_span> spans;
|
||||
ctx.ast.visit(result, [&](const common_peg_ast_node & node) {
|
||||
if (!node.tag.empty()) {
|
||||
spans.push_back({ node.tag, node.start, node.end - node.start });
|
||||
return result;
|
||||
}
|
||||
|
||||
void common_chat_msg_delimiters::tokenize(const llama_vocab * vocab) {
|
||||
for (auto & d : delimiters) {
|
||||
d.tokens = common_tokenize(vocab, d.delimiter, false, true);
|
||||
}
|
||||
}
|
||||
|
||||
common_chat_msg_spans common_chat_msg_delimiters::split(const llama_tokens & tokens, const std::map<size_t, size_t> & skips) const {
|
||||
std::vector<std::pair<common_chat_role, size_t>> matches;
|
||||
|
||||
auto skip = skips.begin();
|
||||
for (size_t i = 0; i < tokens.size();) {
|
||||
if (skip != skips.end() && i == skip->first) {
|
||||
i += skip->second;
|
||||
++skip;
|
||||
continue;
|
||||
}
|
||||
});
|
||||
for (const auto & d : delimiters) {
|
||||
if (i + d.tokens.size() > tokens.size()) {
|
||||
continue;
|
||||
}
|
||||
if (std::equal(d.tokens.begin(), d.tokens.end(), tokens.begin() + i)) {
|
||||
matches.emplace_back(d.role, i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
i++;
|
||||
}
|
||||
|
||||
matches.emplace_back(COMMON_CHAT_ROLE_UNKNOWN, tokens.size());
|
||||
|
||||
common_chat_msg_spans spans;
|
||||
for (size_t i = 0; i + 1 < matches.size(); i++) {
|
||||
const auto & curr = matches[i];
|
||||
const auto & next = matches[i + 1];
|
||||
spans.add(curr.first, curr.second, next.second - curr.second);
|
||||
}
|
||||
|
||||
return spans;
|
||||
}
|
||||
@@ -1081,13 +1133,13 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
|
||||
data.prompt = prompt;
|
||||
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs, /* messages_override= */ adjusted_messages);
|
||||
data.message_spans = common_chat_split_by_role(prompt, {
|
||||
{ "assistant", "<|start|>assistant" },
|
||||
{ "user", "<|start|>user" },
|
||||
{ "system", "<|start|>developer" },
|
||||
{ "system", "<|start|>system" },
|
||||
{ "tool", "<|start|>functions" },
|
||||
});
|
||||
data.message_delimiters = {
|
||||
{ COMMON_CHAT_ROLE_ASSISTANT, "<|start|>assistant" },
|
||||
{ COMMON_CHAT_ROLE_USER, "<|start|>user" },
|
||||
{ COMMON_CHAT_ROLE_SYSTEM, "<|start|>developer" },
|
||||
{ COMMON_CHAT_ROLE_SYSTEM, "<|start|>system" },
|
||||
{ COMMON_CHAT_ROLE_TOOL, "<|start|>functions" },
|
||||
};
|
||||
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
@@ -1228,10 +1280,10 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
|
||||
data.prompt += data.generation_prompt;
|
||||
}
|
||||
|
||||
data.message_spans = common_chat_split_by_role(data.prompt, {
|
||||
{ "user", "<|turn>user\n" },
|
||||
{ "assistant", "<|turn>model\n" },
|
||||
});
|
||||
data.message_delimiters = {
|
||||
{ COMMON_CHAT_ROLE_USER, "<|turn>user" },
|
||||
{ COMMON_CHAT_ROLE_ASSISTANT, "<|turn>model" },
|
||||
};
|
||||
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4;
|
||||
data.supports_thinking = true;
|
||||
@@ -2030,15 +2082,15 @@ static common_chat_params common_chat_params_init_cohere2moe(const common_chat_t
|
||||
RESULT_START, RESULT_END,
|
||||
};
|
||||
|
||||
// Split the rendered prompt into per-role message spans. Tool results are rendered with the
|
||||
// Declare per-role message delimiters. Tool results are rendered with the
|
||||
// system token followed by <|START_TOOL_RESULT|>, so the "tool" delimiter must be listed before
|
||||
// the plain "system" one (it is a strict superset, and the role split tries delimiters in order).
|
||||
data.message_spans = common_chat_split_by_role(data.prompt, {
|
||||
{ "assistant", GEN_PREFIX },
|
||||
{ "user", TURN_START + USER },
|
||||
{ "tool", TURN_START + SYSTEM + RESULT_START },
|
||||
{ "system", TURN_START + SYSTEM },
|
||||
});
|
||||
data.message_delimiters = {
|
||||
{ COMMON_CHAT_ROLE_ASSISTANT, GEN_PREFIX },
|
||||
{ COMMON_CHAT_ROLE_USER, TURN_START + USER },
|
||||
{ COMMON_CHAT_ROLE_TOOL, TURN_START + SYSTEM + RESULT_START },
|
||||
{ COMMON_CHAT_ROLE_SYSTEM, TURN_START + SYSTEM },
|
||||
};
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
@@ -2526,17 +2578,15 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
||||
autoparser.analyze_template(tmpl);
|
||||
auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser);
|
||||
|
||||
std::vector<common_chat_msg_delimiter> delimiters;
|
||||
common_chat_msg_delimiters delimiters;
|
||||
if (!autoparser.assistant_start.empty()) {
|
||||
delimiters.push_back({ "assistant", autoparser.assistant_start });
|
||||
delimiters.add(COMMON_CHAT_ROLE_ASSISTANT, autoparser.assistant_start);
|
||||
}
|
||||
if (!autoparser.user_start.empty()) {
|
||||
delimiters.push_back({ "user", autoparser.user_start });
|
||||
delimiters.add(COMMON_CHAT_ROLE_USER, autoparser.user_start);
|
||||
}
|
||||
|
||||
if (!delimiters.empty()) {
|
||||
auto_params.message_spans = common_chat_split_by_role(auto_params.prompt, delimiters);
|
||||
}
|
||||
auto_params.message_delimiters = std::move(delimiters);
|
||||
|
||||
auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE;
|
||||
if (auto_params.supports_thinking) {
|
||||
|
||||
+65
-6
@@ -143,15 +143,75 @@ struct common_chat_msg_diff {
|
||||
}
|
||||
};
|
||||
|
||||
enum common_chat_role {
|
||||
COMMON_CHAT_ROLE_UNKNOWN,
|
||||
COMMON_CHAT_ROLE_SYSTEM,
|
||||
COMMON_CHAT_ROLE_ASSISTANT,
|
||||
COMMON_CHAT_ROLE_USER,
|
||||
COMMON_CHAT_ROLE_TOOL
|
||||
};
|
||||
|
||||
common_chat_role common_chat_role_from_string(const std::string & role);
|
||||
const char * common_chat_role_to_string(common_chat_role role);
|
||||
|
||||
struct common_chat_msg_span {
|
||||
std::string role;
|
||||
common_chat_role role = COMMON_CHAT_ROLE_UNKNOWN;
|
||||
std::size_t pos = 0;
|
||||
std::size_t len = 0;
|
||||
|
||||
bool valid() const {
|
||||
return role != COMMON_CHAT_ROLE_UNKNOWN;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_chat_msg_spans {
|
||||
std::vector<common_chat_msg_span> spans;
|
||||
|
||||
void add(common_chat_role role, size_t pos, size_t len) {
|
||||
spans.push_back({ role, pos, len });
|
||||
}
|
||||
|
||||
bool is_user_start(int32_t pos) const {
|
||||
for (auto it = spans.begin(); it != spans.end(); ++it) {
|
||||
if (it->role == COMMON_CHAT_ROLE_USER && pos == (int32_t) it->pos) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
int32_t last_user_message_pos() const {
|
||||
for (auto it = spans.rbegin(); it != spans.rend(); ++it) {
|
||||
if (it->role == COMMON_CHAT_ROLE_USER) {
|
||||
return (int32_t) it->pos;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_chat_msg_delimiter {
|
||||
std::string role;
|
||||
std::string delimiter;
|
||||
common_chat_role role = COMMON_CHAT_ROLE_UNKNOWN;
|
||||
std::string delimiter;
|
||||
llama_tokens tokens = {};
|
||||
};
|
||||
|
||||
struct common_chat_msg_delimiters {
|
||||
std::vector<common_chat_msg_delimiter> delimiters;
|
||||
|
||||
common_chat_msg_delimiters() = default;
|
||||
common_chat_msg_delimiters(std::initializer_list<common_chat_msg_delimiter> delims) : delimiters(delims) {}
|
||||
|
||||
void add(common_chat_role role, const std::string & delimiter) {
|
||||
delimiters.push_back({ role, delimiter });
|
||||
}
|
||||
|
||||
void tokenize(const llama_vocab * vocab);
|
||||
|
||||
// split tokens into message spans. skips maps a start index to a length of a region to jump over without matching
|
||||
common_chat_msg_spans split(const llama_tokens & tokens, const std::map<size_t, size_t> & skips = {}) const;
|
||||
|
||||
nlohmann::ordered_json to_json() const;
|
||||
};
|
||||
|
||||
struct common_chat_tool {
|
||||
@@ -219,7 +279,7 @@ struct common_chat_params {
|
||||
std::vector<std::string> preserved_tokens;
|
||||
std::vector<std::string> additional_stops;
|
||||
std::string parser;
|
||||
std::vector<common_chat_msg_span> message_spans;
|
||||
common_chat_msg_delimiters message_delimiters;
|
||||
};
|
||||
|
||||
// per-message parsing syntax
|
||||
@@ -325,5 +385,4 @@ struct common_chat_prompt_preset {
|
||||
|
||||
common_chat_prompt_preset common_chat_get_asr_prompt(const common_chat_templates * chat_templates);
|
||||
|
||||
std::vector<common_chat_msg_span> common_chat_split_by_role(const std::string & prompt, const std::vector<common_chat_msg_delimiter> & delims);
|
||||
|
||||
common_chat_msg_delimiters common_chat_msg_delimiters_parse(const nlohmann::ordered_json & delimiters);
|
||||
|
||||
+15
-1
@@ -1074,6 +1074,18 @@ std::vector<common_file_info> fs_list(const std::string & path, bool include_dir
|
||||
return files;
|
||||
}
|
||||
|
||||
std::ifstream fs_open_ifstream(const std::string & fname, std::ios_base::openmode mode) {
|
||||
#ifdef _WIN32
|
||||
int wlen = MultiByteToWideChar(CP_UTF8, 0, fname.c_str(), -1, NULL, 0);
|
||||
if (!wlen) { return std::ifstream(); }
|
||||
std::vector<wchar_t> wfname(wlen);
|
||||
(void)MultiByteToWideChar(CP_UTF8, 0, fname.c_str(), -1, wfname.data(), wlen);
|
||||
return std::ifstream(wfname.data(), mode);
|
||||
#else
|
||||
return std::ifstream(fname, mode);
|
||||
#endif
|
||||
}
|
||||
|
||||
//
|
||||
// TTY utils
|
||||
//
|
||||
@@ -2034,7 +2046,7 @@ bool common_prompt_batch_decode(
|
||||
}
|
||||
|
||||
size_t common_prompt_checkpoint::size() const {
|
||||
return data_tgt.size() + data_dft.size();
|
||||
return data_tgt.size() + data_dft.size() + data_spec.size();
|
||||
}
|
||||
|
||||
bool common_prompt_checkpoint::empty() const {
|
||||
@@ -2049,6 +2061,7 @@ void common_prompt_checkpoint::clear() {
|
||||
|
||||
data_tgt.clear();
|
||||
data_dft.clear();
|
||||
data_spec.clear();
|
||||
}
|
||||
|
||||
void common_prompt_checkpoint::update_pos(
|
||||
@@ -2138,4 +2151,5 @@ void common_prompt_checkpoint::clear_tgt() {
|
||||
|
||||
void common_prompt_checkpoint::clear_dft() {
|
||||
data_dft.clear();
|
||||
data_spec.clear();
|
||||
}
|
||||
|
||||
+22
-9
@@ -295,7 +295,16 @@ struct common_params_model {
|
||||
std::string hf_repo = ""; // HF repo // NOLINT
|
||||
std::string hf_file = ""; // HF file // NOLINT
|
||||
std::string docker_repo = ""; // Docker repo // NOLINT
|
||||
std::string name = ""; // in format <user>/<model>[:<tag>] (tag is optional) // NOLINT
|
||||
|
||||
std::string get_name() {
|
||||
if (!hf_repo.empty()) {
|
||||
return hf_repo;
|
||||
}
|
||||
if (!docker_repo.empty()) {
|
||||
return docker_repo;
|
||||
}
|
||||
return path;
|
||||
}
|
||||
};
|
||||
|
||||
// draft-model-based speculative decoding parameters
|
||||
@@ -363,7 +372,7 @@ struct common_params_speculative {
|
||||
|
||||
uint32_t need_n_rs_seq() const {
|
||||
bool needs_rs_seq = std::any_of(types.begin(), types.end(), [&](auto t) {
|
||||
return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP;
|
||||
return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP || t == COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3;
|
||||
});
|
||||
|
||||
return needs_rs_seq ? draft.n_max : 0u;
|
||||
@@ -600,7 +609,7 @@ struct common_params {
|
||||
bool cache_prompt = true; // whether to enable prompt caching
|
||||
bool cache_idle_slots = true; // save and clear idle slots upon starting a new task
|
||||
int32_t n_ctx_checkpoints = 32; // max number of context checkpoints per slot
|
||||
int32_t checkpoint_min_step = 256; // minimum spacing between context checkpoints
|
||||
int32_t checkpoint_min_step = 8192; // minimum spacing between context checkpoints
|
||||
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
|
||||
|
||||
std::string hostname = "127.0.0.1";
|
||||
@@ -622,14 +631,11 @@ struct common_params {
|
||||
|
||||
std::map<std::string, std::string> default_template_kwargs;
|
||||
|
||||
// CLI params
|
||||
std::string server_base; // if set, connect to this server instead of starting a new one
|
||||
|
||||
// UI configs
|
||||
bool ui = true;
|
||||
|
||||
// Deprecated: use ui, ui_mcp_proxy, ui_config_json instead
|
||||
bool webui = ui;
|
||||
bool webui_mcp_proxy = false;
|
||||
std::string webui_config_json;
|
||||
|
||||
bool ui_mcp_proxy = false;
|
||||
std::string ui_config_json;
|
||||
|
||||
@@ -848,6 +854,9 @@ struct common_file_info {
|
||||
};
|
||||
std::vector<common_file_info> fs_list(const std::string & path, bool include_directories);
|
||||
|
||||
// fs open, also handle UTF8 on Windows
|
||||
std::ifstream fs_open_ifstream(const std::string & fname, std::ios_base::openmode mode);
|
||||
|
||||
//
|
||||
// TTY utils
|
||||
//
|
||||
@@ -1065,6 +1074,10 @@ struct common_prompt_checkpoint {
|
||||
std::vector<uint8_t> data_tgt;
|
||||
std::vector<uint8_t> data_dft;
|
||||
|
||||
// (optional) speculative-decoding implementation state stashed with the checkpoint
|
||||
// (e.g. eagle3's deferred-boundary g_embd row)
|
||||
std::vector<uint8_t> data_spec;
|
||||
|
||||
size_t size() const;
|
||||
|
||||
bool empty() const;
|
||||
|
||||
+3
-1
@@ -799,6 +799,7 @@ common_download_model_result common_download_model(const common_params_model &
|
||||
|
||||
bool download_mmproj = opts.download_mmproj;
|
||||
bool download_mtp = opts.download_mtp;
|
||||
bool preset_only = opts.preset_only;
|
||||
bool is_hf = !model.hf_repo.empty();
|
||||
|
||||
if (is_hf) {
|
||||
@@ -806,7 +807,8 @@ common_download_model_result common_download_model(const common_params_model &
|
||||
if (!hf.preset.path.empty()) {
|
||||
// if preset.ini exists, only download that file alone
|
||||
tasks.push_back({hf.preset.url, hf.preset.local_path});
|
||||
} else {
|
||||
} else if (!preset_only) {
|
||||
// only add other files if we're NOT in preset-only mode (normal run, non-router)
|
||||
for (const auto & f : hf.model_files) {
|
||||
tasks.push_back({f.url, f.local_path});
|
||||
}
|
||||
|
||||
@@ -55,6 +55,7 @@ struct common_download_opts {
|
||||
bool skip_download = false; // if true, only validation is performed, common_skip_download_exception may be thrown if the file is missing or invalid
|
||||
bool download_mmproj = false;
|
||||
bool download_mtp = false;
|
||||
bool preset_only = false; // if true, only check & download remote preset (for router mode)
|
||||
common_download_callback * callback = nullptr;
|
||||
};
|
||||
|
||||
|
||||
@@ -2,6 +2,16 @@
|
||||
|
||||
#include <cpp-httplib/httplib.h>
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <winsock2.h>
|
||||
#include <windows.h>
|
||||
#else
|
||||
#include <sys/socket.h>
|
||||
#include <netinet/in.h>
|
||||
#include <arpa/inet.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
struct common_http_url {
|
||||
std::string scheme;
|
||||
std::string user;
|
||||
@@ -97,3 +107,63 @@ static std::pair<httplib::Client, common_http_url> common_http_client(const std:
|
||||
static std::string common_http_show_masked_url(const common_http_url & parts) {
|
||||
return parts.scheme + "://" + (parts.user.empty() ? "" : "****:****@") + parts.host + parts.path;
|
||||
}
|
||||
|
||||
static int common_http_get_free_port() {
|
||||
#ifdef _WIN32
|
||||
WSADATA wsaData;
|
||||
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {
|
||||
return -1;
|
||||
}
|
||||
typedef SOCKET native_socket_t;
|
||||
#define INVALID_SOCKET_VAL INVALID_SOCKET
|
||||
#define CLOSE_SOCKET(s) closesocket(s)
|
||||
#else
|
||||
typedef int native_socket_t;
|
||||
#define INVALID_SOCKET_VAL -1
|
||||
#define CLOSE_SOCKET(s) close(s)
|
||||
#endif
|
||||
|
||||
native_socket_t sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (sock == INVALID_SOCKET_VAL) {
|
||||
#ifdef _WIN32
|
||||
WSACleanup();
|
||||
#endif
|
||||
return -1;
|
||||
}
|
||||
|
||||
struct sockaddr_in serv_addr;
|
||||
std::memset(&serv_addr, 0, sizeof(serv_addr));
|
||||
serv_addr.sin_family = AF_INET;
|
||||
serv_addr.sin_addr.s_addr = htonl(INADDR_ANY);
|
||||
serv_addr.sin_port = htons(0);
|
||||
|
||||
if (bind(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) != 0) {
|
||||
CLOSE_SOCKET(sock);
|
||||
#ifdef _WIN32
|
||||
WSACleanup();
|
||||
#endif
|
||||
return -1;
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
int namelen = sizeof(serv_addr);
|
||||
#else
|
||||
socklen_t namelen = sizeof(serv_addr);
|
||||
#endif
|
||||
if (getsockname(sock, (struct sockaddr*)&serv_addr, &namelen) != 0) {
|
||||
CLOSE_SOCKET(sock);
|
||||
#ifdef _WIN32
|
||||
WSACleanup();
|
||||
#endif
|
||||
return -1;
|
||||
}
|
||||
|
||||
int port = ntohs(serv_addr.sin_port);
|
||||
|
||||
CLOSE_SOCKET(sock);
|
||||
#ifdef _WIN32
|
||||
WSACleanup();
|
||||
#endif
|
||||
|
||||
return port;
|
||||
}
|
||||
|
||||
+89
-46
@@ -686,59 +686,62 @@ value set_statement::execute_impl(context & ctx) {
|
||||
return mk_val<value_undefined>();
|
||||
}
|
||||
|
||||
static inline void bind_parameters(const std::string & name, const statements & this_args, const func_args & args, context & ctx) {
|
||||
const size_t expected_count = this_args.size();
|
||||
const size_t input_count = args.count();
|
||||
|
||||
JJ_DEBUG("Invoking '%s' with %zu input arguments (expected %zu)", name.c_str(), input_count, expected_count);
|
||||
for (size_t i = 0; i < expected_count; ++i) {
|
||||
if (i < input_count) {
|
||||
if (is_stmt<identifier>(this_args[i])) {
|
||||
// normal parameter
|
||||
std::string param_name = cast_stmt<identifier>(this_args[i])->val;
|
||||
value param_value = args.get_kwarg_or_pos(param_name, i);
|
||||
JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), param_value->type().c_str());
|
||||
ctx.set_val(param_name, param_value);
|
||||
} else if (is_stmt<keyword_argument_expression>(this_args[i])) {
|
||||
// default argument used as normal parameter
|
||||
auto kwarg = cast_stmt<keyword_argument_expression>(this_args[i]);
|
||||
if (!is_stmt<identifier>(kwarg->key)) {
|
||||
throw std::runtime_error("Keyword argument key must be an identifier in '" + name + "'");
|
||||
}
|
||||
std::string param_name = cast_stmt<identifier>(kwarg->key)->val;
|
||||
value param_value = args.get_kwarg_or_pos(param_name, i);
|
||||
JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), param_value->type().c_str());
|
||||
ctx.set_val(param_name, param_value);
|
||||
} else {
|
||||
throw std::runtime_error("Invalid parameter type in '" + name + "'");
|
||||
}
|
||||
} else {
|
||||
auto & default_arg = this_args[i];
|
||||
if (is_stmt<keyword_argument_expression>(default_arg)) {
|
||||
auto kwarg = cast_stmt<keyword_argument_expression>(default_arg);
|
||||
if (!is_stmt<identifier>(kwarg->key)) {
|
||||
throw std::runtime_error("Keyword argument key must be an identifier in '" + name + "'");
|
||||
}
|
||||
std::string param_name = cast_stmt<identifier>(kwarg->key)->val;
|
||||
JJ_DEBUG(" Binding parameter '%s' to default argument of type %s", param_name.c_str(), kwarg->val->type().c_str());
|
||||
ctx.set_val(param_name, kwarg->val->execute(args.ctx));
|
||||
} else {
|
||||
throw std::runtime_error("Not enough arguments provided to '" + name + "'");
|
||||
}
|
||||
//std::string param_name = cast_stmt<identifier>(default_args[i])->val;
|
||||
//JJ_DEBUG(" Binding parameter '%s' to default", param_name.c_str());
|
||||
//ctx.var[param_name] = default_args[i]->execute(ctx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
value macro_statement::execute_impl(context & ctx) {
|
||||
if (!is_stmt<identifier>(this->name)) {
|
||||
throw std::runtime_error("Macro name must be an identifier");
|
||||
}
|
||||
std::string name = cast_stmt<identifier>(this->name)->val;
|
||||
|
||||
const func_handler func = [this, name, &ctx](const func_args & args) -> value {
|
||||
size_t expected_count = this->args.size();
|
||||
size_t input_count = args.count();
|
||||
const func_handler func = [this, name](const func_args & args) -> value {
|
||||
context macro_ctx(args.ctx); // new scope for macro execution
|
||||
|
||||
JJ_DEBUG("Invoking macro '%s' with %zu input arguments (expected %zu)", name.c_str(), input_count, expected_count);
|
||||
context macro_ctx(ctx); // new scope for macro execution
|
||||
|
||||
// bind parameters
|
||||
for (size_t i = 0; i < expected_count; ++i) {
|
||||
if (i < input_count) {
|
||||
if (is_stmt<identifier>(this->args[i])) {
|
||||
// normal parameter
|
||||
std::string param_name = cast_stmt<identifier>(this->args[i])->val;
|
||||
value param_value = args.get_kwarg_or_pos(param_name, i);
|
||||
JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), param_value->type().c_str());
|
||||
macro_ctx.set_val(param_name, param_value);
|
||||
} else if (is_stmt<keyword_argument_expression>(this->args[i])) {
|
||||
// default argument used as normal parameter
|
||||
auto kwarg = cast_stmt<keyword_argument_expression>(this->args[i]);
|
||||
if (!is_stmt<identifier>(kwarg->key)) {
|
||||
throw std::runtime_error("Keyword argument key must be an identifier in macro '" + name + "'");
|
||||
}
|
||||
std::string param_name = cast_stmt<identifier>(kwarg->key)->val;
|
||||
value param_value = args.get_kwarg_or_pos(param_name, i);
|
||||
JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), param_value->type().c_str());
|
||||
macro_ctx.set_val(param_name, param_value);
|
||||
} else {
|
||||
throw std::runtime_error("Invalid parameter type in macro '" + name + "'");
|
||||
}
|
||||
} else {
|
||||
auto & default_arg = this->args[i];
|
||||
if (is_stmt<keyword_argument_expression>(default_arg)) {
|
||||
auto kwarg = cast_stmt<keyword_argument_expression>(default_arg);
|
||||
if (!is_stmt<identifier>(kwarg->key)) {
|
||||
throw std::runtime_error("Keyword argument key must be an identifier in macro '" + name + "'");
|
||||
}
|
||||
std::string param_name = cast_stmt<identifier>(kwarg->key)->val;
|
||||
JJ_DEBUG(" Binding parameter '%s' to default argument of type %s", param_name.c_str(), kwarg->val->type().c_str());
|
||||
macro_ctx.set_val(param_name, kwarg->val->execute(ctx));
|
||||
} else {
|
||||
throw std::runtime_error("Not enough arguments provided to macro '" + name + "'");
|
||||
}
|
||||
//std::string param_name = cast_stmt<identifier>(default_args[i])->val;
|
||||
//JJ_DEBUG(" Binding parameter '%s' to default", param_name.c_str());
|
||||
//macro_ctx.var[param_name] = default_args[i]->execute(ctx);
|
||||
}
|
||||
}
|
||||
bind_parameters(name, this->args, args, macro_ctx);
|
||||
|
||||
// execute macro body
|
||||
JJ_DEBUG("Executing macro '%s' body with %zu statements", name.c_str(), this->body.size());
|
||||
@@ -752,6 +755,46 @@ value macro_statement::execute_impl(context & ctx) {
|
||||
return mk_val<value_undefined>();
|
||||
}
|
||||
|
||||
value call_statement::execute_impl(context & ctx) {
|
||||
auto call_expr = cast_stmt<call_expression>(this->call);
|
||||
if (!call_expr) {
|
||||
throw std::runtime_error("Call statement requires a valid call expression");
|
||||
}
|
||||
|
||||
value callee_val = call_expr->callee->execute(ctx);
|
||||
if (!is_val<value_func>(callee_val)) {
|
||||
throw std::runtime_error("Callee is not a function: got " + callee_val->type());
|
||||
}
|
||||
auto * callee_func = cast_val<value_func>(callee_val);
|
||||
|
||||
context caller_ctx(ctx); // new scope for caller execution
|
||||
|
||||
const func_handler func = [this, caller_ctx = std::move(caller_ctx)](const func_args & args) -> value {
|
||||
context block_ctx(caller_ctx); // new scope for block execution
|
||||
|
||||
bind_parameters("caller", this->caller_args, args, block_ctx);
|
||||
|
||||
JJ_DEBUG("Executing call body with %zu statements", this->body.size());
|
||||
auto res = exec_statements(this->body, block_ctx);
|
||||
JJ_DEBUG("Call body execution complete, result: %s", res->val_str.str().c_str());
|
||||
return res;
|
||||
};
|
||||
|
||||
context call_ctx(ctx);
|
||||
call_ctx.set_val("caller", mk_val<value_func>("caller", func));
|
||||
|
||||
func_args args(call_ctx);
|
||||
|
||||
for (const auto & arg_expr : call_expr->args) {
|
||||
auto arg_val = arg_expr->execute(ctx);
|
||||
JJ_DEBUG(" Argument type: %s", arg_val->type().c_str());
|
||||
args.push_back(arg_val);
|
||||
}
|
||||
|
||||
JJ_DEBUG("Calling macro '%s' with %zu arguments", callee_func->name.c_str(), args.count());
|
||||
return callee_func->invoke(args);
|
||||
}
|
||||
|
||||
value member_expression::execute_impl(context & ctx) {
|
||||
value object = this->object->execute(ctx);
|
||||
|
||||
|
||||
@@ -552,6 +552,7 @@ struct call_statement : public statement {
|
||||
for (const auto & arg : this->caller_args) chk_type<expression>(arg);
|
||||
}
|
||||
std::string type() const override { return "CallStatement"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
};
|
||||
|
||||
struct ternary_expression : public expression {
|
||||
|
||||
@@ -233,27 +233,27 @@ struct BuiltinRule {
|
||||
};
|
||||
|
||||
static std::unordered_map<std::string, BuiltinRule> PRIMITIVE_RULES = {
|
||||
{"boolean", {"(\"true\" | \"false\") space", {}}},
|
||||
{"boolean", {"(\"true\" | \"false\")", {}}},
|
||||
{"decimal-part", {"[0-9]{1,16}", {}}},
|
||||
{"integral-part", {"[0] | [1-9] [0-9]{0,15}", {}}},
|
||||
{"number", {"(\"-\"? integral-part) (\".\" decimal-part)? ([eE] [-+]? integral-part)? space", {"integral-part", "decimal-part"}}},
|
||||
{"integer", {"(\"-\"? integral-part) space", {"integral-part"}}},
|
||||
{"number", {"(\"-\"? integral-part) (\".\" decimal-part)? ([eE] [-+]? integral-part)?", {"integral-part", "decimal-part"}}},
|
||||
{"integer", {"(\"-\"? integral-part)", {"integral-part"}}},
|
||||
{"value", {"object | array | string | number | boolean | null", {"object", "array", "string", "number", "boolean", "null"}}},
|
||||
{"object", {"\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? \"}\" space", {"string", "value"}}},
|
||||
{"array", {"\"[\" space ( value (\",\" space value)* )? \"]\" space", {"value"}}},
|
||||
{"uuid", {"\"\\\"\" [0-9a-fA-F]{8} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{12} \"\\\"\" space", {}}},
|
||||
{"object", {"\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? space \"}\"", {"string", "value"}}},
|
||||
{"array", {"\"[\" space ( value (\",\" space value)* )? space \"]\"", {"value"}}},
|
||||
{"uuid", {"\"\\\"\" [0-9a-fA-F]{8} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{12} \"\\\"\"", {}}},
|
||||
{"char", {"[^\"\\\\\\x7F\\x00-\\x1F] | [\\\\] ([\"\\\\bfnrt] | \"u\" [0-9a-fA-F]{4})", {}}},
|
||||
{"string", {"\"\\\"\" char* \"\\\"\" space", {"char"}}},
|
||||
{"null", {"\"null\" space", {}}},
|
||||
{"string", {"\"\\\"\" char* \"\\\"\"", {"char"}}},
|
||||
{"null", {"\"null\"", {}}},
|
||||
};
|
||||
|
||||
static std::unordered_map<std::string, BuiltinRule> STRING_FORMAT_RULES = {
|
||||
{"date", {"[0-9]{4} \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( \"0\" [1-9] | [1-2] [0-9] | \"3\" [0-1] )", {}}},
|
||||
{"time", {"([01] [0-9] | \"2\" [0-3]) \":\" [0-5] [0-9] \":\" [0-5] [0-9] ( \".\" [0-9]{3} )? ( \"Z\" | ( \"+\" | \"-\" ) ( [01] [0-9] | \"2\" [0-3] ) \":\" [0-5] [0-9] )", {}}},
|
||||
{"date-time", {"date \"T\" time", {"date", "time"}}},
|
||||
{"date-string", {"\"\\\"\" date \"\\\"\" space", {"date"}}},
|
||||
{"time-string", {"\"\\\"\" time \"\\\"\" space", {"time"}}},
|
||||
{"date-time-string", {"\"\\\"\" date-time \"\\\"\" space", {"date-time"}}}
|
||||
{"date-string", {"\"\\\"\" date \"\\\"\"", {"date"}}},
|
||||
{"time-string", {"\"\\\"\" time \"\\\"\"", {"time"}}},
|
||||
{"date-time-string", {"\"\\\"\" date-time \"\\\"\"", {"date-time"}}}
|
||||
};
|
||||
|
||||
static bool is_reserved_name(const std::string & name) {
|
||||
@@ -551,16 +551,16 @@ private:
|
||||
}
|
||||
return join_seq();
|
||||
};
|
||||
return _add_rule(name, "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space");
|
||||
return _add_rule(name, "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\"");
|
||||
}
|
||||
|
||||
/*
|
||||
Returns a rule that matches a JSON string that is none of the provided strings
|
||||
|
||||
not_strings({"a"})
|
||||
-> ["] ( [a] char+ | [^"a] char* )? ["] space
|
||||
-> ["] ( [a] char+ | [^"a] char* )? ["]
|
||||
not_strings({"and", "also"})
|
||||
-> ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["] space
|
||||
-> ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["]
|
||||
*/
|
||||
std::string _not_strings(const std::vector<std::string> & strings) {
|
||||
|
||||
@@ -619,7 +619,7 @@ private:
|
||||
if (!trie.is_end_of_string) {
|
||||
out << "?";
|
||||
}
|
||||
out << " [\"] space";
|
||||
out << " [\"]";
|
||||
return out.str();
|
||||
}
|
||||
|
||||
@@ -725,7 +725,7 @@ private:
|
||||
rule += " )?";
|
||||
}
|
||||
|
||||
rule += " \"}\" space";
|
||||
rule += " space \"}\"";
|
||||
|
||||
return rule;
|
||||
}
|
||||
@@ -858,14 +858,14 @@ public:
|
||||
return _add_rule(rule_name, _generate_union_rule(name, schema_types));
|
||||
}
|
||||
if (schema.contains("const")) {
|
||||
return _add_rule(rule_name, _generate_constant_rule(schema["const"]) + " space");
|
||||
return _add_rule(rule_name, _generate_constant_rule(schema["const"]));
|
||||
}
|
||||
if (schema.contains("enum")) {
|
||||
std::vector<std::string> enum_values;
|
||||
for (const auto & v : schema["enum"]) {
|
||||
enum_values.push_back(_generate_constant_rule(v));
|
||||
}
|
||||
return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ") space");
|
||||
return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ")");
|
||||
}
|
||||
if ((schema_type.is_null() || schema_type == "object")
|
||||
&& (schema.contains("properties") ||
|
||||
@@ -933,7 +933,7 @@ public:
|
||||
}
|
||||
}
|
||||
if (!enum_intersection.empty()) {
|
||||
return _add_rule(rule_name, "(" + string_join(enum_intersection, " | ") + ") space");
|
||||
return _add_rule(rule_name, "(" + string_join(enum_intersection, " | ") + ")");
|
||||
}
|
||||
}
|
||||
return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json()));
|
||||
@@ -948,7 +948,7 @@ public:
|
||||
}
|
||||
rule += visit(items[i], name + (name.empty() ? "" : "-") + "tuple-" + std::to_string(i));
|
||||
}
|
||||
rule += " \"]\" space";
|
||||
rule += " space \"]\"";
|
||||
return _add_rule(rule_name, rule);
|
||||
}
|
||||
std::string item_rule_name = visit(items, name + (name.empty() ? "" : "-") + "item");
|
||||
@@ -956,7 +956,7 @@ public:
|
||||
json max_items_json = schema.contains("maxItems") ? schema["maxItems"] : json();
|
||||
int max_items = max_items_json.is_number_integer() ? max_items_json.get<int>() : std::numeric_limits<int>::max();
|
||||
|
||||
return _add_rule(rule_name, "\"[\" space " + build_repetition(item_rule_name, min_items, max_items, "\",\" space") + " \"]\" space");
|
||||
return _add_rule(rule_name, "\"[\" space " + build_repetition(item_rule_name, min_items, max_items, "\",\" space") + " space \"]\"");
|
||||
}
|
||||
if ((schema_type.is_null() || schema_type == "string") && schema.contains("pattern")) {
|
||||
return _visit_pattern(schema["pattern"], rule_name);
|
||||
@@ -972,7 +972,7 @@ public:
|
||||
std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char"));
|
||||
int min_len = schema.contains("minLength") ? schema["minLength"].get<int>() : 0;
|
||||
int max_len = schema.contains("maxLength") ? schema["maxLength"].get<int>() : std::numeric_limits<int>::max();
|
||||
return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space");
|
||||
return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\"");
|
||||
}
|
||||
if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) {
|
||||
int64_t min_value = std::numeric_limits<int64_t>::min();
|
||||
@@ -990,7 +990,7 @@ public:
|
||||
std::stringstream out;
|
||||
out << "(";
|
||||
build_min_max_int(min_value, max_value, out);
|
||||
out << ") space";
|
||||
out << ")";
|
||||
return _add_rule(rule_name, out.str());
|
||||
}
|
||||
if (schema.empty() || schema_type == "object") {
|
||||
|
||||
+202
-89
@@ -6,13 +6,14 @@
|
||||
#include "unicode.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <deque>
|
||||
#include <initializer_list>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <nlohmann/json.hpp>
|
||||
#include <regex>
|
||||
#include <set>
|
||||
#include <stdexcept>
|
||||
#include <unordered_set>
|
||||
|
||||
// Trick to catch missing branches
|
||||
template <typename T>
|
||||
@@ -88,40 +89,7 @@ struct trie {
|
||||
return match_result{match_result::NO_MATCH};
|
||||
}
|
||||
|
||||
struct prefix_and_next {
|
||||
std::vector<uint32_t> prefix;
|
||||
std::vector<uint32_t> next_chars;
|
||||
};
|
||||
|
||||
std::vector<prefix_and_next> collect_prefix_and_next() {
|
||||
std::vector<uint32_t> prefix;
|
||||
std::vector<prefix_and_next> result;
|
||||
collect_prefix_and_next(0, prefix, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
void collect_prefix_and_next(size_t index, std::vector<uint32_t> & prefix, std::vector<prefix_and_next> & out) {
|
||||
if (!nodes[index].is_word) {
|
||||
if (!nodes[index].children.empty()) {
|
||||
std::vector<uint32_t> chars;
|
||||
chars.reserve(nodes[index].children.size());
|
||||
for (const auto & p : nodes[index].children) {
|
||||
chars.push_back(p.first);
|
||||
}
|
||||
out.emplace_back(prefix_and_next{prefix, chars});
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto & p : nodes[index].children) {
|
||||
uint32_t ch = p.first;
|
||||
auto child = p.second;
|
||||
prefix.push_back(ch);
|
||||
collect_prefix_and_next(child, prefix, out);
|
||||
prefix.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
size_t create_node() {
|
||||
size_t index = nodes.size();
|
||||
nodes.emplace_back();
|
||||
@@ -153,6 +121,65 @@ struct trie {
|
||||
}
|
||||
};
|
||||
|
||||
// Aho-Corasick automaton
|
||||
struct aho_corasick {
|
||||
trie t;
|
||||
std::vector<size_t> fail; // failure links
|
||||
std::vector<size_t> order; // states in BFS order
|
||||
std::vector<bool> terminal; // match states (directly or via a suffix link)
|
||||
std::set<uint32_t> alphabet; // every character with a transition
|
||||
|
||||
aho_corasick(const std::vector<std::string> & strings) : t(strings) {
|
||||
const auto & nodes = t.nodes;
|
||||
const size_t n = nodes.size();
|
||||
|
||||
fail.assign(n, 0);
|
||||
order.reserve(n);
|
||||
|
||||
std::deque<size_t> queue{ 0 };
|
||||
while (!queue.empty()) {
|
||||
size_t u = queue.front();
|
||||
queue.pop_front();
|
||||
order.push_back(u);
|
||||
for (const auto & [ch, v] : nodes[u].children) {
|
||||
if (u != 0) {
|
||||
size_t f = fail[u];
|
||||
while (f && nodes[f].children.find(ch) == nodes[f].children.end()) {
|
||||
f = fail[f];
|
||||
}
|
||||
auto it = nodes[f].children.find(ch);
|
||||
fail[v] = (it != nodes[f].children.end() && it->second != v) ? it->second : 0;
|
||||
}
|
||||
queue.push_back(v);
|
||||
}
|
||||
}
|
||||
|
||||
terminal.assign(n, false);
|
||||
for (size_t u : order) {
|
||||
terminal[u] = nodes[u].is_word || (u != 0 && terminal[fail[u]]);
|
||||
}
|
||||
|
||||
for (const auto & node : nodes) {
|
||||
for (const auto & [ch, v] : node.children) {
|
||||
alphabet.insert(ch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t num_states() const { return t.nodes.size(); }
|
||||
bool is_terminal(size_t s) const { return terminal[s]; }
|
||||
|
||||
// follow failure links until a transition on `ch` exists.
|
||||
size_t next(size_t state, uint32_t ch) const {
|
||||
const auto & nodes = t.nodes;
|
||||
while (state && nodes[state].children.find(ch) == nodes[state].children.end()) {
|
||||
state = fail[state];
|
||||
}
|
||||
auto it = nodes[state].children.find(ch);
|
||||
return it != nodes[state].children.end() ? it->second : 0;
|
||||
}
|
||||
};
|
||||
|
||||
static std::pair<uint32_t, size_t> parse_hex_escape(const std::string & str, size_t pos, int hex_count) {
|
||||
if (pos + hex_count > str.length()) {
|
||||
return {0, 0};
|
||||
@@ -894,6 +921,10 @@ struct parser_executor {
|
||||
common_peg_parse_result operator()(const common_peg_gbnf_parser & p) {
|
||||
return arena.parse(p.child, ctx, start_pos);
|
||||
}
|
||||
|
||||
common_peg_parse_result operator()(const common_peg_ac_parser & p) {
|
||||
return arena.parse(p.child, ctx, start_pos);
|
||||
}
|
||||
};
|
||||
|
||||
common_peg_parse_result common_peg_arena::parse(common_peg_parse_context & ctx, size_t start) const {
|
||||
@@ -962,7 +993,8 @@ void common_peg_arena::resolve_refs() {
|
||||
std::is_same_v<T, common_peg_not_parser> ||
|
||||
std::is_same_v<T, common_peg_tag_parser> ||
|
||||
std::is_same_v<T, common_peg_atomic_parser> ||
|
||||
std::is_same_v<T, common_peg_gbnf_parser>) {
|
||||
std::is_same_v<T, common_peg_gbnf_parser> ||
|
||||
std::is_same_v<T, common_peg_ac_parser>) {
|
||||
p.child = resolve_ref(p.child);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_rule_parser>) {
|
||||
p.child = resolve_ref(p.child);
|
||||
@@ -992,12 +1024,12 @@ void common_peg_arena::resolve_refs() {
|
||||
}
|
||||
|
||||
std::string common_peg_arena::dump(common_peg_parser_id id) const {
|
||||
std::unordered_set<common_peg_parser_id> visited;
|
||||
std::set<common_peg_parser_id> visited;
|
||||
return dump_impl(id, visited);
|
||||
}
|
||||
|
||||
std::string common_peg_arena::dump_impl(common_peg_parser_id id,
|
||||
std::unordered_set<common_peg_parser_id> & visited) const {
|
||||
std::set<common_peg_parser_id> & visited) const {
|
||||
// Check for cycles
|
||||
if (visited.count(id)) {
|
||||
return "[cycle]";
|
||||
@@ -1043,6 +1075,8 @@ std::string common_peg_arena::dump_impl(common_peg_parser_id
|
||||
return "Atomic(" + dump_impl(p.child, visited) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
|
||||
return "Gbnf(" + p.grammar + ", " + dump_impl(p.child, visited) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_ac_parser>) {
|
||||
return "Ac(" + string_join(p.delimiters, " | ") + ", " + dump_impl(p.child, visited) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_any_parser>) {
|
||||
return "Any";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_space_parser>) {
|
||||
@@ -1342,7 +1376,7 @@ common_peg_parser common_peg_parser_builder::json_object() {
|
||||
common_peg_parser common_peg_parser_builder::json_array() {
|
||||
return rule("json-array", [this]() {
|
||||
auto ws = space();
|
||||
auto elements = sequence({json(), zero_or_more(sequence({literal(","), ws, json()}))});
|
||||
auto elements = sequence({json(), zero_or_more(sequence({ws, literal(","), ws, json()}))});
|
||||
return sequence({
|
||||
literal("["),
|
||||
ws,
|
||||
@@ -1452,6 +1486,13 @@ common_peg_parser common_peg_parser_builder::json_member(const std::string & key
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::ac(const common_peg_parser & p, const std::vector<std::string> & delimiters) {
|
||||
if (delimiters.empty()) {
|
||||
throw std::runtime_error("ac parser requires at least one delimiter");
|
||||
}
|
||||
return add(common_peg_ac_parser{p, delimiters});
|
||||
}
|
||||
|
||||
static std::string gbnf_escape_char_class(uint32_t c) {
|
||||
if (c == '-' || c == ']' || c == '[' || c == '\\') {
|
||||
return "\\" + std::string(1, (char) c);
|
||||
@@ -1502,61 +1543,118 @@ static std::string gbnf_escape_char_class(uint32_t c) {
|
||||
return std::string(buf);
|
||||
}
|
||||
|
||||
static std::string gbnf_excluding_pattern(const std::vector<std::string> & strings) {
|
||||
trie matcher(strings);
|
||||
auto pieces = matcher.collect_prefix_and_next();
|
||||
|
||||
std::string pattern;
|
||||
std::string trailing; // optional proper-prefix of a delimiter, allowed only at the very end
|
||||
for (size_t i = 0; i < pieces.size(); ++i) {
|
||||
if (i > 0) {
|
||||
pattern += " | ";
|
||||
}
|
||||
|
||||
const auto & pre = pieces[i].prefix;
|
||||
const auto & chars = pieces[i].next_chars;
|
||||
|
||||
std::string cls;
|
||||
cls.reserve(chars.size());
|
||||
for (uint32_t ch : chars) {
|
||||
cls += gbnf_escape_char_class(ch);
|
||||
}
|
||||
|
||||
if (!pre.empty()) {
|
||||
std::string pre_literal = gbnf_format_literal(common_unicode_cpts_to_utf8(pre));
|
||||
pattern += pre_literal + " [^" + cls + "]";
|
||||
// Each interior alternative consumes a delimiter-prefix plus a disambiguating
|
||||
// char, so the repetition alone cannot match a value that *ends* on a proper
|
||||
// prefix of a delimiter (e.g. a trailing "\n" when the delimiter is
|
||||
// "\n</parameter>\n"). The runtime until() (greedy first-match) accepts such
|
||||
// values, so without this the grammar would reject input the parser accepts.
|
||||
// Allow the value to terminate on any proper prefix as an optional tail.
|
||||
// This makes the grammar a slight superset of the runtime language (a value
|
||||
// may end on the longest prefix, which greedy first-match would not itself
|
||||
// produce); harmless for constrained generation, which only needs to admit
|
||||
// every runtime-valid string.
|
||||
if (!trailing.empty()) {
|
||||
trailing += " | ";
|
||||
}
|
||||
trailing += pre_literal;
|
||||
} else {
|
||||
pattern += "[^" + cls + "]";
|
||||
}
|
||||
static std::string gbnf_char_class(const std::vector<uint32_t> & chars, bool negate) {
|
||||
std::string s = negate ? "[^" : "[";
|
||||
for (uint32_t ch : chars) {
|
||||
s += gbnf_escape_char_class(ch);
|
||||
}
|
||||
|
||||
std::string result = "(" + pattern + ")*";
|
||||
if (!trailing.empty()) {
|
||||
result += " (" + trailing + ")?";
|
||||
}
|
||||
return result;
|
||||
return s + "]";
|
||||
}
|
||||
|
||||
static std::unordered_set<std::string> collect_reachable_rules(
|
||||
static std::string gbnf_ac_grammar(
|
||||
const common_grammar_builder & builder,
|
||||
const std::string & prefix,
|
||||
const std::vector<std::string> & strings,
|
||||
const std::function<std::string(const std::vector<uint32_t> &,
|
||||
const std::map<size_t, std::vector<uint32_t>> &,
|
||||
const std::vector<uint32_t> &,
|
||||
const std::function<std::string(size_t)> &)> & build_rule) {
|
||||
aho_corasick ac(strings);
|
||||
|
||||
auto state_name = [&](size_t s) -> std::string {
|
||||
if (s == 0) {
|
||||
return prefix;
|
||||
}
|
||||
std::string num = std::to_string(s);
|
||||
num = num.size() == 1 ? ("0" + num) : num;
|
||||
return prefix + "-" + num;
|
||||
};
|
||||
|
||||
for (size_t q = 0; q < ac.num_states(); q++) {
|
||||
if (ac.is_terminal(q)) {
|
||||
continue; // match states
|
||||
}
|
||||
|
||||
std::map<size_t, std::vector<uint32_t>> buckets;
|
||||
std::vector<uint32_t> completing; // chars that complete a delimiter
|
||||
std::vector<uint32_t> specific; // chars with an explicit transition
|
||||
for (uint32_t c : ac.alphabet) {
|
||||
size_t d = ac.next(q, c);
|
||||
if (ac.is_terminal(d)) {
|
||||
completing.push_back(c);
|
||||
specific.push_back(c);
|
||||
} else if (d != 0) {
|
||||
buckets[d].push_back(c); // specific non-root destination
|
||||
specific.push_back(c);
|
||||
}
|
||||
}
|
||||
|
||||
builder.add_rule(state_name(q), build_rule(completing, buckets, specific, state_name));
|
||||
}
|
||||
|
||||
// An empty delimiter makes the start state terminal. Emit an entry rule
|
||||
// that matches the empty string so the returned reference stays valid.
|
||||
if (ac.is_terminal(0)) {
|
||||
builder.add_rule(prefix, "|");
|
||||
}
|
||||
|
||||
return state_name(0);
|
||||
}
|
||||
|
||||
// GBNF grammar matching strings that contain no string in `strings` as a
|
||||
// substring. Emits the complement of an Aho-Corasick automaton DFA and returns
|
||||
// the start state rule name.
|
||||
//
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/24839
|
||||
static std::string gbnf_excluding_grammar(const common_grammar_builder & builder,
|
||||
const std::string & prefix,
|
||||
const std::vector<std::string> & strings) {
|
||||
return gbnf_ac_grammar(builder, prefix, strings,
|
||||
[](const std::vector<uint32_t> & /*completing*/,
|
||||
const std::map<size_t, std::vector<uint32_t>> & buckets,
|
||||
const std::vector<uint32_t> & specific,
|
||||
const std::function<std::string(size_t)> & state_name) {
|
||||
// every state is accepting and completing chars get no
|
||||
// alternative, so a forbidden string can never be matched
|
||||
std::string rhs = "|";
|
||||
for (const auto & [d, chars] : buckets) {
|
||||
rhs += " " + gbnf_char_class(chars, false) + " " + state_name(d) + " |";
|
||||
}
|
||||
rhs += " " + gbnf_char_class(specific, true) + " " + state_name(0);
|
||||
return rhs;
|
||||
});
|
||||
}
|
||||
|
||||
// GBNF grammar matching everything up to and including the first occurrence of
|
||||
// any string in `strings`. Emits the Aho-Corasick automaton DFA and returns
|
||||
// the start state rule name.
|
||||
static std::string gbnf_including_grammar(const common_grammar_builder & builder,
|
||||
const std::string & prefix,
|
||||
const std::vector<std::string> & strings) {
|
||||
return gbnf_ac_grammar(builder, prefix, strings,
|
||||
[](const std::vector<uint32_t> & completing,
|
||||
const std::map<size_t, std::vector<uint32_t>> & buckets,
|
||||
const std::vector<uint32_t> & specific,
|
||||
const std::function<std::string(size_t)> & state_name) {
|
||||
std::vector<std::string> alts;
|
||||
if (!completing.empty()) {
|
||||
alts.push_back(gbnf_char_class(completing, false)); // terminate on match
|
||||
}
|
||||
for (const auto & [d, chars] : buckets) {
|
||||
alts.push_back(gbnf_char_class(chars, false) + " " + state_name(d));
|
||||
}
|
||||
// every other character keeps scanning from the start state
|
||||
alts.push_back(gbnf_char_class(specific, true) + " " + state_name(0));
|
||||
return string_join(alts, " | ");
|
||||
});
|
||||
}
|
||||
|
||||
static std::set<std::string> collect_reachable_rules(
|
||||
const common_peg_arena & arena,
|
||||
const common_peg_parser_id & rule
|
||||
) {
|
||||
std::unordered_set<std::string> reachable;
|
||||
std::unordered_set<std::string> visited;
|
||||
std::set<std::string> reachable;
|
||||
std::set<std::string> visited;
|
||||
|
||||
std::function<void(common_peg_parser_id)> visit = [&](common_peg_parser_id id) {
|
||||
const auto & parser = arena.get(id);
|
||||
@@ -1588,6 +1686,7 @@ static std::unordered_set<std::string> collect_reachable_rules(
|
||||
std::is_same_v<T, common_peg_tag_parser> ||
|
||||
std::is_same_v<T, common_peg_atomic_parser> ||
|
||||
std::is_same_v<T, common_peg_gbnf_parser> ||
|
||||
std::is_same_v<T, common_peg_ac_parser> ||
|
||||
std::is_same_v<T, common_peg_schema_parser>) {
|
||||
visit(p.child);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_rule_parser>) {
|
||||
@@ -1765,7 +1864,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
|
||||
if (p.delimiters.empty()) {
|
||||
return ".*";
|
||||
}
|
||||
return gbnf_excluding_pattern(p.delimiters);
|
||||
return gbnf_excluding_grammar(builder, "until-" + std::to_string(id), p.delimiters);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_schema_parser>) {
|
||||
if (schema_delegates(p)) {
|
||||
return to_gbnf(p.child);
|
||||
@@ -1782,6 +1881,8 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
|
||||
return to_gbnf(p.child);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
|
||||
return p.grammar;
|
||||
} else if constexpr (std::is_same_v<T, common_peg_ac_parser>) {
|
||||
return gbnf_including_grammar(builder, "ac-" + std::to_string(id), p.delimiters);
|
||||
} else {
|
||||
static_assert(is_always_false_v<T>);
|
||||
}
|
||||
@@ -1789,7 +1890,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
|
||||
};
|
||||
|
||||
// Collect reachable rules
|
||||
std::unordered_set<std::string> reachable_rules;
|
||||
std::set<std::string> reachable_rules;
|
||||
|
||||
if (lazy) {
|
||||
// Collect rules reachable from trigger rules
|
||||
@@ -1918,6 +2019,8 @@ static nlohmann::json serialize_parser_variant(const common_peg_parser_variant &
|
||||
};
|
||||
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
|
||||
return json{{"type", "gbnf"}, {"child", p.child}, {"grammar", p.grammar}};
|
||||
} else if constexpr (std::is_same_v<T, common_peg_ac_parser>) {
|
||||
return json{{"type", "ac"}, {"child", p.child}, {"delimiters", p.delimiters}};
|
||||
}
|
||||
}, variant);
|
||||
}
|
||||
@@ -2090,6 +2193,16 @@ static common_peg_parser_variant deserialize_parser_variant(const nlohmann::json
|
||||
};
|
||||
}
|
||||
|
||||
if (type == "ac") {
|
||||
if (!j.contains("child") || !j.contains("delimiters") || !j["delimiters"].is_array() || j["delimiters"].empty()) {
|
||||
throw std::runtime_error("ac parser requires 'child' and a non-empty 'delimiters' array");
|
||||
}
|
||||
return common_peg_ac_parser{
|
||||
j["child"].get<common_peg_parser_id>(),
|
||||
j["delimiters"].get<std::vector<std::string>>(),
|
||||
};
|
||||
}
|
||||
|
||||
throw std::runtime_error("Unknown parser type: " + type);
|
||||
}
|
||||
|
||||
|
||||
+16
-3
@@ -3,8 +3,8 @@
|
||||
#include <nlohmann/json_fwd.hpp>
|
||||
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <functional>
|
||||
@@ -275,6 +275,11 @@ struct common_peg_gbnf_parser {
|
||||
std::string grammar;
|
||||
};
|
||||
|
||||
struct common_peg_ac_parser {
|
||||
common_peg_parser_id child;
|
||||
std::vector<std::string> delimiters;
|
||||
};
|
||||
|
||||
// Variant holding all parser types
|
||||
using common_peg_parser_variant = std::variant<
|
||||
common_peg_epsilon_parser,
|
||||
@@ -296,7 +301,8 @@ using common_peg_parser_variant = std::variant<
|
||||
common_peg_ref_parser,
|
||||
common_peg_atomic_parser,
|
||||
common_peg_tag_parser,
|
||||
common_peg_gbnf_parser
|
||||
common_peg_gbnf_parser,
|
||||
common_peg_ac_parser
|
||||
>;
|
||||
|
||||
class common_peg_arena {
|
||||
@@ -335,7 +341,7 @@ class common_peg_arena {
|
||||
friend class common_peg_parser_builder;
|
||||
|
||||
private:
|
||||
std::string dump_impl(common_peg_parser_id id, std::unordered_set<common_peg_parser_id> & visited) const;
|
||||
std::string dump_impl(common_peg_parser_id id, std::set<common_peg_parser_id> & visited) const;
|
||||
|
||||
common_peg_parser_id add_parser(common_peg_parser_variant parser);
|
||||
void add_rule(const std::string & name, common_peg_parser_id id);
|
||||
@@ -514,6 +520,13 @@ class common_peg_parser_builder {
|
||||
// the child's grammar. Parsing delegates entirely to the child.
|
||||
common_peg_parser gbnf(const common_peg_parser & p, const std::string & grammar) { return add(common_peg_gbnf_parser{p, grammar}); }
|
||||
|
||||
// Wraps a child parser but emits a GBNF grammar built from the Aho-Corasick
|
||||
// automaton of `delimiters`, matching everything up to and including the
|
||||
// first delimiter. Parsing delegates entirely to the child, which is
|
||||
// responsible for consuming the delimiter (e.g. until(D) + literal(D)).
|
||||
common_peg_parser ac(const common_peg_parser & p, const std::vector<std::string> & delimiters);
|
||||
common_peg_parser ac(const common_peg_parser & p, const std::string & delimiter) { return ac(p, std::vector<std::string>{delimiter}); }
|
||||
|
||||
void set_root(const common_peg_parser & p);
|
||||
|
||||
common_peg_arena build();
|
||||
|
||||
+174
-35
@@ -161,6 +161,10 @@ struct common_speculative_impl {
|
||||
|
||||
virtual void accept(llama_seq_id seq_id, uint16_t n_accepted, bool is_other) = 0;
|
||||
|
||||
// (optional) serialize/restore per-seq internal state (e.g. eagle3's deferred boundary).
|
||||
virtual bool get_state(llama_seq_id /*seq_id*/, std::vector<uint8_t> & /*data*/) const { return false; }
|
||||
virtual void set_state(llama_seq_id /*seq_id*/, const std::vector<uint8_t> & /*data*/) {}
|
||||
|
||||
// true if this implementation requires the target context to extract post-norm embeddings
|
||||
virtual bool need_embd() const = 0;
|
||||
|
||||
@@ -841,6 +845,49 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
|
||||
(size_t) n_embd_dec * sizeof(float));
|
||||
}
|
||||
|
||||
// we only need to stash the deferred boundary's g_embd row for recurrent/hybrid targets:
|
||||
// their single-position checkpoints drop it on restore
|
||||
bool need_boundary_stash() const {
|
||||
const llama_model * model_tgt = llama_get_model(params.ctx_tgt);
|
||||
return llama_model_is_recurrent(model_tgt) || llama_model_is_hybrid(model_tgt);
|
||||
}
|
||||
|
||||
bool get_state(llama_seq_id seq_id, std::vector<uint8_t> & data) const override {
|
||||
if (!need_boundary_stash()) {
|
||||
return false;
|
||||
}
|
||||
if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq || pending_pos_last[seq_id] < 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const llama_pos pos = pending_pos_last[seq_id];
|
||||
const std::vector<float> & g = pending_g_last[seq_id];
|
||||
|
||||
data.resize(sizeof(llama_pos) + g.size() * sizeof(float));
|
||||
std::memcpy(data.data(), &pos, sizeof(llama_pos));
|
||||
std::memcpy(data.data() + sizeof(llama_pos), g.data(), g.size() * sizeof(float));
|
||||
return true;
|
||||
}
|
||||
|
||||
void set_state(llama_seq_id seq_id, const std::vector<uint8_t> & data) override {
|
||||
if (!need_boundary_stash()) {
|
||||
return;
|
||||
}
|
||||
if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq) {
|
||||
return;
|
||||
}
|
||||
if (data.size() != sizeof(llama_pos) + (size_t) n_embd_dec * sizeof(float)) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_pos pos = -1;
|
||||
std::memcpy(&pos, data.data(), sizeof(llama_pos));
|
||||
|
||||
pending_pos_last[seq_id] = pos;
|
||||
pending_g_last[seq_id].resize(n_embd_dec);
|
||||
std::memcpy(pending_g_last[seq_id].data(), data.data() + sizeof(llama_pos), (size_t) n_embd_dec * sizeof(float));
|
||||
}
|
||||
|
||||
bool need_embd() const override {
|
||||
return false;
|
||||
}
|
||||
@@ -858,7 +905,13 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
|
||||
int32_t n_embd = 0;
|
||||
|
||||
bool is_mem_shared = false;
|
||||
// One MTP draft driver, three modes (set once in the ctor):
|
||||
// is_mem_shared (gemma4): shares the target KV, runs all heads in one graph.
|
||||
// chain_heads (step35): n_mtp_layers trained heads, one per draft step.
|
||||
// neither (qwen35 / qwen35moe): a single trained MTP head.
|
||||
int32_t n_mtp_layers = 1;
|
||||
bool is_mem_shared = false; // gemma4
|
||||
bool chain_heads = false; // derived in the ctor: n_mtp_layers > 1 && !is_mem_shared
|
||||
|
||||
// Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1.
|
||||
// The last h-row of one process() call needs the first token of the NEXT
|
||||
@@ -873,10 +926,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
std::vector<std::vector<float>> verify_h;
|
||||
std::vector<int32_t> verify_h_rows;
|
||||
|
||||
// Per-seq draft length from the last draft() call, used in accept() to
|
||||
// roll back ctx_dft's recurrent state past the AR draft's redundant
|
||||
// pre-advancement before process() mirrored the verify batch.
|
||||
std::vector<uint16_t> last_n_drafted;
|
||||
std::vector<int> i_last;
|
||||
std::vector<std::vector<float>> chain_h;
|
||||
|
||||
common_speculative_impl_draft_mtp(const common_params_speculative & params, uint32_t n_seq)
|
||||
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, n_seq)
|
||||
@@ -889,6 +940,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
n_embd = llama_model_n_embd_out(llama_get_model(ctx_dft));
|
||||
GGML_ASSERT(n_embd == llama_model_n_embd(llama_get_model(ctx_tgt)) &&
|
||||
"MTP input row width must match the target h_nextn width");
|
||||
n_mtp_layers = std::max(1, (int) llama_model_n_layer_nextn(llama_get_model(ctx_dft)));
|
||||
|
||||
LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__);
|
||||
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling);
|
||||
@@ -935,16 +987,25 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true);
|
||||
|
||||
is_mem_shared = llama_get_ctx_other(ctx_dft) == ctx_tgt;
|
||||
chain_heads = n_mtp_layers > 1 && !is_mem_shared;
|
||||
|
||||
if (chain_heads) {
|
||||
this->params.n_max = std::min(this->params.n_max, n_mtp_layers);
|
||||
|
||||
chain_h.assign(n_seq, {});
|
||||
for (auto & c : chain_h) {
|
||||
c.reserve((size_t) (this->params.n_max + 1) * n_embd);
|
||||
}
|
||||
}
|
||||
|
||||
pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));
|
||||
|
||||
i_last.assign(n_seq, -1);
|
||||
i_batch_beg.assign(n_seq, -1);
|
||||
i_batch_end.assign(n_seq, -1);
|
||||
|
||||
verify_h.assign(n_seq, {});
|
||||
verify_h_rows.assign(n_seq, 0);
|
||||
|
||||
last_n_drafted.assign(n_seq, 0);
|
||||
}
|
||||
|
||||
~common_speculative_impl_draft_mtp() override {
|
||||
@@ -1050,9 +1111,34 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
set_h(i_batch_beg[seq_id], pending_h[seq_id].data());
|
||||
}
|
||||
|
||||
const int32_t rc = llama_decode(ctx_dft, batch);
|
||||
if (rc != 0) {
|
||||
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]);
|
||||
auto * mem_dft = llama_get_memory(ctx_dft);
|
||||
|
||||
bool ok = true;
|
||||
for (int head = 0; head < n_mtp_layers; ++head) {
|
||||
if (chain_heads) {
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/24340/changes#r3413498544
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
if (i_batch_beg[seq_id] < 0) {
|
||||
continue;
|
||||
}
|
||||
llama_memory_seq_rm(mem_dft, seq_id, batch_in.pos[i_batch_beg[seq_id]], -1);
|
||||
}
|
||||
llama_set_nextn_layer_offset(ctx_dft, head);
|
||||
}
|
||||
|
||||
const int32_t rc = llama_decode(ctx_dft, batch);
|
||||
if (rc != 0) {
|
||||
LOG_ERR("%s: llama_decode(ctx_dft) head=%d failed rc=%d (pos=%d)\n",
|
||||
__func__, head, (int) rc, (int) batch_in.pos[0]);
|
||||
ok = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (chain_heads) {
|
||||
llama_set_nextn_layer_offset(ctx_dft, 0); // restore default for non-draft decodes
|
||||
}
|
||||
if (!ok) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -1087,7 +1173,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
int n_drafting = 0;
|
||||
std::vector<bool> drafting(n_seq);
|
||||
|
||||
const float * h_row = nullptr;
|
||||
const size_t row_bytes = (size_t) n_embd * sizeof(float);
|
||||
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
@@ -1102,22 +1187,43 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
common_sampler_reset(smpls[seq_id].get());
|
||||
|
||||
common_batch_add(batch, dp.id_last, dp.n_past, { seq_id }, true);
|
||||
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, pending_h[seq_id].data(), row_bytes);
|
||||
|
||||
h_row = pending_h[seq_id].data();
|
||||
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
|
||||
}
|
||||
i_last[seq_id] = batch.n_tokens - 1;
|
||||
|
||||
int ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0) {
|
||||
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
|
||||
return;
|
||||
if (chain_heads) {
|
||||
chain_h[seq_id].assign(pending_h[seq_id].begin(), pending_h[seq_id].end());
|
||||
}
|
||||
}
|
||||
|
||||
int i = 0;
|
||||
|
||||
while (n_drafting > 0) {
|
||||
int i_batch = 0;
|
||||
// each step decodes under a different head, i.e. a different decoder layer, and
|
||||
// KV is per layer. process() filled this layer's KV only for positions < n_past
|
||||
// (prompt + accepted prefix) — nothing in the draft region yet. so reset the
|
||||
// draft region (the seq_rm lower bound is n_past, leaving the prompt KV intact)
|
||||
// and select head i so it rebuilds its own layer's KV there; decoding just the
|
||||
// latest token would leave its attention reading cells only another head wrote.
|
||||
if (chain_heads) {
|
||||
auto * mem_dft = llama_get_memory(ctx_dft);
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
if (drafting[seq_id]) {
|
||||
llama_memory_seq_rm(mem_dft, seq_id, dparams[seq_id].n_past, -1);
|
||||
}
|
||||
}
|
||||
llama_set_nextn_layer_offset(ctx_dft, i);
|
||||
}
|
||||
|
||||
int ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0) {
|
||||
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
|
||||
break;
|
||||
}
|
||||
|
||||
// rebuild the batch for the next step: the growing-KV paths re-add only the
|
||||
// new token (the KV already holds the prefix), while chained heads re-add the
|
||||
// whole prefix at the next head. dropped sequences are simply not re-added.
|
||||
common_batch_clear(batch);
|
||||
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
@@ -1127,9 +1233,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
|
||||
auto * smpl = smpls[seq_id].get();
|
||||
|
||||
common_sampler_sample(smpl, ctx_dft, i_batch, true);
|
||||
h_row = llama_get_embeddings_nextn_ith(ctx_dft, i_batch);
|
||||
++i_batch;
|
||||
common_sampler_sample(smpl, ctx_dft, i_last[seq_id], true);
|
||||
const float * h_row = llama_get_embeddings_nextn_ith(ctx_dft, i_last[seq_id]);
|
||||
|
||||
const auto * cur_p = common_sampler_get_candidates(smpl, true);
|
||||
|
||||
@@ -1163,30 +1268,41 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (is_mem_shared) {
|
||||
if (chain_heads) {
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/24340#discussion_r3448031546
|
||||
chain_h[seq_id].insert(chain_h[seq_id].end(), h_row, h_row + n_embd);
|
||||
|
||||
const int n_rows = (int) result.size() + 1; // id_last + tokens drafted so far
|
||||
for (int t = 0; t < n_rows; ++t) {
|
||||
const llama_token tok = (t == 0) ? dp.id_last : result[t - 1];
|
||||
common_batch_add(batch, tok, dp.n_past + t, { seq_id }, t == n_rows - 1);
|
||||
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd,
|
||||
chain_h[seq_id].data() + (size_t) t * n_embd, row_bytes);
|
||||
}
|
||||
} else if (is_mem_shared) {
|
||||
// note: with shared memory (e.g. Gemma4 assistants) we use the same position for all draft tokens
|
||||
// ref: https://github.com/huggingface/transformers/blob/effde20942e3f82a1b97449f60b3a48c5ff96145/docs/source/en/model_doc/gemma4_assistant.md?plain=1#L36-L37
|
||||
common_batch_add(batch, id, dp.n_past, { seq_id }, true);
|
||||
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, h_row, row_bytes);
|
||||
} else {
|
||||
common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true);
|
||||
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, h_row, row_bytes);
|
||||
}
|
||||
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
|
||||
|
||||
i_last[seq_id] = batch.n_tokens - 1;
|
||||
}
|
||||
|
||||
if (batch.n_tokens == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
// evaluate the drafted tokens on the draft model
|
||||
ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0) {
|
||||
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
|
||||
break;
|
||||
}
|
||||
|
||||
++i;
|
||||
}
|
||||
|
||||
if (chain_heads) {
|
||||
llama_set_nextn_layer_offset(ctx_dft, 0); // restore default for non-draft decodes
|
||||
}
|
||||
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
auto & dp = dparams[seq_id];
|
||||
if (!dp.drafting) {
|
||||
@@ -1196,8 +1312,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
if (dp.result->size() < (size_t) params.n_min) {
|
||||
dp.result->clear();
|
||||
}
|
||||
|
||||
last_n_drafted[seq_id] = (uint16_t) dp.result->size();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1810,7 +1924,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
|
||||
|
||||
bool has_draft_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE));
|
||||
bool has_draft_eagle3 = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3)) && params.draft.ctx_dft != nullptr;
|
||||
bool has_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr;
|
||||
bool has_draft_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr;
|
||||
|
||||
|
||||
|
||||
@@ -1848,7 +1962,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
|
||||
if (has_draft_eagle3) {
|
||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, params));
|
||||
}
|
||||
if (has_mtp) {
|
||||
if (has_draft_mtp) {
|
||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, params));
|
||||
}
|
||||
}
|
||||
@@ -2118,6 +2232,31 @@ void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, u
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: support the case of more than one speculative implementations having a state
|
||||
bool common_speculative_get_state(common_speculative * spec, llama_seq_id seq_id, std::vector<uint8_t> & data) {
|
||||
if (spec == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto & impl : spec->impls) {
|
||||
if (impl->get_state(seq_id, data)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void common_speculative_set_state(common_speculative * spec, llama_seq_id seq_id, const std::vector<uint8_t> & data) {
|
||||
if (spec == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (auto & impl : spec->impls) {
|
||||
impl->set_state(seq_id, data);
|
||||
}
|
||||
}
|
||||
|
||||
void common_speculative_print_stats(const common_speculative * spec) {
|
||||
if (spec == nullptr) {
|
||||
return;
|
||||
|
||||
@@ -68,6 +68,10 @@ void common_speculative_draft(common_speculative * spec);
|
||||
// informs the speculative context that n_accepted tokens were accepted by the target model
|
||||
void common_speculative_accept(common_speculative * spec, llama_seq_id, uint16_t n_accepted);
|
||||
|
||||
// (optional) get/set internal state
|
||||
bool common_speculative_get_state(common_speculative * spec, llama_seq_id seq_id, std::vector<uint8_t> & data);
|
||||
void common_speculative_set_state(common_speculative * spec, llama_seq_id seq_id, const std::vector<uint8_t> & data);
|
||||
|
||||
// print statistics about the speculative decoding
|
||||
void common_speculative_print_stats(const common_speculative * spec);
|
||||
|
||||
|
||||
@@ -96,6 +96,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
|
||||
"GraniteMoeHybridForCausalLM": "granite",
|
||||
"GraniteMoeSharedForCausalLM": "granite",
|
||||
"GraniteSpeechForConditionalGeneration": "granite",
|
||||
"GraniteSpeechPlusForConditionalGeneration": "granite",
|
||||
"Grok1ForCausalLM": "grok",
|
||||
"GrokForCausalLM": "grok",
|
||||
"GroveMoeForCausalLM": "grovemoe",
|
||||
@@ -261,6 +262,7 @@ MMPROJ_MODEL_MAP: dict[str, str] = {
|
||||
"GlmasrModel": "ultravox",
|
||||
"Granite4VisionForConditionalGeneration": "granite",
|
||||
"GraniteSpeechForConditionalGeneration": "granite",
|
||||
"GraniteSpeechPlusForConditionalGeneration": "granite",
|
||||
"HunYuanVLForConditionalGeneration": "hunyuan",
|
||||
"Idefics3ForConditionalGeneration": "smolvlm",
|
||||
"InternVisionModel": "internvl",
|
||||
|
||||
@@ -126,7 +126,7 @@ class BailingMoeV2Model(TextModel):
|
||||
if (rope_dim := hparams.get("head_dim")) is None:
|
||||
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
|
||||
|
||||
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)))
|
||||
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.rope_parameters.get("partial_rotary_factor", 0.5)))
|
||||
self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
|
||||
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
||||
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
|
||||
|
||||
+7
-1
@@ -1119,8 +1119,10 @@ class TextModel(ModelBase):
|
||||
|
||||
rope_theta = self.find_hparam(["global_rope_theta", "rope_global_theta", "rope_theta_global", "rope_theta", "rotary_emb_base"], optional=True)
|
||||
local_rope_theta = self.find_hparam(["local_rope_theta", "rope_local_theta", "rope_theta_local", "swa_rope_theta", "rope_local_base_freq"], optional=True)
|
||||
partial_rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct", "rope_percent"], optional=True)
|
||||
original_max_position_embeddings = self.find_hparam(["original_max_position_embeddings"], optional=True)
|
||||
|
||||
# Ensure "rope_theta" and "rope_type" is mirrored in rope_parameters
|
||||
# Ensure global params are mirrored in rope_parameters
|
||||
if "full_attention" not in self.rope_parameters and "sliding_attention" not in self.rope_parameters:
|
||||
if local_rope_theta is not None:
|
||||
self.rope_parameters["sliding_attention"] = {"rope_theta": local_rope_theta}
|
||||
@@ -1128,6 +1130,10 @@ class TextModel(ModelBase):
|
||||
self.rope_parameters["rope_theta"] = rope_theta
|
||||
if "rope_type" not in self.rope_parameters and (rope_type := self.rope_parameters.get("type")) is not None:
|
||||
self.rope_parameters["rope_type"] = rope_type
|
||||
if "partial_rotary_factor" not in self.rope_parameters and partial_rotary_factor is not None:
|
||||
self.rope_parameters["partial_rotary_factor"] = partial_rotary_factor
|
||||
if "original_max_position_embeddings" not in self.rope_parameters and original_max_position_embeddings is not None:
|
||||
self.rope_parameters["original_max_position_embeddings"] = original_max_position_embeddings
|
||||
|
||||
@classmethod
|
||||
def __init_subclass__(cls):
|
||||
|
||||
@@ -148,7 +148,7 @@ class ChatGLMModel(TextModel):
|
||||
rope_dim = self.hparams["attention_dim"]
|
||||
else:
|
||||
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
||||
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)))
|
||||
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.rope_parameters.get("partial_rotary_factor", 0.5)))
|
||||
self.gguf_writer.add_add_bos_token(False)
|
||||
rope_freq = 10000
|
||||
if "rope_ratio" in self.hparams:
|
||||
|
||||
+1
-1
@@ -161,7 +161,7 @@ class DeciModel(TextModel):
|
||||
factor = rope_params.get("factor", 8.0)
|
||||
low_freq_factor = rope_params.get("low_freq_factor", 1.0)
|
||||
high_freq_factor = rope_params.get("high_freq_factor", 4.0)
|
||||
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)
|
||||
old_context_len = rope_params.get("original_max_position_embeddings", 8192)
|
||||
|
||||
low_freq_wavelen = old_context_len / low_freq_factor
|
||||
high_freq_wavelen = old_context_len / high_freq_factor
|
||||
|
||||
@@ -24,7 +24,7 @@ class ExaoneModel(TextModel):
|
||||
|
||||
assert (hparams["activation_function"] == "silu")
|
||||
|
||||
rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"], optional=True)
|
||||
rotary_factor = self.rope_parameters.get("partial_rotary_factor")
|
||||
rotary_factor = rotary_factor if rotary_factor is not None else 1.0
|
||||
self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"])))
|
||||
|
||||
@@ -39,7 +39,7 @@ class ExaoneModel(TextModel):
|
||||
factor = rope_params.get("factor", 8.0)
|
||||
low_freq_factor = rope_params.get("low_freq_factor", 1.0)
|
||||
high_freq_factor = rope_params.get("high_freq_factor", 4.0)
|
||||
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)
|
||||
old_context_len = rope_params.get("original_max_position_embeddings", 8192)
|
||||
|
||||
low_freq_wavelen = old_context_len / low_freq_factor
|
||||
high_freq_wavelen = old_context_len / high_freq_factor
|
||||
@@ -104,7 +104,7 @@ class Exaone4Model(TextModel):
|
||||
factor = rope_params.get("factor", 16.0)
|
||||
low_freq_factor = rope_params.get("low_freq_factor", 1.0)
|
||||
high_freq_factor = rope_params.get("high_freq_factor", 4.0)
|
||||
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)
|
||||
old_context_len = rope_params.get("original_max_position_embeddings", 8192)
|
||||
|
||||
low_freq_wavelen = old_context_len / low_freq_factor
|
||||
high_freq_wavelen = old_context_len / high_freq_factor
|
||||
|
||||
+1
-1
@@ -693,7 +693,7 @@ class Gemma4Model(Gemma3Model):
|
||||
self.gguf_writer.add_head_count_kv(value_arr)
|
||||
|
||||
# handle n_rot differently for global vs swa layers
|
||||
partial_rotary_factor_swa = self.hparams.get("partial_rotary_factor", 1.0)
|
||||
partial_rotary_factor_swa = self.rope_parameters.get("partial_rotary_factor", 1.0)
|
||||
n_rot_full = int(head_dim_full) # "proportional" is used, see generate_extra_tensors
|
||||
n_rot_swa = int(head_dim_swa * partial_rotary_factor_swa)
|
||||
self.gguf_writer.add_rope_dimension_count(n_rot_full)
|
||||
|
||||
+2
-2
@@ -124,7 +124,7 @@ class Glm4MoeModel(TextModel):
|
||||
self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
||||
)
|
||||
self.gguf_writer.add_rope_dimension_count(
|
||||
int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5))
|
||||
int(rope_dim * self.rope_parameters.get("partial_rotary_factor", 0.5))
|
||||
)
|
||||
|
||||
# MoE parameters - Use only routed expert count (shared experts handled separately)
|
||||
@@ -226,7 +226,7 @@ class GlmMoeDsaModel(DeepseekV2Model):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
rope_dim = self.hparams["qk_rope_head_dim"]
|
||||
partial_rotary_factor = self.hparams.get("partial_rotary_factor", 1.0)
|
||||
partial_rotary_factor = self.rope_parameters.get("partial_rotary_factor", 1.0)
|
||||
self.gguf_writer.add_rope_dimension_count(int(rope_dim * partial_rotary_factor))
|
||||
|
||||
# NextN/MTP prediction layers
|
||||
|
||||
@@ -348,6 +348,34 @@ class GraniteSpeechMmprojModel(MmprojModel):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("GraniteSpeechPlusForConditionalGeneration")
|
||||
class GraniteSpeechPlusMmprojModel(GraniteSpeechMmprojModel):
|
||||
"""Conversion for GraniteSpeechPlus - extends GraniteSpeech with feature layer concatenation"""
|
||||
has_vision_encoder = False
|
||||
has_audio_encoder = True
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
assert self.hparams_audio is not None
|
||||
super().set_gguf_parameters()
|
||||
|
||||
# Add feature_layer if present in encoder config
|
||||
if feature_layers := self.hparams_audio.get("cat_hidden_layers"):
|
||||
self.gguf_writer.add_audio_feature_layers(feature_layers)
|
||||
logger.info(f"gguf: audio feature_layers = {feature_layers}")
|
||||
|
||||
# Validate projector dimension matches concatenated encoder output
|
||||
hidden_dim = self.hparams_audio["hidden_dim"]
|
||||
expected_dim = hidden_dim * (len(feature_layers) + 1)
|
||||
projector_dim = self.global_config["projector_config"]["encoder_hidden_size"]
|
||||
|
||||
if projector_dim != expected_dim:
|
||||
raise ValueError(
|
||||
f"Projector encoder_hidden_size ({projector_dim}) does not match "
|
||||
f"expected concatenated dimension ({expected_dim}). "
|
||||
f"Expected: hidden_dim ({hidden_dim}) * (len(feature_layers) + 1) = {expected_dim}"
|
||||
)
|
||||
|
||||
|
||||
@ModelBase.register("Granite4VisionForConditionalGeneration")
|
||||
class Granite4VisionMmprojModel(MmprojModel):
|
||||
has_vision_encoder = True
|
||||
|
||||
+1
-1
@@ -289,7 +289,7 @@ class LlamaModel(TextModel):
|
||||
factor = rope_params.get("factor", 8.0)
|
||||
low_freq_factor = rope_params.get("low_freq_factor", 1.0)
|
||||
high_freq_factor = rope_params.get("high_freq_factor", 4.0)
|
||||
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)
|
||||
old_context_len = rope_params.get("original_max_position_embeddings", 8192)
|
||||
|
||||
low_freq_wavelen = old_context_len / low_freq_factor
|
||||
high_freq_wavelen = old_context_len / high_freq_factor
|
||||
|
||||
+1
-1
@@ -154,7 +154,7 @@ class MimoV2Model(TextModel):
|
||||
self.gguf_writer.add_expert_count(self.hparams["n_routed_experts"])
|
||||
self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"])
|
||||
|
||||
rope_dim = int(self.hparams["head_dim"] * self.hparams["partial_rotary_factor"])
|
||||
rope_dim = int(self.hparams["head_dim"] * self.rope_parameters["partial_rotary_factor"])
|
||||
self.gguf_writer.add_rope_dimension_count(rope_dim)
|
||||
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("layernorm_epsilon", 1e-5))
|
||||
|
||||
+6
-10
@@ -32,11 +32,9 @@ class MiniCPMModel(TextModel):
|
||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||
rope_dims = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
||||
|
||||
rope_scaling = self.find_hparam(['rope_scaling'], True)
|
||||
if rope_scaling is not None:
|
||||
long_factors = rope_scaling.get('long_factor', None)
|
||||
short_factors = rope_scaling.get('short_factor', None)
|
||||
|
||||
long_factors = self.rope_parameters.get('long_factor')
|
||||
short_factors = self.rope_parameters.get('short_factor')
|
||||
if long_factors or short_factors:
|
||||
if long_factors is None or short_factors is None:
|
||||
raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor')
|
||||
|
||||
@@ -85,13 +83,11 @@ class MiniCPM3Model(TextModel):
|
||||
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
|
||||
|
||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||
rope_scaling = self.find_hparam(['rope_scaling'], True)
|
||||
if rope_scaling is not None:
|
||||
long_factors = self.rope_parameters.get('long_factor')
|
||||
short_factors = self.rope_parameters.get('short_factor')
|
||||
if long_factors or short_factors:
|
||||
rope_dims = self.hparams["qk_rope_head_dim"]
|
||||
|
||||
long_factors = rope_scaling.get('long_factor', None)
|
||||
short_factors = rope_scaling.get('short_factor', None)
|
||||
|
||||
if long_factors is None or short_factors is None:
|
||||
raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor')
|
||||
|
||||
|
||||
@@ -125,17 +125,18 @@ class NemotronModel(TextModel):
|
||||
self.gguf_writer.add_layer_norm_eps(f_norm_eps)
|
||||
|
||||
# * Partial RoPE
|
||||
rot_pct = self.find_hparam(["partial_rotary_factor", "rope_pct", "rope_percent"])
|
||||
rot_pct = self.rope_parameters["partial_rotary_factor"]
|
||||
n_embd = self.find_hparam(["hidden_size", "n_embd"])
|
||||
n_head = self.find_hparam(["num_attention_heads", "n_head"])
|
||||
self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head)
|
||||
|
||||
# * RopeScaling for Nemotron
|
||||
if "rope_scaling" not in self.hparams or self.hparams["rope_scaling"] is None:
|
||||
factor = self.hparams.get("factor") or self.rope_parameters.get("factor")
|
||||
if factor is None:
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
|
||||
else:
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
|
||||
self.gguf_writer.add_rope_scaling_factor(self.hparams["factor"])
|
||||
self.gguf_writer.add_rope_scaling_factor(factor)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# * Adding +1 to LayerNorm's weights here to implement layernorm1p w/o changing anything on the GGML engine side
|
||||
|
||||
+9
-11
@@ -18,7 +18,7 @@ class Phi2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.PHI2
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
rot_pct = self.find_hparam(["partial_rotary_factor"])
|
||||
rot_pct = self.rope_parameters["partial_rotary_factor"]
|
||||
n_embd = self.find_hparam(["hidden_size", "n_embd"])
|
||||
n_head = self.find_hparam(["num_attention_heads", "n_head"])
|
||||
|
||||
@@ -149,8 +149,8 @@ class Phi3MiniModel(TextModel):
|
||||
n_head_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"])
|
||||
rms_eps = self.find_hparam(["rms_norm_eps"])
|
||||
max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"])
|
||||
orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"])
|
||||
rot_pct = self.hparams.get("partial_rotary_factor", 1.0)
|
||||
orig_max_pos_embds = self.rope_parameters["original_max_position_embeddings"]
|
||||
rot_pct = self.rope_parameters.get("partial_rotary_factor", 1.0)
|
||||
rope_dims = int(rot_pct * n_embd) // n_head
|
||||
|
||||
self.gguf_writer.add_context_length(max_pos_embds)
|
||||
@@ -174,18 +174,19 @@ class Phi3MiniModel(TextModel):
|
||||
n_embd = self.find_hparam(["hidden_size", "n_embd"])
|
||||
n_head = self.find_hparam(["num_attention_heads", "n_head"])
|
||||
max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"])
|
||||
orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"])
|
||||
rot_pct = self.hparams.get("partial_rotary_factor", 1.0)
|
||||
orig_max_pos_embds = self.rope_parameters["original_max_position_embeddings"]
|
||||
rot_pct = self.rope_parameters.get("partial_rotary_factor", 1.0)
|
||||
rope_dims = int(rot_pct * n_embd) // n_head
|
||||
|
||||
# write rope scaling for long context (128k) model
|
||||
rope_scaling = self.find_hparam(['rope_scaling'], True)
|
||||
if rope_scaling is None:
|
||||
long_factors = self.rope_parameters.get('long_factor')
|
||||
short_factors = self.rope_parameters.get('short_factor')
|
||||
if not long_factors:
|
||||
return
|
||||
|
||||
scale = max_pos_embds / orig_max_pos_embds
|
||||
|
||||
rope_scaling_type = rope_scaling.get('rope_type', rope_scaling.get('type', '')).lower()
|
||||
rope_scaling_type = self.rope_parameters.get('rope_type', '').lower()
|
||||
if len(rope_scaling_type) == 0:
|
||||
raise KeyError('Missing the required key rope_scaling.type')
|
||||
|
||||
@@ -198,9 +199,6 @@ class Phi3MiniModel(TextModel):
|
||||
|
||||
self.gguf_writer.add_rope_scaling_attn_factors(attn_factor)
|
||||
|
||||
long_factors = rope_scaling.get('long_factor', None)
|
||||
short_factors = rope_scaling.get('short_factor', None)
|
||||
|
||||
if long_factors is None or short_factors is None:
|
||||
raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor')
|
||||
|
||||
|
||||
+1
-1
@@ -280,7 +280,7 @@ class Qwen3NextModel(Qwen2MoeModel):
|
||||
self.gguf_writer.add_full_attention_interval(self.hparams.get("full_attention_interval", 4))
|
||||
if (rope_dim := self.hparams.get("head_dim")) is None:
|
||||
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
||||
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.25)))
|
||||
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.rope_parameters.get("partial_rotary_factor", 0.25)))
|
||||
|
||||
@classmethod
|
||||
def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None:
|
||||
|
||||
@@ -28,7 +28,7 @@ class StableLMModel(TextModel):
|
||||
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
|
||||
rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"])
|
||||
rotary_factor = self.rope_parameters["partial_rotary_factor"]
|
||||
self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"])))
|
||||
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count_kv(hparams["num_key_value_heads"])
|
||||
|
||||
+1
-1
@@ -314,7 +314,7 @@ class Step35Model(TextModel):
|
||||
factor = float(rope_params.get("factor", 8.0))
|
||||
low_freq_factor = float(rope_params.get("low_freq_factor", 1.0))
|
||||
high_freq_factor = float(rope_params.get("high_freq_factor", 4.0))
|
||||
old_context_len = int(rope_params.get("original_max_position_embeddings", self.hparams.get("original_max_position_embeddings", 8192)))
|
||||
old_context_len = int(rope_params.get("original_max_position_embeddings", 8192))
|
||||
|
||||
low_freq_wavelen = old_context_len / low_freq_factor
|
||||
high_freq_wavelen = old_context_len / high_freq_factor
|
||||
|
||||
+1
-1
@@ -29,7 +29,7 @@ With Termux, you can install and run `llama.cpp` as if the environment were Linu
|
||||
|
||||
```
|
||||
$ apt update && apt upgrade -y
|
||||
$ apt install git cmake
|
||||
$ apt install git cmake libandroid-spawn
|
||||
```
|
||||
|
||||
Then, follow the [build instructions](https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md), specifically for CMake.
|
||||
|
||||
@@ -198,18 +198,18 @@ class BuiltinRule:
|
||||
SPACE_RULE = '| " " | "\\n"{1,2} [ \\t]{0,20}'
|
||||
|
||||
PRIMITIVE_RULES = {
|
||||
'boolean' : BuiltinRule('("true" | "false") space', []),
|
||||
'boolean' : BuiltinRule('("true" | "false")', []),
|
||||
'decimal-part' : BuiltinRule('[0-9]{1,16}', []),
|
||||
'integral-part': BuiltinRule('[0] | [1-9] [0-9]{0,15}', []),
|
||||
'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']),
|
||||
'integer' : BuiltinRule('("-"? integral-part) space', ['integral-part']),
|
||||
'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?', ['integral-part', 'decimal-part']),
|
||||
'integer' : BuiltinRule('("-"? integral-part)', ['integral-part']),
|
||||
'value' : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']),
|
||||
'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']),
|
||||
'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']),
|
||||
'uuid' : BuiltinRule(r'"\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\"" space', []),
|
||||
'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? space "}"', ['string', 'value']),
|
||||
'array' : BuiltinRule('"[" space ( value ("," space value)* )? space "]"', ['value']),
|
||||
'uuid' : BuiltinRule(r'"\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\""', []),
|
||||
'char' : BuiltinRule(r'[^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})', []),
|
||||
'string' : BuiltinRule(r'"\"" char* "\"" space', ['char']),
|
||||
'null' : BuiltinRule('"null" space', []),
|
||||
'string' : BuiltinRule(r'"\"" char* "\""', ['char']),
|
||||
'null' : BuiltinRule('"null"', []),
|
||||
}
|
||||
|
||||
# TODO: support "uri", "email" string formats
|
||||
@@ -217,9 +217,9 @@ STRING_FORMAT_RULES = {
|
||||
'date' : BuiltinRule('[0-9]{4} "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []),
|
||||
'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9]{3} )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []),
|
||||
'date-time' : BuiltinRule('date "T" time', ['date', 'time']),
|
||||
'date-string' : BuiltinRule('"\\"" date "\\"" space', ['date']),
|
||||
'time-string' : BuiltinRule('"\\"" time "\\"" space', ['time']),
|
||||
'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']),
|
||||
'date-string' : BuiltinRule('"\\"" date "\\""', ['date']),
|
||||
'time-string' : BuiltinRule('"\\"" time "\\""', ['time']),
|
||||
'date-time-string': BuiltinRule('"\\"" date-time "\\""', ['date-time']),
|
||||
}
|
||||
|
||||
DOTALL = '[\\U00000000-\\U0010FFFF]'
|
||||
@@ -319,7 +319,7 @@ class SchemaConverter:
|
||||
out.append(f'[^"{"".join(rejects)}] {char_rule}*')
|
||||
visit(trie)
|
||||
|
||||
out.append(f' ){"" if trie.is_end_of_string else "?"} ["] space')
|
||||
out.append(f' ){"" if trie.is_end_of_string else "?"} ["]')
|
||||
return ''.join(out)
|
||||
|
||||
def _add_rule(self, name, rule):
|
||||
@@ -549,7 +549,7 @@ class SchemaConverter:
|
||||
return self._add_rule(
|
||||
name,
|
||||
to_rule(transform()) if self._raw_pattern \
|
||||
else "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space")
|
||||
else "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\"")
|
||||
|
||||
|
||||
def _resolve_ref(self, ref):
|
||||
@@ -580,10 +580,10 @@ class SchemaConverter:
|
||||
return self._add_rule(rule_name, self._generate_union_rule(name, [{**schema, 'type': t} for t in schema_type]))
|
||||
|
||||
elif 'const' in schema:
|
||||
return self._add_rule(rule_name, self._generate_constant_rule(schema['const']) + ' space')
|
||||
return self._add_rule(rule_name, self._generate_constant_rule(schema['const']))
|
||||
|
||||
elif 'enum' in schema:
|
||||
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + ') space'
|
||||
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + ')'
|
||||
return self._add_rule(rule_name, rule)
|
||||
|
||||
elif schema_type in (None, 'object') and \
|
||||
@@ -624,7 +624,7 @@ class SchemaConverter:
|
||||
enum_intersection &= s
|
||||
|
||||
if enum_intersection:
|
||||
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in sorted(enum_intersection))) + ') space'
|
||||
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in sorted(enum_intersection))) + ')'
|
||||
return self._add_rule(rule_name, rule)
|
||||
|
||||
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None))
|
||||
@@ -638,12 +638,12 @@ class SchemaConverter:
|
||||
' "," space '.join(
|
||||
self.visit(item, f'{name}{"-" if name else ""}tuple-{i}')
|
||||
for i, item in enumerate(items)) +
|
||||
' "]" space')
|
||||
' space "]"')
|
||||
else:
|
||||
item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item')
|
||||
min_items = schema.get("minItems", 0)
|
||||
max_items = schema.get("maxItems")
|
||||
return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' "]" space')
|
||||
return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' space "]"')
|
||||
|
||||
elif schema_type in (None, 'string') and 'pattern' in schema:
|
||||
return self._visit_pattern(schema['pattern'], rule_name)
|
||||
@@ -663,7 +663,7 @@ class SchemaConverter:
|
||||
min_len = schema.get('minLength', 0)
|
||||
max_len = schema.get('maxLength')
|
||||
|
||||
return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space')
|
||||
return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\""')
|
||||
|
||||
elif schema_type in (None, 'integer') and \
|
||||
('minimum' in schema or 'exclusiveMinimum' in schema or 'maximum' in schema or 'exclusiveMaximum' in schema):
|
||||
@@ -680,7 +680,7 @@ class SchemaConverter:
|
||||
|
||||
out = ["("]
|
||||
_generate_min_max_int(min_value, max_value, out)
|
||||
out.append(") space")
|
||||
out.append(")")
|
||||
return self._add_rule(rule_name, ''.join(out))
|
||||
|
||||
elif (schema_type == 'object') or (len(schema) == 0):
|
||||
@@ -765,7 +765,7 @@ class SchemaConverter:
|
||||
rule += ' )'
|
||||
rule += ' )?'
|
||||
|
||||
rule += ' "}" space'
|
||||
rule += ' space "}"'
|
||||
|
||||
return rule
|
||||
|
||||
|
||||
+1
-1
@@ -5,7 +5,7 @@ project("ggml" C CXX ASM)
|
||||
### GGML Version
|
||||
set(GGML_VERSION_MAJOR 0)
|
||||
set(GGML_VERSION_MINOR 15)
|
||||
set(GGML_VERSION_PATCH 1)
|
||||
set(GGML_VERSION_PATCH 2)
|
||||
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
|
||||
|
||||
@@ -2417,15 +2417,14 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
|
||||
// Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size
|
||||
GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size);
|
||||
|
||||
parallel_for_ggml(params, n_batch, [&](int begin, int end) {
|
||||
for (int batch_idx = begin; batch_idx < end; ++batch_idx) {
|
||||
parallel_for_ggml(params, n_batch * M, [&](int begin, int end) {
|
||||
for (int idx = begin; idx < end; ++idx) {
|
||||
int batch_idx = idx / M;
|
||||
int m = idx % M;
|
||||
int64_t src1_offset = ggml_batch_offset(src1, batch_idx, ne2);
|
||||
const float * A_data = (const float *)((const char *)src1->data + src1_offset);
|
||||
char * wdata_batch = (char *)wdata + batch_idx * M * row_size_A;
|
||||
|
||||
for (int m = 0; m < M; ++m) {
|
||||
from_float<vec_dot_type>(A_data + m * K, wdata_batch + m * row_size_A, K);
|
||||
}
|
||||
from_float<vec_dot_type>(A_data + m * K, wdata_batch + m * row_size_A, K);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
@@ -2345,7 +2345,7 @@ class tinyBLAS_Q0_PPC {
|
||||
else if (n_aligned % 16 == 0) nc = 16;
|
||||
else nc = 8;
|
||||
}
|
||||
bool can_use_tiled = n_aligned > 0 && (m % mc == 0) && (k % kc == 0);
|
||||
bool can_use_tiled = n_aligned > 0 && (m % mc == 0);
|
||||
if (can_use_tiled) {
|
||||
matmul_tiled(m, n_aligned, mc, nc, kc);
|
||||
if (n > n_aligned) {
|
||||
@@ -3063,13 +3063,14 @@ class tinyBLAS_Q0_PPC {
|
||||
int64_t ii = (job / xtiles) * mc;
|
||||
int64_t jj = (job % xtiles) * nc;
|
||||
for (int64_t kk = 0; kk < k; kk += kc) {
|
||||
int64_t k_cur = MIN(kc, k - kk);
|
||||
if constexpr(is_Ablock_q4) {
|
||||
packNormal_q4_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack);
|
||||
packNormal_q4_fp16(A + ii * lda + kk, lda, mc, k_cur, (uint8_t *)A_pack);
|
||||
} else {
|
||||
packNormal_q8_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack);
|
||||
packNormal_q8_fp16(A + ii * lda + kk, lda, mc, k_cur, (uint8_t *)A_pack);
|
||||
}
|
||||
packNormal_q8_fp16(B + jj * ldb + kk, ldb, nc, kc, (uint8_t *)B_pack);
|
||||
KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack);
|
||||
packNormal_q8_fp16(B + jj * ldb + kk, ldb, nc, k_cur, (uint8_t *)B_pack);
|
||||
KERNEL_Q0(ii, jj, mc, nc, k_cur, kk, A_pack, B_pack);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
#include "col2im-1d.cuh"
|
||||
#include "convert.cuh"
|
||||
|
||||
// col2im_1d: scatter-add GEMM columns to 1D signal (gather approach)
|
||||
// columns: [K*OC, T_in] -> output: [T_out, OC]
|
||||
// Supports F32, F16, BF16 data with F32 accumulator.
|
||||
|
||||
template <typename T>
|
||||
static __global__ void col2im_1d_kernel(
|
||||
const T * __restrict__ col,
|
||||
T * __restrict__ dst,
|
||||
const int T_in, const uint3 T_out_fd,
|
||||
const int OC, const int K, const int K_OC,
|
||||
const int s0, const int p0, const int total) {
|
||||
|
||||
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (idx >= total) return;
|
||||
|
||||
// dst layout: [T_out, OC], ne[0]=T_out fastest
|
||||
const uint2 qr = fast_div_modulo((uint32_t)idx, T_out_fd); // qr.x = idx / T_out, qr.y = idx % T_out
|
||||
const int oc = (int)qr.x;
|
||||
const int t_out = (int)qr.y;
|
||||
const int t_abs = t_out + p0; // absolute position in uncropped signal
|
||||
|
||||
// Gather: find all (t_in, k) where t_in*s + k == t_abs, 0 <= k < K
|
||||
int t_in_min = (t_abs - K + s0) / s0; // ceil((t_abs - K + 1) / s)
|
||||
if (t_in_min < 0) t_in_min = 0;
|
||||
int t_in_max = t_abs / s0;
|
||||
if (t_in_max >= T_in) t_in_max = T_in - 1;
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int t_in = t_in_min; t_in <= t_in_max; t_in++) {
|
||||
const int k = t_abs - t_in * s0;
|
||||
// col layout: [K*OC, T_in], column index = oc * K + k
|
||||
sum += ggml_cuda_cast<float>(col[(oc * K + k) + t_in * K_OC]);
|
||||
}
|
||||
|
||||
dst[idx] = ggml_cuda_cast<T>(sum);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_col2im_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
||||
const int32_t OC = ((const int32_t *)(dst->op_params))[1];
|
||||
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
|
||||
|
||||
const int K_OC = (int) src0->ne[0];
|
||||
const int T_in = (int) src0->ne[1];
|
||||
const int K = K_OC / OC;
|
||||
const int T_out = (int) dst->ne[0];
|
||||
|
||||
const uint3 T_out_fd = init_fastdiv_values((uint32_t)T_out);
|
||||
|
||||
const int total = T_out * OC;
|
||||
const int block_size = 256;
|
||||
const int num_blocks = (total + block_size - 1) / block_size;
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: {
|
||||
col2im_1d_kernel<<<num_blocks, block_size, 0, stream>>>(
|
||||
(const float *)src0->data, (float *)dst->data,
|
||||
T_in, T_out_fd, OC, K, K_OC, s0, p0, total);
|
||||
} break;
|
||||
case GGML_TYPE_F16: {
|
||||
col2im_1d_kernel<<<num_blocks, block_size, 0, stream>>>(
|
||||
(const half *)src0->data, (half *)dst->data,
|
||||
T_in, T_out_fd, OC, K, K_OC, s0, p0, total);
|
||||
} break;
|
||||
case GGML_TYPE_BF16: {
|
||||
col2im_1d_kernel<<<num_blocks, block_size, 0, stream>>>(
|
||||
(const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data,
|
||||
T_in, T_out_fd, OC, K, K_OC, s0, p0, total);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("col2im_1d: unsupported type");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_op_col2im_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "ggml-cuda/argsort.cuh"
|
||||
#include "ggml-cuda/binbcast.cuh"
|
||||
#include "ggml-cuda/clamp.cuh"
|
||||
#include "ggml-cuda/col2im-1d.cuh"
|
||||
#include "ggml-cuda/concat.cuh"
|
||||
#include "ggml-cuda/conv-transpose-1d.cuh"
|
||||
#include "ggml-cuda/conv2d.cuh"
|
||||
@@ -3051,6 +3052,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
ggml_cuda_op_conv_transpose_1d(ctx,dst);
|
||||
break;
|
||||
case GGML_OP_COL2IM_1D:
|
||||
ggml_cuda_op_col2im_1d(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_POOL_2D:
|
||||
ggml_cuda_op_pool2d(ctx, dst);
|
||||
break;
|
||||
@@ -5316,6 +5320,14 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
}
|
||||
return false;
|
||||
} break;
|
||||
case GGML_OP_COL2IM_1D:
|
||||
{
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
return (src0_type == GGML_TYPE_F32 || src0_type == GGML_TYPE_F16 || src0_type == GGML_TYPE_BF16) &&
|
||||
op->type == src0_type &&
|
||||
ggml_is_contiguous(op->src[0]) &&
|
||||
ggml_is_contiguous(op);
|
||||
} break;
|
||||
case GGML_OP_SILU_BACK:
|
||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
break;
|
||||
|
||||
@@ -183,24 +183,25 @@ static inline void hvx_transpose_32x32_f32(HVX_Vector m[32]) {
|
||||
// transposed into VTCM.
|
||||
//
|
||||
// VTCM layouts (per thread):
|
||||
// src1_T : {d_inner_per_thread, d_conv} — staged once per launch (small).
|
||||
// src0_T : {d_inner_tile, ncs} — staged per d_inner-tile.
|
||||
// src1_T : {d_inner_stride, d_conv} - staged once per launch (small).
|
||||
// src0_T : {d_inner_tile, ncs} - staged per d_inner-tile.
|
||||
//
|
||||
// d_inner_tile is chosen so that per-thread VTCM stays under the budget.
|
||||
// Each thread iterates ceil(d_inner_per_thread d_inner_tile) tiles serially.
|
||||
#define HTP_SSM_CONV_VTCM_BUDGET (1u << 20) // 1 MiB per thread
|
||||
|
||||
// Scalar transpose: src1 {d_conv, d_inner} (DDR) -> {d_inner_per_thread, d_conv} (VTCM)
|
||||
// Scalar transpose: src1 {d_conv, d_inner} (DDR) -> {d_inner_stride, d_conv} (VTCM)
|
||||
static inline void transpose_src1(const float * src1_data,
|
||||
uint32_t src1_stride_inner,
|
||||
uint32_t i1_off,
|
||||
uint32_t d_inner_per_thread,
|
||||
uint32_t d_inner_stride,
|
||||
uint32_t d_conv,
|
||||
float * src1_T) {
|
||||
for (uint32_t i = 0; i < d_inner_per_thread; ++i) {
|
||||
const float * src_row = src1_data + (i1_off + i) * src1_stride_inner;
|
||||
for (uint32_t j = 0; j < d_conv; ++j) {
|
||||
src1_T[j * d_inner_per_thread + i] = src_row[j];
|
||||
src1_T[j * d_inner_stride + i] = src_row[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -280,6 +281,7 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void
|
||||
}
|
||||
|
||||
const uint32_t d_inner_per_thread = ir1 - ir0;
|
||||
const uint32_t d_inner_stride = scctx->nrows_per_thread;
|
||||
const uint32_t d_inner_tile = scctx->d_inner_tile;
|
||||
|
||||
const float * src0_data = (const float *) src0->data;
|
||||
@@ -290,8 +292,8 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void
|
||||
float * src0_T = (float *)(octx->src0_spad.data + ith * octx->src0_spad.size_per_thread);
|
||||
float * src1_T = (float *)(octx->src1_spad.data + ith * octx->src1_spad.size_per_thread);
|
||||
|
||||
// Stage src1 weights once into VTCM in {d_inner_per_thread, d_conv} layout.
|
||||
transpose_src1(src1_data, src1_stride_inner, ir0, d_inner_per_thread, d_conv, src1_T);
|
||||
// Stage src1 weights once into VTCM in {d_inner_stride, d_conv} layout.
|
||||
transpose_src1(src1_data, src1_stride_inner, ir0, d_inner_per_thread, d_inner_stride, d_conv, src1_T);
|
||||
|
||||
const uint32_t C_TILE = VLEN_FP32;
|
||||
|
||||
@@ -314,7 +316,7 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void
|
||||
HVX_Vector acc = hvx_vec_splat_f32(0.0f);
|
||||
for (uint32_t j = 0; j < d_conv; ++j) {
|
||||
HVX_Vector x = *(const HVX_Vector *) (src0_T + (t + j) * d_inner_tile + cb);
|
||||
HVX_Vector w = *(const HVX_Vector *) (src1_T + j * d_inner_per_thread + tile_off + cb);
|
||||
HVX_Vector w = *(const HVX_Vector *) (src1_T + j * d_inner_stride + tile_off + cb);
|
||||
acc = Q6_Vqf32_vadd_Vqf32Vqf32(acc, Q6_Vqf32_vmpy_VsfVsf(x, w));
|
||||
}
|
||||
HVX_Vector res = Q6_Vsf_equals_Vqf32(acc);
|
||||
@@ -362,8 +364,7 @@ int op_ssm_conv_f32(struct htp_ops_context * octx) {
|
||||
use_hvx = 1;
|
||||
}
|
||||
|
||||
scctx.nrows_per_thread = (d_inner + n_threads - 1) / n_threads;
|
||||
scctx.nrows_per_thread += (scctx.nrows_per_thread & 1);
|
||||
scctx.nrows_per_thread = hex_round_up((d_inner + n_threads - 1) / n_threads, VLEN_FP32);
|
||||
|
||||
const uint32_t d_inner_per_thread = scctx.nrows_per_thread;
|
||||
const uint32_t ncs = src0->ne[0];
|
||||
|
||||
@@ -174,7 +174,7 @@ __kernel void kernel_gemv_noshuffle_q8_0_f32(
|
||||
regA.s6 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
|
||||
regA.s7 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
|
||||
|
||||
dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, regS, regB);
|
||||
dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, convert_float(regS), regB);
|
||||
}
|
||||
|
||||
// reduction in local memory, assumes #wave=4
|
||||
|
||||
@@ -293,6 +293,11 @@ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_t
|
||||
(sycl::ext::oneapi::bfloat16 *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2,
|
||||
ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, ggml_is_contiguous(src0),
|
||||
ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream);
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_BF16) {
|
||||
op()((const sycl::ext::oneapi::bfloat16 *) src0->data, (const float *) src1->data,
|
||||
(sycl::ext::oneapi::bfloat16 *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2,
|
||||
ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, ggml_is_contiguous(src0),
|
||||
ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream);
|
||||
#endif
|
||||
} else {
|
||||
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, ggml_type_name(dst->type),
|
||||
|
||||
@@ -43,14 +43,44 @@ static __dpct_inline__ T op_sgn(T x) {
|
||||
return x > static_cast<T>(0.f) ? static_cast<T>(1.f) : ((x < static_cast<T>(0.f) ? static_cast<T>(-1.f) : static_cast<T>(0.f)));
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_abs(T x) {
|
||||
return sycl::fabs(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::fabs(x); // or experimental namespace if needed
|
||||
} else {
|
||||
return sycl::fabs(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_expm1(T x) {
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return static_cast<sycl::ext::oneapi::bfloat16>(
|
||||
sycl::expm1(static_cast<float>(x))
|
||||
);
|
||||
} else {
|
||||
return sycl::expm1(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_elu(T x) {
|
||||
return (x > static_cast<T>(0.f)) ? x : sycl::expm1(x);
|
||||
return (x > static_cast<T>(0.f)) ? x : op_expm1(x);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_tanh(T x) {
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
constexpr int ver = __INTEL_LLVM_COMPILER;
|
||||
#if defined(__INTEL_LLVM_COMPILER) && (__INTEL_LLVM_COMPILER >= 20260000)
|
||||
return sycl::ext::oneapi::experimental::tanh(x);
|
||||
#else
|
||||
return static_cast<T>(sycl::tanh(static_cast<float>(x)));
|
||||
#endif
|
||||
} else {
|
||||
return sycl::tanh(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
@@ -59,74 +89,106 @@ static __dpct_inline__ T op_gelu(T x) {
|
||||
const T SQRT_2_OVER_PI = static_cast<T>(0.79788456080286535587989211986876f);
|
||||
return static_cast<T>(0.5f) * x *
|
||||
(static_cast<T>(1.0f) +
|
||||
sycl::tanh(SQRT_2_OVER_PI * x * (static_cast<T>(1.0f) + GELU_COEF_A * x * x)));
|
||||
op_tanh(SQRT_2_OVER_PI * x * (static_cast<T>(1.0f) + GELU_COEF_A * x * x)));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_exp(T x) {
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::exp(x);
|
||||
} else {
|
||||
return sycl::exp(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_silu(T x) {
|
||||
return x / (static_cast<T>(1.0f) + sycl::native::exp(-x));
|
||||
return x / (static_cast<T>(1.0f) + op_exp(-x));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_gelu_quick(T x) {
|
||||
const T GELU_QUICK_COEF_LOCAL = static_cast<T>(-1.702f);
|
||||
return x * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(GELU_QUICK_COEF_LOCAL * x)));
|
||||
static __dpct_inline__ T op_erf(T x) {
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return static_cast<sycl::ext::oneapi::bfloat16>(
|
||||
sycl::erf(static_cast<float>(x))
|
||||
);
|
||||
} else {
|
||||
return sycl::erf(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_gelu_erf(T x) {
|
||||
const T SQRT_2_INV = static_cast<T>(0.70710678118654752440084436210484f);
|
||||
return static_cast<T>(0.5f) * x * (static_cast<T>(1.0f) + sycl::erf(x * SQRT_2_INV));
|
||||
return static_cast<T>(0.5f) * x * (static_cast<T>(1.0f) + op_erf(x * SQRT_2_INV));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_tanh(T x) {
|
||||
return sycl::tanh(x);
|
||||
static __dpct_inline__ T op_gelu_quick(T x) {
|
||||
const T GELU_QUICK_COEF_LOCAL = static_cast<T>(-1.702f);
|
||||
return x * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + op_exp(GELU_QUICK_COEF_LOCAL * x)));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_relu(T x) {
|
||||
return sycl::fmax(x, static_cast<T>(0));
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::fmax(x, static_cast<T>(0));
|
||||
} else {
|
||||
return sycl::fmax(x, static_cast<T>(0));
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_sigmoid(T x) {
|
||||
return static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(-x));
|
||||
return static_cast<T>(1.0f) / (static_cast<T>(1.0f) + op_exp(-x));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_sqrt(T x) {
|
||||
return sycl::sqrt(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::sqrt(x);
|
||||
} else {
|
||||
return sycl::sqrt(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_sin(T x) {
|
||||
return sycl::sin(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::sin(x);
|
||||
} else {
|
||||
return sycl::sin(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_cos(T x) {
|
||||
return sycl::cos(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::cos(x);
|
||||
} else {
|
||||
return sycl::cos(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_hardsigmoid(T x) {
|
||||
return sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::fmin(
|
||||
static_cast<T>(1.0f), sycl::ext::oneapi::experimental::fmax(
|
||||
static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
||||
} else {
|
||||
return sycl::fmin(static_cast<T>(1.0f),
|
||||
sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_hardswish(T x) {
|
||||
return x * sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_exp(T x) {
|
||||
return sycl::exp(x);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_expm1(T x) {
|
||||
return sycl::expm1(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return x * sycl::ext::oneapi::experimental::fmin(static_cast<T>(1.0f), sycl::ext::oneapi::experimental::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
||||
} else {
|
||||
return x * sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
@@ -134,13 +196,17 @@ static __dpct_inline__ T op_log(T x) {
|
||||
if (x <= static_cast<T>(0)) {
|
||||
return neg_infinity<T>();
|
||||
}
|
||||
return sycl::log(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::log(x);
|
||||
} else {
|
||||
return sycl::log(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_softplus(T x) {
|
||||
const float xf = (float) x;
|
||||
const float ax = sycl::fabs(xf);
|
||||
const float ax = op_abs(xf);
|
||||
const float m = sycl::fmax(xf, 0.0f);
|
||||
const float y = m + sycl::log1p(sycl::exp(-ax));
|
||||
return (T) y;
|
||||
@@ -159,8 +225,14 @@ static __dpct_inline__ T op_step(T x) {
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_leaky_relu(T x, float negative_slope) {
|
||||
T neg_slope_T = static_cast<T>(negative_slope);
|
||||
return sycl::fmax(x, static_cast<T>(0)) +
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::fmax(x, static_cast<T>(0)) +
|
||||
sycl::ext::oneapi::experimental::fmin(x, static_cast<T>(0.0f)) * neg_slope_T;
|
||||
|
||||
} else {
|
||||
return sycl::fmax(x, static_cast<T>(0)) +
|
||||
sycl::fmin(x, static_cast<T>(0.0f)) * neg_slope_T;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
@@ -175,22 +247,40 @@ static __dpct_inline__ T op_clamp(T x, float min_val, float max_val) {
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_floor(T x) {
|
||||
return sycl::floor(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::floor(x);
|
||||
} else {
|
||||
return sycl::floor(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_ceil(T x) {
|
||||
return sycl::ceil(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::ceil(x);
|
||||
} else {
|
||||
return sycl::ceil(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_round(T x) {
|
||||
return sycl::round(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return static_cast<sycl::ext::oneapi::bfloat16>(
|
||||
sycl::round(static_cast<float>(x))
|
||||
);
|
||||
} else {
|
||||
return sycl::round(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_trunc(T x) {
|
||||
return sycl::trunc(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::trunc(x);
|
||||
} else {
|
||||
return sycl::trunc(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename F>
|
||||
@@ -339,7 +429,7 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst,
|
||||
const int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) {
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset);
|
||||
});
|
||||
}
|
||||
@@ -354,8 +444,8 @@ static void arange_kernel(T * dst, const int k, T start, T step,
|
||||
|
||||
template<typename KernelInvoker, typename... Args>
|
||||
static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16 || dst->src[0]->type == GGML_TYPE_BF16);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_BF16);
|
||||
GGML_ASSERT(dst->src[0]->type == dst->type);
|
||||
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
@@ -367,6 +457,14 @@ static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx,
|
||||
kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
#ifdef GGML_SYCL_HAS_BF16
|
||||
case GGML_TYPE_BF16:
|
||||
{
|
||||
auto data_pts = cast_data<sycl::ext::oneapi::bfloat16>(dst);
|
||||
kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
auto data_pts = cast_data<float>(dst);
|
||||
@@ -480,7 +578,7 @@ static inline void ggml_sycl_op_unary(
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
||||
sycl::range<1>(256)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
unary_op_generic_kernel(
|
||||
src, dst_ptr, k_elements,
|
||||
ne0, ne1, ne2, ne3,
|
||||
@@ -508,7 +606,7 @@ static inline void ggml_sycl_op_arange(ggml_backend_sycl_context & ctx, ggml_ten
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
arange_kernel(dst_ptr, k, start, step, item_ct1);
|
||||
});
|
||||
}
|
||||
@@ -602,7 +700,7 @@ static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_EXP_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
unary_op_log_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -640,7 +738,7 @@ static inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tenso
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQRT_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_SQRT_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
unary_op_sqrt_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -653,7 +751,7 @@ static inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
unary_op_sin_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -666,7 +764,7 @@ static inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
unary_op_cos_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -681,7 +779,7 @@ static inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
unary_op_leaky_relu_kernel(src, dst_ptr, k_elements, slope, item_ct1);
|
||||
});
|
||||
}, negative_slope);
|
||||
@@ -694,7 +792,7 @@ static inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQR_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_SQR_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
unary_op_sqr_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -711,7 +809,7 @@ static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tens
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
clamp(src, dst_ptr, min_arg, max_arg, k_elements, item_ct1);
|
||||
});
|
||||
}, min_val, max_val);
|
||||
@@ -774,7 +872,8 @@ static inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tens
|
||||
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
||||
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
|
||||
main_stream->parallel_for(
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
|
||||
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_op_fused_geglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -785,7 +884,8 @@ static inline void ggml_sycl_op_reglu(ggml_backend_sycl_context & ctx, ggml_tens
|
||||
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
||||
const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_RELU_BLOCK_SIZE); // Using RELU block size for reglu
|
||||
main_stream->parallel_for(
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
|
||||
sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_op_fused_reglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -796,7 +896,8 @@ static inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_ten
|
||||
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
||||
const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_SILU_BLOCK_SIZE); // Using SILU block size for swiglu
|
||||
main_stream->parallel_for(
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)),
|
||||
sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_op_fused_swiglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -811,7 +912,6 @@ __dpct_inline__ float ggml_sycl_op_swiglu_oai_single(float x, float g, float alp
|
||||
return out_glu;
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
static void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k,
|
||||
const int64_t n, const int64_t o0, const int64_t o1,
|
||||
@@ -845,7 +945,7 @@ static void swiglu_oai_sycl(const T * x,
|
||||
const int64_t num_blocks = (k + SYCL_GLU_BLOCK_SIZE - 1) / SYCL_GLU_BLOCK_SIZE;
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
swiglu_oai_kernel(x, g, dst, k, n, o0, o1, alpha, limit, item_ct1);
|
||||
});
|
||||
}
|
||||
@@ -899,7 +999,8 @@ static inline void ggml_sycl_op_geglu_erf(ggml_backend_sycl_context & ctx, ggml_
|
||||
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
||||
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
|
||||
main_stream->parallel_for(
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
|
||||
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_op_fused_geglu_erf(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -910,7 +1011,8 @@ static inline void ggml_sycl_op_geglu_quick(ggml_backend_sycl_context & ctx, ggm
|
||||
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
||||
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
|
||||
main_stream->parallel_for(
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
|
||||
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_op_fused_geglu_quick(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -108,6 +108,9 @@ if (Vulkan_FOUND)
|
||||
|
||||
if (GGML_VULKAN_CHECK_RESULTS)
|
||||
add_compile_definitions(GGML_VULKAN_CHECK_RESULTS)
|
||||
# the result-checking path computes a CPU reference graph via
|
||||
# ggml_graph_compute_with_ctx(), which is defined in ggml-cpu
|
||||
target_link_libraries(ggml-vulkan PRIVATE ggml-cpu)
|
||||
endif()
|
||||
|
||||
if (GGML_VULKAN_DEBUG)
|
||||
@@ -129,6 +132,8 @@ if (Vulkan_FOUND)
|
||||
|
||||
if (GGML_VULKAN_RUN_TESTS)
|
||||
add_compile_definitions(GGML_VULKAN_RUN_TESTS)
|
||||
# the test path also calls ggml_graph_compute_with_ctx() (ggml-cpu)
|
||||
target_link_libraries(ggml-vulkan PRIVATE ggml-cpu)
|
||||
endif()
|
||||
|
||||
# Set up toolchain for host compilation whether cross-compiling or not
|
||||
|
||||
@@ -905,11 +905,12 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key {
|
||||
ggml_type src0_type;
|
||||
ggml_type src1_type;
|
||||
int vectorized;
|
||||
uint32_t num_cols;
|
||||
bool use_mmvq;
|
||||
|
||||
bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const {
|
||||
return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized &&
|
||||
use_mmvq == other.use_mmvq;
|
||||
num_cols == other.num_cols && use_mmvq == other.use_mmvq;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -919,6 +920,7 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash {
|
||||
ggml_webgpu_hash_combine(seed, key.src0_type);
|
||||
ggml_webgpu_hash_combine(seed, key.src1_type);
|
||||
ggml_webgpu_hash_combine(seed, key.vectorized);
|
||||
ggml_webgpu_hash_combine(seed, key.num_cols);
|
||||
ggml_webgpu_hash_combine(seed, key.use_mmvq);
|
||||
return seed;
|
||||
}
|
||||
@@ -993,11 +995,12 @@ struct ggml_webgpu_mul_mat_id_pipeline_key {
|
||||
ggml_type src0_type;
|
||||
ggml_type src1_type;
|
||||
uint32_t n_experts;
|
||||
uint32_t num_cols;
|
||||
int vectorized;
|
||||
|
||||
bool operator==(const ggml_webgpu_mul_mat_id_pipeline_key & other) const {
|
||||
return src0_type == other.src0_type && src1_type == other.src1_type && n_experts == other.n_experts &&
|
||||
vectorized == other.vectorized;
|
||||
num_cols == other.num_cols && vectorized == other.vectorized;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1007,6 +1010,7 @@ struct ggml_webgpu_mul_mat_id_pipeline_key_hash {
|
||||
ggml_webgpu_hash_combine(seed, key.src0_type);
|
||||
ggml_webgpu_hash_combine(seed, key.src1_type);
|
||||
ggml_webgpu_hash_combine(seed, key.n_experts);
|
||||
ggml_webgpu_hash_combine(seed, key.num_cols);
|
||||
ggml_webgpu_hash_combine(seed, key.vectorized);
|
||||
return seed;
|
||||
}
|
||||
@@ -1107,7 +1111,7 @@ inline bool ggml_webgpu_can_use_mmvq(const ggml_tensor * src0,
|
||||
const ggml_tensor * src1,
|
||||
bool supports_dot_product,
|
||||
const std::string & vendor) {
|
||||
if (src1->ne[1] == 1) {
|
||||
if (src1->ne[1] <= 4) {
|
||||
bool supports_dp4a = vendor == "amd" || vendor == "intel" || vendor == "nvidia";
|
||||
if (supports_dp4a && supports_dot_product) {
|
||||
switch (src1->type) {
|
||||
@@ -1889,6 +1893,7 @@ class ggml_webgpu_shader_lib {
|
||||
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
||||
1 :
|
||||
0;
|
||||
key.num_cols = context.dst->ne[1];
|
||||
key.use_mmvq =
|
||||
ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor);
|
||||
|
||||
@@ -2004,6 +2009,7 @@ class ggml_webgpu_shader_lib {
|
||||
if (key.vectorized) {
|
||||
variant += "_vectorized";
|
||||
}
|
||||
defines.push_back(std::string("NUM_COLS=") + std::to_string(key.num_cols));
|
||||
|
||||
auto processed = preprocessor.preprocess(shader_src, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_mul_mat_vec_shader_decisions>();
|
||||
@@ -2421,6 +2427,7 @@ class ggml_webgpu_shader_lib {
|
||||
if (key.vectorized) {
|
||||
variant += "_vectorized";
|
||||
}
|
||||
defines.push_back(std::string("NUM_COLS=1"));
|
||||
|
||||
defines.push_back(std::string("N_EXPERTS=") + std::to_string(key.n_experts));
|
||||
|
||||
|
||||
@@ -1418,15 +1418,17 @@ static void ggml_webgpu_quantize_q8_dispatch(webgpu_context &
|
||||
const size_t dst_offset = ggml_webgpu_tensor_offset(dst);
|
||||
const size_t q8_src1_align_offset = ROUNDUP_POW2(
|
||||
dst_offset + ggml_nbytes(dst), ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
|
||||
const size_t q8_src1_binding_size =
|
||||
ROUNDUP_POW2(src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)),
|
||||
WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
const size_t q8_src1_binding_size = ROUNDUP_POW2(
|
||||
src1->ne[3] * src1->ne[2] * src1->ne[1] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)),
|
||||
WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
|
||||
std::vector<uint32_t> q8_params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
|
||||
(uint32_t) src1->ne[0],
|
||||
(uint32_t) src1->ne[1],
|
||||
(uint32_t) src1->ne[2],
|
||||
(uint32_t) src1->ne[3],
|
||||
};
|
||||
@@ -1442,7 +1444,7 @@ static void ggml_webgpu_quantize_q8_dispatch(webgpu_context &
|
||||
uint32_t q8_wg_x = 1;
|
||||
uint32_t q8_wg_y = 1;
|
||||
const uint32_t wg_per_vec = (src0->ne[0] / 4 + (q8_wg_size - 1)) / q8_wg_size;
|
||||
const uint32_t q8_total_wg = src1->ne[2] * src1->ne[3] * wg_per_vec;
|
||||
const uint32_t q8_total_wg = src1->ne[1] * src1->ne[2] * src1->ne[3] * wg_per_vec;
|
||||
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
compute_2d_workgroups(q8_total_wg, max_wg_per_dim, q8_wg_x, q8_wg_y);
|
||||
|
||||
@@ -1456,7 +1458,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst) {
|
||||
// Determine if this is a mat-vec operation
|
||||
bool is_vec = (dst->ne[1] == 1);
|
||||
bool use_mat_vec = (dst->ne[1] <= 4);
|
||||
|
||||
// use MMVQ path for mat-vec
|
||||
bool use_mmvq = ggml_webgpu_can_use_mmvq(src0, src1, ctx->global_ctx->capabilities.supports_dot_product,
|
||||
@@ -1482,7 +1484,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
webgpu_pipeline pipeline;
|
||||
std::vector<webgpu_dispatch_desc> dispatches;
|
||||
|
||||
if (is_vec) {
|
||||
if (use_mat_vec) {
|
||||
if (use_mmvq) {
|
||||
ggml_webgpu_quantize_q8_dispatch(ctx, src0, src1, dst, dispatches);
|
||||
}
|
||||
@@ -1529,7 +1531,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
uint32_t wg_y = 1;
|
||||
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
|
||||
if (is_vec) {
|
||||
if (use_mat_vec) {
|
||||
auto * decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
uint32_t batches = dst->ne[2] * dst->ne[3];
|
||||
@@ -3691,8 +3693,8 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
|
||||
ggml_webgpu_can_use_mmvq(src0, src1, ctx->webgpu_global_ctx->capabilities.supports_dot_product,
|
||||
ctx->webgpu_global_ctx->vendor);
|
||||
if (use_mmvq) {
|
||||
const size_t q8_src1_size =
|
||||
src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32));
|
||||
const size_t q8_src1_size = src1->ne[3] * src1->ne[2] * src1->ne[1] *
|
||||
(36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32));
|
||||
res = ROUNDUP_POW2(res + q8_src1_size +
|
||||
ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment,
|
||||
WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
@@ -3788,7 +3790,7 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) {
|
||||
ctx->memset_pipeline = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants);
|
||||
}
|
||||
|
||||
static void create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
||||
static void ggml_backend_webgpu_request_adapter(wgpu::Instance & instance, wgpu::Adapter & adapter) {
|
||||
wgpu::RequestAdapterOptions options = {};
|
||||
|
||||
#ifndef __EMSCRIPTEN__
|
||||
@@ -3800,17 +3802,20 @@ static void create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
||||
options.nextInChain = &adapterTogglesDesc;
|
||||
#endif
|
||||
|
||||
ctx->webgpu_global_ctx->instance.WaitAny(
|
||||
ctx->webgpu_global_ctx->instance.RequestAdapter(
|
||||
&options, wgpu::CallbackMode::AllowSpontaneous,
|
||||
[&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
|
||||
if (status != wgpu::RequestAdapterStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
|
||||
return;
|
||||
}
|
||||
ctx->webgpu_global_ctx->adapter = std::move(adapter);
|
||||
}),
|
||||
UINT64_MAX);
|
||||
instance.WaitAny(instance.RequestAdapter(
|
||||
&options, wgpu::CallbackMode::AllowSpontaneous,
|
||||
[&adapter](wgpu::RequestAdapterStatus status, wgpu::Adapter _adapter, const char * message) {
|
||||
if (status != wgpu::RequestAdapterStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
|
||||
return;
|
||||
}
|
||||
adapter = std::move(_adapter);
|
||||
}),
|
||||
UINT64_MAX);
|
||||
}
|
||||
|
||||
static void create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
||||
ggml_backend_webgpu_request_adapter(ctx->webgpu_global_ctx->instance, ctx->webgpu_global_ctx->adapter);
|
||||
GGML_ASSERT(ctx->webgpu_global_ctx->adapter != nullptr);
|
||||
|
||||
ctx->webgpu_global_ctx->adapter.GetLimits(&ctx->webgpu_global_ctx->capabilities.limits);
|
||||
@@ -4543,20 +4548,7 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
|
||||
// Probe for adapter support
|
||||
wgpu::Adapter adapter;
|
||||
if (ctx->webgpu_global_ctx->instance != nullptr) {
|
||||
wgpu::RequestAdapterOptions options = {};
|
||||
|
||||
// probe for adapter support
|
||||
ctx->webgpu_global_ctx->instance.WaitAny(
|
||||
ctx->webgpu_global_ctx->instance.RequestAdapter(
|
||||
&options, wgpu::CallbackMode::AllowSpontaneous,
|
||||
[&adapter](wgpu::RequestAdapterStatus status, wgpu::Adapter _adapter, const char * message) {
|
||||
if (status != wgpu::RequestAdapterStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
|
||||
return;
|
||||
}
|
||||
adapter = std::move(_adapter);
|
||||
}),
|
||||
UINT64_MAX);
|
||||
ggml_backend_webgpu_request_adapter(ctx->webgpu_global_ctx->instance, adapter);
|
||||
}
|
||||
|
||||
// WebGPU backend requires f16 support and, on native, implicit device synchronization.
|
||||
|
||||
@@ -103,7 +103,7 @@ fn main(
|
||||
|
||||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let subgroup_total = subgroupAdd(acc[row]);
|
||||
let subgroup_total = subgroupAdd(acc[0][row]);
|
||||
if (subgroup_invocation_id == 0u) {
|
||||
partial_sums[partial_index(row, subgroup_id)] = subgroup_total;
|
||||
}
|
||||
@@ -126,7 +126,7 @@ fn main(
|
||||
|
||||
#ifdef USE_WORKGROUP_REDUCTION
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
partial_sums[partial_index(row, thread_id)] = acc[row];
|
||||
partial_sums[partial_index(row, thread_id)] = acc[0][row];
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
@@ -91,61 +91,67 @@ fn main(
|
||||
let dst_idx_base = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row_base;
|
||||
|
||||
#ifdef MMVQ
|
||||
let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * (params.k / 32u);
|
||||
let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * params.n * (params.k / 32u);
|
||||
let acc = accumulate_vec_q_dot(thread_id, row_base, src0_batch_offset, src1q_idx_base);
|
||||
#else
|
||||
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
|
||||
let acc = accumulate_vec_dot(thread_id, row_base, src0_batch_offset, src1_idx_base);
|
||||
#endif
|
||||
|
||||
for (var col = 0u;col < NUM_COLS;col += 1) {
|
||||
|
||||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let subgroup_total = subgroupAdd(acc[row]);
|
||||
if (subgroup_invocation_id == 0u) {
|
||||
partial_sums[partial_index(row, subgroup_id)] = subgroup_total;
|
||||
}
|
||||
}
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let subgroup_total = subgroupAdd(acc[col][row]);
|
||||
if (subgroup_invocation_id == 0u) {
|
||||
partial_sums[partial_index(row, subgroup_id)] = subgroup_total;
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
workgroupBarrier();
|
||||
|
||||
for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) {
|
||||
let output_row = row_base + row;
|
||||
var row_acc = 0.0f;
|
||||
for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) {
|
||||
row_acc += partial_sums[partial_index(row, k)];
|
||||
}
|
||||
let row_total = subgroupAdd(row_acc);
|
||||
if (subgroup_invocation_id == 0) {
|
||||
dst[dst_idx_base + row] = row_total;
|
||||
}
|
||||
}
|
||||
for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) {
|
||||
let output_row = row_base + row;
|
||||
var row_acc = 0.0f;
|
||||
for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) {
|
||||
row_acc += partial_sums[partial_index(row, k)];
|
||||
}
|
||||
let row_total = subgroupAdd(row_acc);
|
||||
if (subgroup_invocation_id == 0) {
|
||||
dst[dst_idx_base + col * params.m + row] = row_total;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_WORKGROUP_REDUCTION
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
partial_sums[partial_index(row, thread_id)] = acc[row];
|
||||
}
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
partial_sums[partial_index(row, thread_id)] = acc[col][row];
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
var stride = WG_SIZE / 2u;
|
||||
|
||||
while (stride > 0) {
|
||||
if (thread_id < stride) {
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)];
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
stride = stride / 2;
|
||||
}
|
||||
|
||||
if (thread_id < OUTPUTS_PER_WG) {
|
||||
let output_row = row_base + thread_id;
|
||||
if (output_row < params.m) {
|
||||
dst[dst_idx_base + col * params.m + thread_id] = partial_sums[partial_index(thread_id, 0)];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
var stride = WG_SIZE / 2u;
|
||||
|
||||
while (stride > 0) {
|
||||
if (thread_id < stride) {
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)];
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
stride = stride / 2;
|
||||
}
|
||||
|
||||
if (thread_id < OUTPUTS_PER_WG) {
|
||||
let output_row = row_base + thread_id;
|
||||
if (output_row < params.m) {
|
||||
dst[dst_idx_base + thread_id] = partial_sums[partial_index(thread_id, 0)];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -51,10 +51,7 @@ fn repack_b_dm(block: u32) -> B_DS_TYPE {
|
||||
fn get_dm(block_byte_base: u32) -> f32 {
|
||||
return f32(load_f16_at_src0(block_byte_base));
|
||||
}
|
||||
fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 {
|
||||
return f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK;
|
||||
}
|
||||
#endif
|
||||
#endif // MUL_ACC_Q4_0
|
||||
|
||||
#ifdef MUL_ACC_Q4_1
|
||||
#define BLOCK_SIZE_BYTES 20
|
||||
@@ -85,10 +82,7 @@ fn get_dm(block_byte_base: u32) -> vec2<f32> {
|
||||
f32(load_f16_at_src0(block_byte_base + 2u))
|
||||
);
|
||||
}
|
||||
fn mul_q8_1(row_sum: i32, dma: vec2<f32>, b_ds: B_DS_TYPE) -> f32 {
|
||||
return f32(row_sum) * (dma.x * b_ds.x) + dma.y * b_ds.y / THREADS_PER_BLOCK;
|
||||
}
|
||||
#endif
|
||||
#endif // MUL_ACC_Q4_1
|
||||
|
||||
#ifdef MUL_ACC_Q8_0
|
||||
#define BLOCK_SIZE_BYTES 34
|
||||
@@ -111,46 +105,48 @@ fn repack_b_dm(block: u32) -> B_DS_TYPE {
|
||||
fn get_dm(block_byte_base: u32) -> f32 {
|
||||
return f32(load_f16_at_src0(block_byte_base));
|
||||
}
|
||||
fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 {
|
||||
return f32(row_sum) * (da * b_ds);
|
||||
}
|
||||
#endif
|
||||
#endif // MUL_ACC_Q8_0
|
||||
|
||||
#ifdef LEGACY_QUANTS
|
||||
fn mmvq_dot_product(a_byte_base: u32, b_inner_id: u32, b_repacked: vec2<u32>, b_ds: B_DS_TYPE) -> f32 {
|
||||
var row_sum = 0;
|
||||
let a_repacked = repack_a(a_byte_base, b_inner_id);
|
||||
|
||||
row_sum += dot4I8Packed(a_repacked[0], b_repacked[0]);
|
||||
row_sum += dot4I8Packed(a_repacked[1], b_repacked[1]);
|
||||
|
||||
return mul_q8_1(row_sum, get_dm(a_byte_base), b_ds);
|
||||
}
|
||||
|
||||
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
|
||||
var acc: array<f32, OUTPUTS_PER_WG>;
|
||||
#if defined(LEGACY_QUANTS)
|
||||
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<array<f32, OUTPUTS_PER_WG>, NUM_COLS> {
|
||||
var acc: array<array<f32, OUTPUTS_PER_WG>, NUM_COLS>;
|
||||
|
||||
let num_blocks = params.k / BLOCK_SIZE;
|
||||
|
||||
for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
|
||||
let b_inner_id = thread_id % THREADS_PER_BLOCK;
|
||||
let b_block_idx = src1q_idx_base + block;
|
||||
|
||||
let b_repacked = repack_b_qs(b_block_idx, b_inner_id);
|
||||
let b_ds = repack_b_dm(b_block_idx);
|
||||
|
||||
let inner_id = thread_id % THREADS_PER_BLOCK;
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let output_row = row_base + row;
|
||||
if (output_row < params.m) {
|
||||
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
|
||||
acc[row] += mmvq_dot_product(block_byte_base, b_inner_id, b_repacked, b_ds);
|
||||
let a_repacked = repack_a(block_byte_base, inner_id);
|
||||
let da = get_dm(block_byte_base);
|
||||
for (var col = 0u;col < NUM_COLS;col += 1) {
|
||||
let src1q_idx = src1q_idx_base + col * (params.k / Q8_BLOCK_SIZE) + block;
|
||||
let b_repacked = repack_b_qs(src1q_idx, inner_id);
|
||||
let b_ds = repack_b_dm(src1q_idx);
|
||||
|
||||
let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1]);
|
||||
|
||||
#if defined(MUL_ACC_Q4_0)
|
||||
acc[col][row] += f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK;
|
||||
#endif // MUL_ACC_Q4_0
|
||||
|
||||
#if defined(MUL_ACC_Q4_1)
|
||||
acc[col][row] += f32(row_sum) * (da.x * b_ds.x) + da.y * b_ds.y / THREADS_PER_BLOCK;
|
||||
#endif // MUL_ACC_Q4_1
|
||||
|
||||
#if defined(MUL_ACC_Q8_0)
|
||||
acc[col][row] += f32(row_sum) * (da * b_ds);
|
||||
#endif // MUL_ACC_Q8_0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return acc;
|
||||
}
|
||||
#endif
|
||||
#endif // LEGACY_QUANTS
|
||||
|
||||
#ifdef MUL_ACC_Q2_K
|
||||
#define BLOCK_SIZE_BYTES 84
|
||||
@@ -191,22 +187,7 @@ fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2<f32> {
|
||||
let scale = byte_of(load_u32_at_src0_aligned(scale_byte), scale_byte & 3u);
|
||||
return vec2<f32>(f32(scale & 0xFu), f32(scale >> 4u));
|
||||
}
|
||||
fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4<u32>, b_ds: B_DS_TYPE) -> f32 {
|
||||
let a_repacked = repack_a(a_byte_base, tid);
|
||||
let dm = get_dm(a_byte_base);
|
||||
let scale_min = get_scale_min(a_byte_base, tid);
|
||||
|
||||
let scale_q = i32(scale_min.x);
|
||||
let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u;
|
||||
|
||||
let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1])
|
||||
+ dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q;
|
||||
let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4)
|
||||
+ dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4);
|
||||
|
||||
return b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m));
|
||||
}
|
||||
#endif
|
||||
#endif // MUL_ACC_Q2_K
|
||||
|
||||
#ifdef MUL_ACC_Q4_K
|
||||
#define BLOCK_SIZE_BYTES 144
|
||||
@@ -265,39 +246,52 @@ fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2<f32> {
|
||||
|
||||
return vec2<f32>(scale, min_val);
|
||||
}
|
||||
fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4<u32>, b_ds: B_DS_TYPE) -> f32 {
|
||||
let a_repacked = repack_a(a_byte_base, tid);
|
||||
let dm = get_dm(a_byte_base);
|
||||
let scale_min = get_scale_min(a_byte_base, tid);
|
||||
|
||||
let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1])
|
||||
+ dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]);
|
||||
|
||||
// Each thread covers half of the Q8_1 block, so add only b_ds.y/2.
|
||||
return b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD));
|
||||
}
|
||||
#endif
|
||||
#endif // MUL_ACC_Q4_K
|
||||
|
||||
#ifdef K_QUANTS
|
||||
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
|
||||
var acc: array<f32, OUTPUTS_PER_WG>;
|
||||
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<array<f32, OUTPUTS_PER_WG>, NUM_COLS> {
|
||||
var acc: array<array<f32, OUTPUTS_PER_WG>, NUM_COLS>;
|
||||
|
||||
let tid = thread_id % THREADS_PER_BLOCK;
|
||||
|
||||
for (var block = thread_id / THREADS_PER_BLOCK; block < params.k / BLOCK_SIZE; block += WG_SIZE / THREADS_PER_BLOCK) {
|
||||
let src1q_idx = src1q_idx_base + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE;
|
||||
let b_repacked = repack_b_qs(src1q_idx, tid);
|
||||
let b_ds = repack_b_dm(src1q_idx);
|
||||
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let output_row = row_base + row;
|
||||
if (output_row < params.m) {
|
||||
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
|
||||
acc[row] += mmvq_dot_product(block_byte_base, tid, b_repacked, b_ds);
|
||||
let a_repacked = repack_a(block_byte_base, tid);
|
||||
let dm = get_dm(block_byte_base);
|
||||
let scale_min = get_scale_min(block_byte_base, tid);
|
||||
for (var col = 0u;col < NUM_COLS;col += 1) {
|
||||
let src1q_idx = src1q_idx_base + col * (params.k / Q8_BLOCK_SIZE) + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE;
|
||||
let b_repacked = repack_b_qs(src1q_idx, tid);
|
||||
let b_ds = repack_b_dm(src1q_idx);
|
||||
|
||||
#if defined(MUL_ACC_Q2_K)
|
||||
let scale_q = i32(scale_min.x);
|
||||
let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u;
|
||||
|
||||
let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1])
|
||||
+ dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q;
|
||||
let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4)
|
||||
+ dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4);
|
||||
|
||||
acc[col][row] += b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m));
|
||||
#endif // MUL_ACC_Q2_K
|
||||
|
||||
#if defined(MUL_ACC_Q4_K)
|
||||
let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1])
|
||||
+ dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]);
|
||||
|
||||
// Each thread covers half of the Q8_1 block, so add only b_ds.y/2.
|
||||
acc[col][row] += b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD));
|
||||
#endif // MUL_ACC_Q4_K
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return acc;
|
||||
}
|
||||
#endif
|
||||
#endif // K_QUANTS
|
||||
|
||||
@@ -9,9 +9,11 @@ requires packed_4x8_integer_dot_product;
|
||||
|
||||
struct Params {
|
||||
offset_src1: u32,
|
||||
stride_11: u32,
|
||||
stride_12: u32,
|
||||
stride_13: u32,
|
||||
ne0: u32,
|
||||
ne1: u32,
|
||||
ne2: u32,
|
||||
ne3: u32,
|
||||
};
|
||||
@@ -57,25 +59,28 @@ fn main(
|
||||
@builtin(num_workgroups) num_wg: vec3<u32>
|
||||
) {
|
||||
let thread_id = local_id.x;
|
||||
let num_vec4 = params.ne0 / 4u;
|
||||
let ne0_vec4 = params.ne0 / 4u;
|
||||
|
||||
let wg_per_vec = (num_vec4 + (WG_SIZE - 1u)) / WG_SIZE;
|
||||
let total_batches = wg_per_vec * params.ne2 * params.ne3;
|
||||
let wg_per_vec = (ne0_vec4 + (WG_SIZE - 1u)) / WG_SIZE;
|
||||
let total_batches = wg_per_vec * params.ne1 * params.ne2 * params.ne3;
|
||||
|
||||
let wg_linear = wg_id.y * num_wg.x + wg_id.x;
|
||||
if (wg_linear >= total_batches) {
|
||||
return;
|
||||
}
|
||||
|
||||
let src13_idx = wg_linear / (params.ne2 * wg_per_vec);
|
||||
let src12_idx = (wg_linear - src13_idx * (params.ne2 * wg_per_vec)) / wg_per_vec;
|
||||
let src11_wg_idx = wg_linear % wg_per_vec;
|
||||
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
|
||||
let vec_idx = wg_linear / wg_per_vec;
|
||||
let src13_idx = vec_idx / (params.ne2 * params.ne1);
|
||||
let vec_ne12_num = vec_idx % (params.ne2 * params.ne1);
|
||||
let src12_idx = vec_ne12_num / params.ne1;
|
||||
let src11_idx = vec_ne12_num % params.ne1;
|
||||
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + src11_idx * params.stride_11;
|
||||
let src1_idx_vec4_base = src1_idx_base / 4u;
|
||||
|
||||
let blocks_per_row = params.ne0 / 32u;
|
||||
let blocks_per_wg = (WG_SIZE * 4u) / 32u;
|
||||
let src1q_idx_base = (src13_idx * params.ne2 + src12_idx) * blocks_per_row;
|
||||
let src1q_idx_base = ((src13_idx * params.ne2 + src12_idx) * params.ne1 + src11_idx) * blocks_per_row;
|
||||
let src11_wg_idx = wg_linear % wg_per_vec;
|
||||
let src1q_idx = src1q_idx_base + src11_wg_idx * blocks_per_wg + thread_id / 8u;
|
||||
let qs_idx = thread_id % 8u;
|
||||
|
||||
@@ -85,7 +90,7 @@ fn main(
|
||||
var thread_amax = 0.0;
|
||||
|
||||
let src11_vec4_idx = src11_wg_idx * WG_SIZE + thread_id;
|
||||
let is_valid = src11_vec4_idx < num_vec4;
|
||||
let is_valid = src11_vec4_idx < ne0_vec4;
|
||||
|
||||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
|
||||
|
||||
+7
-10
@@ -600,18 +600,15 @@ FILE * ggml_fopen(const char * fname, const char * mode) {
|
||||
// convert fname (UTF-8)
|
||||
wchar_t * wfname = ggml_mbstowcs(fname);
|
||||
if (wfname) {
|
||||
// convert mode (ANSI)
|
||||
wchar_t * wmode = GGML_MALLOC((strlen(mode) + 1) * sizeof(wchar_t));
|
||||
wchar_t * wmode_p = wmode;
|
||||
do {
|
||||
*wmode_p++ = (wchar_t)*mode;
|
||||
} while (*mode++);
|
||||
|
||||
// open file
|
||||
file = _wfopen(wfname, wmode);
|
||||
// convert mode (UTF-8)
|
||||
wchar_t * wmode = ggml_mbstowcs(mode);
|
||||
if (wmode) {
|
||||
// open file
|
||||
file = _wfopen(wfname, wmode);
|
||||
GGML_FREE(wmode);
|
||||
}
|
||||
|
||||
GGML_FREE(wfname);
|
||||
GGML_FREE(wmode);
|
||||
}
|
||||
|
||||
return file;
|
||||
|
||||
@@ -359,6 +359,7 @@ class Keys:
|
||||
CHUNK_SIZE = "clip.audio.chunk_size"
|
||||
CONV_KERNEL_SIZE = "clip.audio.conv_kernel_size"
|
||||
MAX_POS_EMB = "clip.audio.max_pos_emb"
|
||||
FEATURE_LAYERS = "clip.audio.feature_layer" # Granite Speech Plus
|
||||
|
||||
class Attention:
|
||||
HEAD_COUNT = "clip.audio.attention.head_count"
|
||||
|
||||
@@ -1310,6 +1310,9 @@ class GGUFWriter:
|
||||
def add_audio_max_pos_emb(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipAudio.MAX_POS_EMB, value)
|
||||
|
||||
def add_audio_feature_layers(self, layers: Sequence[int]) -> None:
|
||||
self.add_array(Keys.ClipAudio.FEATURE_LAYERS, layers)
|
||||
|
||||
def add_audio_projector_window_size(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipAudio.Projector.WINDOW_SIZE, value)
|
||||
|
||||
|
||||
+9
-8
@@ -558,14 +558,15 @@ extern "C" {
|
||||
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
|
||||
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
|
||||
|
||||
LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_embd_out (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_ctx_train (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_embd_out (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_layer_nextn(const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model);
|
||||
|
||||
// Get the model's RoPE frequency scaling factor
|
||||
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
|
||||
|
||||
@@ -1 +1 @@
|
||||
3af5f5760e19a96427f5f7a93b79cbdf3d4b265b
|
||||
707321c4cf6d21cb4bc831aa8b687dbf01a521ce
|
||||
|
||||
@@ -5,7 +5,7 @@ import os
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
HTTPLIB_VERSION = "refs/tags/v0.47.0"
|
||||
HTTPLIB_VERSION = "refs/tags/v0.48.0"
|
||||
|
||||
vendor = {
|
||||
"https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp",
|
||||
|
||||
@@ -1156,6 +1156,10 @@ void llama_context::set_embeddings_layer_inp(uint32_t lid, bool enable) {
|
||||
sched_need_reserve = true;
|
||||
}
|
||||
|
||||
void llama_context::set_nextn_layer_offset(int32_t offset) {
|
||||
cparams.nextn_layer_offset = offset;
|
||||
}
|
||||
|
||||
void llama_context::set_causal_attn(bool value) {
|
||||
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
|
||||
|
||||
@@ -3699,6 +3703,10 @@ void llama_set_embeddings_layer_inp(llama_context * ctx, uint32_t lid, bool valu
|
||||
ctx->set_embeddings_layer_inp(lid, value);
|
||||
}
|
||||
|
||||
void llama_set_nextn_layer_offset(llama_context * ctx, int32_t offset) {
|
||||
ctx->set_nextn_layer_offset(offset);
|
||||
}
|
||||
|
||||
llama_memory_t llama_get_memory(const struct llama_context * ctx) {
|
||||
if (!ctx) {
|
||||
return nullptr;
|
||||
|
||||
@@ -115,6 +115,7 @@ struct llama_context {
|
||||
void set_embeddings (bool value);
|
||||
void set_embeddings_nextn(bool value, bool masked);
|
||||
void set_embeddings_layer_inp(uint32_t lid, bool enable);
|
||||
void set_nextn_layer_offset(int32_t offset);
|
||||
void set_causal_attn(bool value);
|
||||
void set_warmup(bool value);
|
||||
|
||||
|
||||
@@ -18,6 +18,8 @@ struct llama_cparams {
|
||||
int32_t n_threads; // number of threads to use for generation
|
||||
int32_t n_threads_batch; // number of threads to use for batch processing
|
||||
|
||||
int32_t nextn_layer_offset = 0;
|
||||
|
||||
float rope_freq_base;
|
||||
float rope_freq_scale;
|
||||
|
||||
|
||||
@@ -95,6 +95,11 @@ LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_c
|
||||
// If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits
|
||||
LLAMA_API void llama_set_embeddings_nextn(struct llama_context * ctx, bool value, bool masked);
|
||||
|
||||
// Select which appended NextN block the DECODER_MTP graph runs (offset past
|
||||
// the trunk: il = n_layer() + offset). Used by the speculative NextN driver to
|
||||
// chain multiple trained NextN heads. Default 0 (first head).
|
||||
LLAMA_API void llama_set_nextn_layer_offset(struct llama_context * ctx, int32_t offset);
|
||||
|
||||
// mirrors:
|
||||
// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
||||
LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx);
|
||||
|
||||
+9
-2
@@ -682,9 +682,16 @@ struct llm_graph_params {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: https://github.com/ggml-org/llama.cpp/pull/24340#discussion_r3448035248
|
||||
if (cparams.nextn_layer_offset != other.cparams.nextn_layer_offset) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return
|
||||
cparams.embeddings == other.cparams.embeddings &&
|
||||
cparams.causal_attn == other.cparams.causal_attn &&
|
||||
cparams.embeddings == other.cparams.embeddings &&
|
||||
cparams.embeddings_nextn == other.cparams.embeddings_nextn &&
|
||||
cparams.embeddings_nextn_masked == other.cparams.embeddings_nextn_masked &&
|
||||
cparams.causal_attn == other.cparams.causal_attn &&
|
||||
arch == other.arch &&
|
||||
gtype == other.gtype &&
|
||||
cvec == other.cvec &&
|
||||
|
||||
@@ -2312,6 +2312,10 @@ int32_t llama_model_n_layer(const llama_model * model) {
|
||||
return model->hparams.n_layer();
|
||||
}
|
||||
|
||||
int32_t llama_model_n_layer_nextn(const llama_model * model) {
|
||||
return model->hparams.n_layer_nextn;
|
||||
}
|
||||
|
||||
int32_t llama_model_n_head(const llama_model * model) {
|
||||
return model->hparams.n_head();
|
||||
}
|
||||
|
||||
+2
-2
@@ -932,8 +932,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
|
||||
// copy the KV pairs from the input file
|
||||
gguf_set_kv (ctx_out.get(), ml.metadata);
|
||||
gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV
|
||||
gguf_set_val_u32(ctx_out.get(), "general.file_type", ftype); // TODO: use LLM_KV
|
||||
gguf_set_val_u32(ctx_out.get(), ml.llm_kv(LLM_KV_GENERAL_QUANTIZATION_VERSION).c_str(), GGML_QNT_VERSION);
|
||||
gguf_set_val_u32(ctx_out.get(), ml.llm_kv(LLM_KV_GENERAL_FILE_TYPE).c_str(), ftype);
|
||||
|
||||
// Remove split metadata
|
||||
gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str());
|
||||
|
||||
@@ -2813,8 +2813,6 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t
|
||||
cur_p->data[i].logit = -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
llama_sampler_softmax_impl(cur_p, true);
|
||||
}
|
||||
|
||||
static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) {
|
||||
|
||||
@@ -101,11 +101,11 @@ void llama_model_glm_dsa::load_arch_tensors(llama_model_loader &) {
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags);
|
||||
|
||||
// DSA indexer
|
||||
layer.indexer_k_norm = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "weight", i), {hparams.indexer_head_size}, flags);
|
||||
layer.indexer_k_norm_b = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "bias", i), {hparams.indexer_head_size}, flags);
|
||||
layer.indexer_proj = create_tensor(tn(LLM_TENSOR_INDEXER_PROJ, "weight", i), {n_embd, hparams.indexer_n_head}, flags);
|
||||
layer.indexer_attn_k = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_K, "weight", i), {n_embd, hparams.indexer_head_size}, flags);
|
||||
layer.indexer_attn_q_b = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.indexer_n_head * hparams.indexer_head_size}, flags);
|
||||
layer.indexer_k_norm = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "weight", i), {hparams.indexer_head_size}, flags | TENSOR_NOT_REQUIRED);
|
||||
layer.indexer_k_norm_b = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "bias", i), {hparams.indexer_head_size}, flags | TENSOR_NOT_REQUIRED);
|
||||
layer.indexer_proj = create_tensor(tn(LLM_TENSOR_INDEXER_PROJ, "weight", i), {n_embd, hparams.indexer_n_head}, flags | TENSOR_NOT_REQUIRED);
|
||||
layer.indexer_attn_k = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_K, "weight", i), {n_embd, hparams.indexer_head_size}, flags | TENSOR_NOT_REQUIRED);
|
||||
layer.indexer_attn_q_b = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.indexer_n_head * hparams.indexer_head_size}, flags | TENSOR_NOT_REQUIRED);
|
||||
if (i < (int) hparams.n_layer_dense_lead) {
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags);
|
||||
|
||||
@@ -156,6 +156,8 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para
|
||||
|
||||
// MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass.
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
res->t_layer_inp[il] = inpL;
|
||||
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
|
||||
|
||||
@@ -179,6 +179,8 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p
|
||||
|
||||
// MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass.
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
res->t_layer_inp[il] = inpL;
|
||||
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
|
||||
|
||||
+27
-28
@@ -112,7 +112,7 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) {
|
||||
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED);
|
||||
};
|
||||
|
||||
auto load_block_mtp = [&](int i, bool is_first_mtp) {
|
||||
auto load_block_mtp = [&](int i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
const uint32_t n_head_l = hparams.n_head(i);
|
||||
@@ -121,15 +121,12 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) {
|
||||
|
||||
// The MTP block is a full Step3p5 decoder layer (mtp_block) plus the
|
||||
// NextN-specific wiring (enorm/hnorm/eh_proj + optional shared head).
|
||||
// `mtp_flags` becomes NOT_REQUIRED when the GGUF is trunk-only.
|
||||
//
|
||||
// Only the FIRST MTP block (i == n_main) is required for the
|
||||
// single-block MTP runtime; trailing MTP blocks are always tolerated
|
||||
// as missing so pruned GGUFs (block 0 only) load cleanly. Override
|
||||
// mtp_flags to NOT_REQUIRED for those.
|
||||
const int eff_mtp_flags = is_first_mtp ? mtp_flags : (mtp_flags | TENSOR_NOT_REQUIRED);
|
||||
// Multi-block MTP: every declared MTP block is required (the draft chain
|
||||
// runs all n_layer_nextn heads), so each block uses the captured
|
||||
// `mtp_flags` directly — already NOT_REQUIRED for a trunk-only GGUF,
|
||||
// which keeps that path correct.
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, eff_mtp_flags);
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, mtp_flags);
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED);
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
@@ -140,12 +137,12 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) {
|
||||
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, eff_mtp_flags);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, eff_mtp_flags);
|
||||
create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, mtp_flags);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, mtp_flags);
|
||||
|
||||
layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, eff_mtp_flags);
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, mtp_flags);
|
||||
|
||||
// dense MLP (leading dense blocks) — present if the MTP block isn't MoE
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
|
||||
@@ -165,9 +162,9 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) {
|
||||
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
// NextN-specific tensors that define the MTP block.
|
||||
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, eff_mtp_flags);
|
||||
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, eff_mtp_flags);
|
||||
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, eff_mtp_flags);
|
||||
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, mtp_flags);
|
||||
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, mtp_flags);
|
||||
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, mtp_flags);
|
||||
layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
|
||||
layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
|
||||
layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED);
|
||||
@@ -176,13 +173,11 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) {
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
load_block_trunk(i, trunk_flags);
|
||||
}
|
||||
// Only the first MTP block (i == n_main) is required at runtime — the
|
||||
// single-block-MTP graph in build_arch_graph always uses that one.
|
||||
// Trailing MTP blocks are loaded if present (so an un-pruned GGUF with
|
||||
// all MTP layers still works) but tolerated when absent via the pruning
|
||||
// path. See scripts/prune_step35_extra_mtp.py for the pruner.
|
||||
// All n_layer_nextn MTP blocks are required — the multi-block draft chain
|
||||
// runs every head (head k at offset k). The GGUF declares the count via
|
||||
// step35.nextn_predict_layers.
|
||||
for (int i = n_layer; i < n_layer_all; ++i) {
|
||||
load_block_mtp(i, /*is_first_mtp=*/ i == n_layer);
|
||||
load_block_mtp(i);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -372,13 +367,14 @@ llama_model_step35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr
|
||||
: llm_graph_context(params) {
|
||||
GGML_ASSERT(hparams.n_layer_nextn > 0 && "STEP35 MTP requires n_layer_nextn > 0");
|
||||
|
||||
// Single-block MTP only: always run the first trained MTP block (Qwen
|
||||
// MTP / vLLM single-MTP-layer style). Multi-block round-robin proved to
|
||||
// be a much deeper refactor than this PR justifies; the trailing MTP
|
||||
// blocks are loaded with TENSOR_NOT_REQUIRED so pruned GGUFs (with just
|
||||
// block 0) also work — see load_arch_tensors below and
|
||||
// scripts/prune_step35_extra_mtp.py.
|
||||
const int il = hparams.n_layer();
|
||||
// Multi-block MTP: the DECODER_MTP graph runs the MTP head selected by
|
||||
// cparams.nextn_layer_offset (0 = first trained head). The speculative driver
|
||||
// bumps the offset per draft step to chain heads 45->46->47. offset 0 keeps
|
||||
// single-block behavior identical to before.
|
||||
const int il = hparams.n_layer() + cparams.nextn_layer_offset;
|
||||
GGML_ASSERT(cparams.nextn_layer_offset >= 0 &&
|
||||
cparams.nextn_layer_offset < (int) hparams.n_layer_nextn &&
|
||||
"nextn_layer_offset out of range [0, n_layer_nextn)");
|
||||
const auto & layer = model.layers[il];
|
||||
|
||||
GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj");
|
||||
@@ -536,6 +532,9 @@ llama_model_step35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cb(cur, "mtp_post_ffn", il);
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
|
||||
// Pre-norm hidden state: used by the AR draft loop to seed the next MTP step.
|
||||
cb(cur, "h_nextn", -1);
|
||||
res->t_h_nextn = cur;
|
||||
|
||||
@@ -129,7 +129,154 @@ void test_gbnf_generation(testing &t) {
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= ([^<] | "<" [^/] | "</" [^t] | "</t" [^a] | "</ta" [^g] | "</tag" [^>])* ("<" | "</" | "</t" | "</ta" | "</tag")?
|
||||
root ::= until-0
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
until-0 ::= | [<] until-0-01 | [^<] until-0
|
||||
until-0-01 ::= | [<] until-0-01 | [/] until-0-02 | [^/<] until-0
|
||||
until-0-02 ::= | [<] until-0-01 | [t] until-0-03 | [^<t] until-0
|
||||
until-0-03 ::= | [<] until-0-01 | [a] until-0-04 | [^<a] until-0
|
||||
until-0-04 ::= | [<] until-0-01 | [g] until-0-05 | [^<g] until-0
|
||||
until-0-05 ::= | [<] until-0-01 | [^<>] until-0
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("until grammar overlapping delimiter", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.until("\n</parameter>\n");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= until-0
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
until-0 ::= | [\n] until-0-01 | [^\n] until-0
|
||||
until-0-01 ::= | [\n] until-0-01 | [<] until-0-02 | [^\n<] until-0
|
||||
until-0-02 ::= | [\n] until-0-01 | [/] until-0-03 | [^\n/] until-0
|
||||
until-0-03 ::= | [\n] until-0-01 | [p] until-0-04 | [^\np] until-0
|
||||
until-0-04 ::= | [\n] until-0-01 | [a] until-0-05 | [^\na] until-0
|
||||
until-0-05 ::= | [\n] until-0-01 | [r] until-0-06 | [^\nr] until-0
|
||||
until-0-06 ::= | [\n] until-0-01 | [a] until-0-07 | [^\na] until-0
|
||||
until-0-07 ::= | [\n] until-0-01 | [m] until-0-08 | [^\nm] until-0
|
||||
until-0-08 ::= | [\n] until-0-01 | [e] until-0-09 | [^\ne] until-0
|
||||
until-0-09 ::= | [\n] until-0-01 | [t] until-0-10 | [^\nt] until-0
|
||||
until-0-10 ::= | [\n] until-0-01 | [e] until-0-11 | [^\ne] until-0
|
||||
until-0-11 ::= | [\n] until-0-01 | [r] until-0-12 | [^\nr] until-0
|
||||
until-0-12 ::= | [\n] until-0-01 | [>] until-0-13 | [^\n>] until-0
|
||||
until-0-13 ::= | [^\n] until-0
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
// DeepSeek-V3.2 tag prefix. The DSML token (|DSML|) embeds U+FF5C,
|
||||
// so the delimiter mixes ASCII and multi-byte codepoints.
|
||||
t.test("until grammar unicode delimiter", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.until("<|DSML|");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= until-0
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
until-0 ::= | [<] until-0-01 | [^<] until-0
|
||||
until-0-01 ::= | [<] until-0-01 | [\uFF5C] until-0-02 | [^<\uFF5C] until-0
|
||||
until-0-02 ::= | [<] until-0-01 | [D] until-0-03 | [^<D] until-0
|
||||
until-0-03 ::= | [<] until-0-01 | [S] until-0-04 | [^<S] until-0
|
||||
until-0-04 ::= | [<] until-0-01 | [M] until-0-05 | [^<M] until-0
|
||||
until-0-05 ::= | [<] until-0-01 | [L] until-0-06 | [^<L] until-0
|
||||
until-0-06 ::= | [<] until-0-01 | [^<\uFF5C] until-0
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("until grammar multiple delimiters", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.until_one_of({"ab", "cd", "ef"});
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= until-0
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
until-0 ::= | [a] until-0-01 | [c] until-0-03 | [e] until-0-05 | [^ace] until-0
|
||||
until-0-01 ::= | [a] until-0-01 | [c] until-0-03 | [e] until-0-05 | [^abce] until-0
|
||||
until-0-03 ::= | [a] until-0-01 | [c] until-0-03 | [e] until-0-05 | [^acde] until-0
|
||||
until-0-05 ::= | [a] until-0-01 | [c] until-0-03 | [e] until-0-05 | [^acef] until-0
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("ac grammar", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.ac(p.until("</tag>") + p.literal("</tag>"), "</tag>");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
ac-3 ::= [<] ac-3-01 | [^<] ac-3
|
||||
ac-3-01 ::= [<] ac-3-01 | [/] ac-3-02 | [^/<] ac-3
|
||||
ac-3-02 ::= [<] ac-3-01 | [t] ac-3-03 | [^<t] ac-3
|
||||
ac-3-03 ::= [<] ac-3-01 | [a] ac-3-04 | [^<a] ac-3
|
||||
ac-3-04 ::= [<] ac-3-01 | [g] ac-3-05 | [^<g] ac-3
|
||||
ac-3-05 ::= [>] | [<] ac-3-01 | [^<>] ac-3
|
||||
root ::= ac-3
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("ac grammar terminates at first delimiter", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.ac(p.until("\n</parameter>\n") + p.literal("\n</parameter>\n"), "\n</parameter>\n");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
ac-3 ::= [\n] ac-3-01 | [^\n] ac-3
|
||||
ac-3-01 ::= [\n] ac-3-01 | [<] ac-3-02 | [^\n<] ac-3
|
||||
ac-3-02 ::= [\n] ac-3-01 | [/] ac-3-03 | [^\n/] ac-3
|
||||
ac-3-03 ::= [\n] ac-3-01 | [p] ac-3-04 | [^\np] ac-3
|
||||
ac-3-04 ::= [\n] ac-3-01 | [a] ac-3-05 | [^\na] ac-3
|
||||
ac-3-05 ::= [\n] ac-3-01 | [r] ac-3-06 | [^\nr] ac-3
|
||||
ac-3-06 ::= [\n] ac-3-01 | [a] ac-3-07 | [^\na] ac-3
|
||||
ac-3-07 ::= [\n] ac-3-01 | [m] ac-3-08 | [^\nm] ac-3
|
||||
ac-3-08 ::= [\n] ac-3-01 | [e] ac-3-09 | [^\ne] ac-3
|
||||
ac-3-09 ::= [\n] ac-3-01 | [t] ac-3-10 | [^\nt] ac-3
|
||||
ac-3-10 ::= [\n] ac-3-01 | [e] ac-3-11 | [^\ne] ac-3
|
||||
ac-3-11 ::= [\n] ac-3-01 | [r] ac-3-12 | [^\nr] ac-3
|
||||
ac-3-12 ::= [\n] ac-3-01 | [>] ac-3-13 | [^\n>] ac-3
|
||||
ac-3-13 ::= [\n] | [^\n] ac-3
|
||||
root ::= ac-3
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("ac grammar multiple delimiters", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.ac(p.eps(), std::vector<std::string>{"ab", "cd", "ef"});
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
ac-1 ::= [a] ac-1-01 | [c] ac-1-03 | [e] ac-1-05 | [^ace] ac-1
|
||||
ac-1-01 ::= [b] | [a] ac-1-01 | [c] ac-1-03 | [e] ac-1-05 | [^abce] ac-1
|
||||
ac-1-03 ::= [d] | [a] ac-1-01 | [c] ac-1-03 | [e] ac-1-05 | [^acde] ac-1
|
||||
ac-1-05 ::= [f] | [a] ac-1-01 | [c] ac-1-03 | [e] ac-1-05 | [^acef] ac-1
|
||||
root ::= ac-1
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
#undef NDEBUG
|
||||
#include <cassert>
|
||||
|
||||
int main(void) {
|
||||
static void test(void) {
|
||||
common_params params;
|
||||
|
||||
printf("test-arg-parser: make sure there is no duplicated arguments in any examples\n\n");
|
||||
@@ -210,3 +210,13 @@ int main(void) {
|
||||
|
||||
printf("test-arg-parser: all tests OK\n\n");
|
||||
}
|
||||
|
||||
int main(void) {
|
||||
try {
|
||||
test();
|
||||
} catch (std::exception & e) {
|
||||
fprintf(stderr, "test-arg-parser: exception: %s\n", e.what());
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -8433,6 +8433,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {2, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {1, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {2, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 4, k, {3, 2}, {2, 2}));
|
||||
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {2, 1}));
|
||||
@@ -8449,6 +8450,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {2, 3}, {1, 1}, {0, 1, 3, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
|
||||
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 4, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 2, 1, 3}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 1, 3, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
|
||||
|
||||
+101
-26
@@ -1562,37 +1562,112 @@ static void test_msgs_oaicompat_json_conversion() {
|
||||
}
|
||||
}
|
||||
|
||||
static void test_split_by_role() {
|
||||
static void test_msg_token_delimiters_split() {
|
||||
LOG_DBG("%s\n", __func__);
|
||||
|
||||
// Delimiters that share a leading token, distinguished by the second token,
|
||||
// to exercise the per-position token matching.
|
||||
const common_chat_msg_delimiters delims = {
|
||||
{ { COMMON_CHAT_ROLE_USER, "", { 10, 11 } },
|
||||
{ COMMON_CHAT_ROLE_ASSISTANT, "", { 10, 12 } } }
|
||||
};
|
||||
|
||||
// Empty inputs
|
||||
assert_equals<size_t>(0, common_chat_split_by_role("", {}).size());
|
||||
assert_equals<size_t>(0, common_chat_split_by_role("hello", {}).size());
|
||||
assert_equals<size_t>(0, common_chat_split_by_role("", { { "user", "<|user|>" } }).size());
|
||||
assert_equals<size_t>(0, common_chat_msg_delimiters{}.split({}).spans.size());
|
||||
assert_equals<size_t>(0, common_chat_msg_delimiters{}.split({ 10, 11 }).spans.size());
|
||||
assert_equals<size_t>(0, delims.split({}).spans.size());
|
||||
|
||||
// Multi-role conversation, no leading/trailing content
|
||||
// No delimiters match -> no spans
|
||||
assert_equals<size_t>(0, delims.split({ 100, 101, 102 }).spans.size());
|
||||
|
||||
// Multi-role conversation: <user>Hi<assistant>Hello<user>Bye
|
||||
{
|
||||
const std::string prompt = "<|user|>Hi<|assistant|>Hello<|user|>Bye";
|
||||
const auto splits = common_chat_split_by_role(prompt, {
|
||||
{ "user", "<|user|>" },
|
||||
{ "assistant", "<|assistant|>" },
|
||||
});
|
||||
assert_equals<size_t>(3, splits.size());
|
||||
const llama_tokens tokens = {
|
||||
10, 11, // <user>
|
||||
100, 101, // Hi
|
||||
10, 12, // <assistant>
|
||||
200, 201, 202, // Hello
|
||||
10, 11, // <user>
|
||||
300, 301, // Bye
|
||||
};
|
||||
|
||||
assert_equals<std::string>("user", splits[0].role);
|
||||
assert_equals<size_t>(0, splits[0].pos);
|
||||
assert_equals<size_t>(10, splits[0].len);
|
||||
assert_equals<std::string>("<|user|>Hi", prompt.substr(splits[0].pos, splits[0].len));
|
||||
const auto result = delims.split(tokens);
|
||||
const auto & spans = result.spans;
|
||||
assert_equals<size_t>(3, spans.size());
|
||||
|
||||
assert_equals<std::string>("assistant", splits[1].role);
|
||||
assert_equals<size_t>(10, splits[1].pos);
|
||||
assert_equals<size_t>(18, splits[1].len);
|
||||
assert_equals<std::string>("<|assistant|>Hello", prompt.substr(splits[1].pos, splits[1].len));
|
||||
assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role);
|
||||
assert_equals<size_t>(0, spans[0].pos);
|
||||
assert_equals<size_t>(4, spans[0].len);
|
||||
|
||||
assert_equals<std::string>("user", splits[2].role);
|
||||
assert_equals<size_t>(28, splits[2].pos);
|
||||
assert_equals<size_t>(11, splits[2].len);
|
||||
assert_equals<std::string>("<|user|>Bye", prompt.substr(splits[2].pos, splits[2].len));
|
||||
assert_equals(COMMON_CHAT_ROLE_ASSISTANT, spans[1].role);
|
||||
assert_equals<size_t>(4, spans[1].pos);
|
||||
assert_equals<size_t>(5, spans[1].len);
|
||||
|
||||
assert_equals(COMMON_CHAT_ROLE_USER, spans[2].role);
|
||||
assert_equals<size_t>(9, spans[2].pos);
|
||||
assert_equals<size_t>(4, spans[2].len);
|
||||
|
||||
// is_user_start() is true at the token position where a user span begins
|
||||
assert_equals(true, result.is_user_start(0));
|
||||
assert_equals(false, result.is_user_start(4)); // assistant span
|
||||
assert_equals(true, result.is_user_start(9));
|
||||
}
|
||||
|
||||
// Content before the first delimiter is not captured as a span
|
||||
{
|
||||
const llama_tokens tokens = {
|
||||
500, 501, // leading content (dropped)
|
||||
10, 11, // <user>
|
||||
100, // Hi
|
||||
};
|
||||
|
||||
const auto spans = delims.split(tokens).spans;
|
||||
assert_equals<size_t>(1, spans.size());
|
||||
assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role);
|
||||
assert_equals<size_t>(2, spans[0].pos);
|
||||
assert_equals<size_t>(3, spans[0].len);
|
||||
}
|
||||
|
||||
// Skipped regions (media chunks) are jumped over but still count as span content
|
||||
{
|
||||
const llama_tokens tokens = {
|
||||
10, 11, // <user>
|
||||
LLAMA_TOKEN_NULL, // media chunk (3 tokens)
|
||||
LLAMA_TOKEN_NULL,
|
||||
LLAMA_TOKEN_NULL,
|
||||
100, // Hi
|
||||
10, 12, // <assistant>
|
||||
};
|
||||
|
||||
const std::map<size_t, size_t> skips = { { 2, 3 } };
|
||||
|
||||
const auto spans = delims.split(tokens, skips).spans;
|
||||
assert_equals<size_t>(2, spans.size());
|
||||
|
||||
assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role);
|
||||
assert_equals<size_t>(0, spans[0].pos);
|
||||
assert_equals<size_t>(6, spans[0].len);
|
||||
|
||||
assert_equals(COMMON_CHAT_ROLE_ASSISTANT, spans[1].role);
|
||||
assert_equals<size_t>(6, spans[1].pos);
|
||||
assert_equals<size_t>(2, spans[1].len);
|
||||
}
|
||||
|
||||
// A delimiter sequence inside a skipped region is not matched
|
||||
{
|
||||
const llama_tokens tokens = {
|
||||
10, 11, // <user>
|
||||
10, 12, // skipped region that happens to contain delimiter tokens
|
||||
100, // Hi
|
||||
};
|
||||
|
||||
const std::map<size_t, size_t> skips = { { 2, 2 } };
|
||||
|
||||
const auto spans = delims.split(tokens, skips).spans;
|
||||
assert_equals<size_t>(1, spans.size());
|
||||
assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role);
|
||||
assert_equals<size_t>(0, spans[0].pos);
|
||||
assert_equals<size_t>(5, spans[0].len);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5022,14 +5097,14 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
tst.test("Hello, world!\nWhat's up?").tools({ special_function_tool }).expect(message_assist).expect_reconstruction().run();
|
||||
|
||||
tst.test(
|
||||
"```json\n\"42\" \n```")
|
||||
"```json\n\"42\"\n```")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.json_schema(const_schema)
|
||||
.expect_content(R"("42")")
|
||||
.run();
|
||||
|
||||
tst.test(
|
||||
"\"42\" \n")
|
||||
"\"42\"\n")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.json_schema(const_schema)
|
||||
.expect_content(R"("42")")
|
||||
@@ -5857,7 +5932,7 @@ int main(int argc, char ** argv) {
|
||||
{
|
||||
test_msg_diffs_compute();
|
||||
test_msgs_oaicompat_json_conversion();
|
||||
test_split_by_role();
|
||||
test_msg_token_delimiters_split();
|
||||
test_tools_oaicompat_json_conversion();
|
||||
test_convert_responses_to_chatcmpl();
|
||||
test_developer_role_to_system_workaround();
|
||||
|
||||
@@ -995,6 +995,32 @@ static void test_macros(testing & t) {
|
||||
json::object(),
|
||||
"Hello, John Smith,Hi, Jane Doe"
|
||||
);
|
||||
|
||||
test_template(t, "macro with caller",
|
||||
"\
|
||||
{%- macro nest_dict(o, i, ff='') %}\n\
|
||||
{{- caller(ff) }}\n\
|
||||
{%- for k, v in o|items %}\n\
|
||||
{{- i + k + ': ' }}\n\
|
||||
{%- if v is mapping %}\n\
|
||||
{{- '{' }}\n\
|
||||
{% call(f) nest_dict(v, i + ' ') %}\n\
|
||||
{{- 'fail' if ff is undefined }}\n\
|
||||
{%- endcall %}\n\
|
||||
{{- i + '}' }}\n\
|
||||
{% else %}\n\
|
||||
{{- v|string }}\n\
|
||||
{% endif %}\n\
|
||||
{%- endfor %}\n\
|
||||
{%- endmacro %}\n\
|
||||
{%- call(f) nest_dict({'root1': 1, 'root2': {'nest1': 1, 'nest2': {'nest3': 2}}}, ' ', 'Dict') %}\n\
|
||||
{{- 'fail' if ff is defined }}\n\
|
||||
{{- f + ' {' }}\n\
|
||||
{% endcall %}\n\
|
||||
{{- '}' }}",
|
||||
json::object(),
|
||||
"Dict {\n root1: 1\n root2: {\n nest1: 1\n nest2: {\n nest3: 2\n }\n }\n}"
|
||||
);
|
||||
}
|
||||
|
||||
static void test_namespace(testing & t) {
|
||||
|
||||
@@ -92,7 +92,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"minimum": 0
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ([0] | [1-9] [0-9]{0,15}) space
|
||||
root ::= ([0] | [1-9] [0-9]{0,15})
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -105,7 +105,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"minimum": 1
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ([1-9] [0-9]{0,15}) space
|
||||
root ::= ([1-9] [0-9]{0,15})
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -118,7 +118,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"minimum": 3
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ([1-2] [0-9]{1,15} | [3-9] [0-9]{0,15}) space
|
||||
root ::= ([1-2] [0-9]{1,15} | [3-9] [0-9]{0,15})
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -131,7 +131,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"minimum": 9
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ([1-8] [0-9]{1,15} | [9] [0-9]{0,15}) space
|
||||
root ::= ([1-8] [0-9]{1,15} | [9] [0-9]{0,15})
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -144,7 +144,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"minimum": 10
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ([1] ([0-9]{1,15}) | [2-9] [0-9]{1,15}) space
|
||||
root ::= ([1] ([0-9]{1,15}) | [2-9] [0-9]{1,15})
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -157,7 +157,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"minimum": 25
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ([1] [0-9]{2,15} | [2] ([0-4] [0-9]{1,14} | [5-9] [0-9]{0,14}) | [3-9] [0-9]{1,15}) space
|
||||
root ::= ([1] [0-9]{2,15} | [2] ([0-4] [0-9]{1,14} | [5-9] [0-9]{0,14}) | [3-9] [0-9]{1,15})
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -170,7 +170,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"maximum": 30
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("-" [1-9] [0-9]{0,15} | [0-9] | ([1-2] [0-9] | [3] "0")) space
|
||||
root ::= ("-" [1-9] [0-9]{0,15} | [0-9] | ([1-2] [0-9] | [3] "0"))
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -183,7 +183,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"minimum": -5
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("-" ([0-5]) | [0] | [1-9] [0-9]{0,15}) space
|
||||
root ::= ("-" ([0-5]) | [0] | [1-9] [0-9]{0,15})
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -196,7 +196,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"minimum": -123
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("-" ([0-9] | ([1-8] [0-9] | [9] [0-9]) | "1" ([0-1] [0-9] | [2] [0-3])) | [0] | [1-9] [0-9]{0,15}) space
|
||||
root ::= ("-" ([0-9] | ([1-8] [0-9] | [9] [0-9]) | "1" ([0-1] [0-9] | [2] [0-3])) | [0] | [1-9] [0-9]{0,15})
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -209,7 +209,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"maximum": -5
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("-" ([0-4] [0-9]{1,15} | [5-9] [0-9]{0,15})) space
|
||||
root ::= ("-" ([0-4] [0-9]{1,15} | [5-9] [0-9]{0,15}))
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -222,7 +222,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"maximum": 1
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("-" [1-9] [0-9]{0,15} | [0-1]) space
|
||||
root ::= ("-" [1-9] [0-9]{0,15} | [0-1])
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -235,7 +235,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"maximum": 100
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("-" [1-9] [0-9]{0,15} | [0-9] | ([1-8] [0-9] | [9] [0-9]) | "100") space
|
||||
root ::= ("-" [1-9] [0-9]{0,15} | [0-9] | ([1-8] [0-9] | [9] [0-9]) | "100")
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -249,7 +249,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"maximum": 23
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ([0-9] | ([1] [0-9] | [2] [0-3])) space
|
||||
root ::= ([0-9] | ([1] [0-9] | [2] [0-3]))
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -263,7 +263,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"maximum": 300
|
||||
})""",
|
||||
R"""(
|
||||
root ::= (([1] ([5-9]) | [2-9] [0-9]) | ([1-2] [0-9]{2} | [3] "00")) space
|
||||
root ::= (([1] ([5-9]) | [2-9] [0-9]) | ([1-2] [0-9]{2} | [3] "00"))
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -277,7 +277,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"maximum": 30
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ([5-9] | ([1-2] [0-9] | [3] "0")) space
|
||||
root ::= ([5-9] | ([1-2] [0-9] | [3] "0"))
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -291,7 +291,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"maximum": 42
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("-" ([0-9] | ([1-8] [0-9] | [9] [0-9]) | "1" ([0-1] [0-9] | [2] [0-3])) | [0-9] | ([1-3] [0-9] | [4] [0-2])) space
|
||||
root ::= ("-" ([0-9] | ([1-8] [0-9] | [9] [0-9]) | "1" ([0-1] [0-9] | [2] [0-3])) | [0-9] | ([1-3] [0-9] | [4] [0-2]))
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -305,7 +305,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"maximum": 10
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("-" ([0-9] | "10") | [0-9] | "10") space
|
||||
root ::= ("-" ([0-9] | "10") | [0-9] | "10")
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -333,17 +333,17 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"empty schema (object)",
|
||||
"{}",
|
||||
R"""(
|
||||
array ::= "[" space ( value ("," space value)* )? "]" space
|
||||
boolean ::= ("true" | "false") space
|
||||
array ::= "[" space ( value ("," space value)* )? space "]"
|
||||
boolean ::= ("true" | "false")
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
null ::= "null" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
|
||||
null ::= "null"
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? space "}"
|
||||
root ::= object
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
string ::= "\"" char* "\""
|
||||
value ::= object | array | string | number | boolean | null
|
||||
)"""
|
||||
});
|
||||
@@ -361,17 +361,17 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
date ::= [0-9]{4} "-" ( "0" [1-9] | "1" [0-2] ) "-" ( "0" [1-9] | [1-2] [0-9] | "3" [0-1] )
|
||||
date-string ::= "\"" date "\"" space
|
||||
date-string ::= "\"" date "\""
|
||||
date-time ::= date "T" time
|
||||
date-time-string ::= "\"" date-time "\"" space
|
||||
root ::= "[" space tuple-0 "," space uuid "," space tuple-2 "," space tuple-3 "]" space
|
||||
date-time-string ::= "\"" date-time "\""
|
||||
root ::= "[" space tuple-0 "," space uuid "," space tuple-2 "," space tuple-3 space "]"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
time ::= ([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9]{3} )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )
|
||||
time-string ::= "\"" time "\"" space
|
||||
time-string ::= "\"" time "\""
|
||||
tuple-0 ::= date-string
|
||||
tuple-2 ::= time-string
|
||||
tuple-3 ::= date-time-string
|
||||
uuid ::= "\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\"" space
|
||||
uuid ::= "\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\""
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -383,7 +383,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "\"" char* "\"" space
|
||||
root ::= "\"" char* "\""
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -397,7 +397,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "\"" char+ "\"" space
|
||||
root ::= "\"" char+ "\""
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -411,7 +411,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "\"" char{3,} "\"" space
|
||||
root ::= "\"" char{3,} "\""
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -425,7 +425,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "\"" char{0,3} "\"" space
|
||||
root ::= "\"" char{0,3} "\""
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -440,7 +440,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "\"" char{1,4} "\"" space
|
||||
root ::= "\"" char{1,4} "\""
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -452,7 +452,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"type": "boolean"
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("true" | "false") space
|
||||
root ::= ("true" | "false")
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -465,7 +465,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
root ::= ("-"? integral-part) space
|
||||
root ::= ("-"? integral-part)
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -477,7 +477,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"const": "foo"
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "\"foo\"" space
|
||||
root ::= "\"foo\""
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -489,7 +489,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"const": 123
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "123" space
|
||||
root ::= "123"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -501,7 +501,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"enum": ["red", "amber", "green", null, 42, ["foo"]]
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("\"red\"" | "\"amber\"" | "\"green\"" | "null" | "42" | "[\"foo\"]") space
|
||||
root ::= ("\"red\"" | "\"amber\"" | "\"green\"" | "null" | "42" | "[\"foo\"]")
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -515,9 +515,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "[" space (string ("," space string)*)? "]" space
|
||||
root ::= "[" space (string ("," space string)*)? space "]"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
string ::= "\"" char* "\""
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -529,12 +529,12 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"prefixItems": { "type": "string" }
|
||||
})""",
|
||||
R"""(
|
||||
alternative-0 ::= "[" space (string ("," space string)*)? "]" space
|
||||
alternative-0 ::= "[" space (string ("," space string)*)? space "]"
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
null ::= "null" space
|
||||
null ::= "null"
|
||||
root ::= alternative-0 | null
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
string ::= "\"" char* "\""
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -546,9 +546,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "[" space string "]" space
|
||||
root ::= "[" space string space "]"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
string ::= "\"" char* "\""
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -562,10 +562,10 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
root ::= "[" space string "," space number "]" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
|
||||
root ::= "[" space string "," space number space "]"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
string ::= "\"" char* "\""
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -577,18 +577,18 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"items": {}
|
||||
})""",
|
||||
R"""(
|
||||
array ::= "[" space ( value ("," space value)* )? "]" space
|
||||
boolean ::= ("true" | "false") space
|
||||
array ::= "[" space ( value ("," space value)* )? space "]"
|
||||
boolean ::= ("true" | "false")
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
item ::= object
|
||||
null ::= "null" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
|
||||
root ::= "[" space (item ("," space item)*)? "]" space
|
||||
null ::= "null"
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? space "}"
|
||||
root ::= "[" space (item ("," space item)*)? space "]"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
string ::= "\"" char* "\""
|
||||
value ::= object | array | string | number | boolean | null
|
||||
)"""
|
||||
});
|
||||
@@ -602,18 +602,18 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"prefixItems": { "type": "string" }
|
||||
})""",
|
||||
R"""(
|
||||
array ::= "[" space ( value ("," space value)* )? "]" space
|
||||
boolean ::= ("true" | "false") space
|
||||
array ::= "[" space ( value ("," space value)* )? space "]"
|
||||
boolean ::= ("true" | "false")
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
item ::= object
|
||||
null ::= "null" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
|
||||
root ::= "[" space (item ("," space item)*)? "]" space
|
||||
null ::= "null"
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? space "}"
|
||||
root ::= "[" space (item ("," space item)*)? space "]"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
string ::= "\"" char* "\""
|
||||
value ::= object | array | string | number | boolean | null
|
||||
)"""
|
||||
});
|
||||
@@ -627,7 +627,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
R"""(
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
root ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
root ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -642,8 +642,8 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"minItems": 2
|
||||
})""",
|
||||
R"""(
|
||||
boolean ::= ("true" | "false") space
|
||||
root ::= "[" space boolean ("," space boolean)+ "]" space
|
||||
boolean ::= ("true" | "false")
|
||||
root ::= "[" space boolean ("," space boolean)+ space "]"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -658,8 +658,8 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"maxItems": 0
|
||||
})""",
|
||||
R"""(
|
||||
boolean ::= ("true" | "false") space
|
||||
root ::= "[" space "]" space
|
||||
boolean ::= ("true" | "false")
|
||||
root ::= "[" space space "]"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -674,8 +674,8 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"maxItems": 1
|
||||
})""",
|
||||
R"""(
|
||||
boolean ::= ("true" | "false") space
|
||||
root ::= "[" space boolean? "]" space
|
||||
boolean ::= ("true" | "false")
|
||||
root ::= "[" space boolean? space "]"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -690,8 +690,8 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"maxItems": 2
|
||||
})""",
|
||||
R"""(
|
||||
boolean ::= ("true" | "false") space
|
||||
root ::= "[" space (boolean ("," space boolean)?)? "]" space
|
||||
boolean ::= ("true" | "false")
|
||||
root ::= "[" space (boolean ("," space boolean)?)? space "]"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -708,11 +708,11 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integer ::= ("-"? integral-part) space
|
||||
integer ::= ("-"? integral-part)
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
item ::= number | integer
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
root ::= "[" space item ("," space item){2,4} "]" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
|
||||
root ::= "[" space item ("," space item){2,4} space "]"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -730,8 +730,8 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"maxItems": 5
|
||||
})""",
|
||||
R"""(
|
||||
item ::= ("-" ([0-9] | "1" [0-2]) | [0-9] | ([1-8] [0-9] | [9] [0-9]) | ([1] [0-9]{2} | [2] "0" [0-7])) space
|
||||
root ::= "[" space item ("," space item){2,4} "]" space
|
||||
item ::= ("-" ([0-9] | "1" [0-2]) | [0-9] | ([1-8] [0-9] | [9] [0-9]) | ([1] [0-9]{2} | [2] "0" [0-7]))
|
||||
root ::= "[" space item ("," space item){2,4} space "]"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -749,8 +749,8 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"maxItems": 5
|
||||
})""",
|
||||
R"""(
|
||||
item ::= (([1] ([2-9]) | [2-9] [0-9]) | ([1] [0-9]{2} | [2] "0" [0-7])) space
|
||||
root ::= "[" space item ("," space item){2,4} "]" space
|
||||
item ::= (([1] ([2-9]) | [2-9] [0-9]) | ([1] [0-9]{2} | [2] "0" [0-7]))
|
||||
root ::= "[" space item ("," space item){2,4} space "]"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -763,7 +763,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"pattern": "^abc?d*efg+(hij)?kl$"
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "\"" ("ab" "c"? "d"* "ef" "g"+ ("hij")? "kl") "\"" space
|
||||
root ::= "\"" ("ab" "c"? "d"* "ef" "g"+ ("hij")? "kl") "\""
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -776,7 +776,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"pattern": "^\\[\\]\\{\\}\\(\\)\\|\\+\\*\\?$"
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "\"" ("[]{}()|+*?") "\"" space
|
||||
root ::= "\"" ("[]{}()|+*?") "\""
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -789,7 +789,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"pattern": "^\"$"
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "\"" ("\"") "\"" space
|
||||
root ::= "\"" ("\"") "\""
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -802,7 +802,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"pattern": "^A|B|C|D$"
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "\"" ("A" | "B" | "C" | "D") "\"" space
|
||||
root ::= "\"" ("A" | "B" | "C" | "D") "\""
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -816,7 +816,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
dot ::= [^\x0A\x0D]
|
||||
root ::= "\"" (("(" root-1{1,3} ")")? root-1{3,3} "-" root-1{4,4} " " "a"{3,5} "nd" dot dot dot) "\"" space
|
||||
root ::= "\"" (("(" root-1{1,3} ")")? root-1{3,3} "-" root-1{4,4} " " "a"{3,5} "nd" dot dot dot) "\""
|
||||
root-1 ::= [0-9]
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
@@ -845,9 +845,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
b-kv ::= "\"b\"" space ":" space string
|
||||
c-kv ::= "\"c\"" space ":" space string
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "{" space b-kv "," space c-kv "," space a-kv "}" space
|
||||
root ::= "{" space b-kv "," space c-kv "," space a-kv space "}"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
string ::= "\"" char* "\""
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -865,9 +865,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
R"""(
|
||||
a-kv ::= "\"a\"" space ":" space string
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "{" space (a-kv )? "}" space
|
||||
root ::= "{" space (a-kv )? space "}"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
string ::= "\"" char* "\""
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -889,9 +889,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
b-rest ::= ( "," space c-kv )?
|
||||
c-kv ::= "\"c\"" space ":" space string
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "{" space (a-kv a-rest | b-kv b-rest | c-kv )? "}" space
|
||||
root ::= "{" space (a-kv a-rest | b-kv b-rest | c-kv )? space "}"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
string ::= "\"" char* "\""
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -915,9 +915,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
d-kv ::= "\"d\"" space ":" space string
|
||||
d-rest ::= ( "," space c-kv )?
|
||||
root ::= "{" space b-kv "," space a-kv ( "," space ( d-kv d-rest | c-kv ) )? "}" space
|
||||
root ::= "{" space b-kv "," space a-kv ( "," space ( d-kv d-rest | c-kv ) )? space "}"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
string ::= "\"" char* "\""
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -930,14 +930,14 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
additional-kv ::= string ":" space additional-value
|
||||
additional-value ::= "[" space (number ("," space number)*)? "]" space
|
||||
additional-value ::= "[" space (number ("," space number)*)? space "]"
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
root ::= "{" space (additional-kv ( "," space additional-kv )* )? "}" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
|
||||
root ::= "{" space (additional-kv ( "," space additional-kv )* )? space "}"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
string ::= "\"" char* "\""
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -949,17 +949,17 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"additionalProperties": true
|
||||
})""",
|
||||
R"""(
|
||||
array ::= "[" space ( value ("," space value)* )? "]" space
|
||||
boolean ::= ("true" | "false") space
|
||||
array ::= "[" space ( value ("," space value)* )? space "]"
|
||||
boolean ::= ("true" | "false")
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
null ::= "null" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
|
||||
null ::= "null"
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? space "}"
|
||||
root ::= object
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
string ::= "\"" char* "\""
|
||||
value ::= object | array | string | number | boolean | null
|
||||
)"""
|
||||
});
|
||||
@@ -971,17 +971,17 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"type": "object"
|
||||
})""",
|
||||
R"""(
|
||||
array ::= "[" space ( value ("," space value)* )? "]" space
|
||||
boolean ::= ("true" | "false") space
|
||||
array ::= "[" space ( value ("," space value)* )? space "]"
|
||||
boolean ::= ("true" | "false")
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
null ::= "null" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
|
||||
null ::= "null"
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? space "}"
|
||||
root ::= object
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
string ::= "\"" char* "\""
|
||||
value ::= object | array | string | number | boolean | null
|
||||
)"""
|
||||
});
|
||||
@@ -994,7 +994,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"additionalProperties": false
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "{" space "}" space
|
||||
root ::= "{" space space "}"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -1012,15 +1012,15 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
a-kv ::= "\"a\"" space ":" space number
|
||||
additional-k ::= ["] ( [a] char+ | [^"a] char* )? ["] space
|
||||
additional-k ::= ["] ( [a] char+ | [^"a] char* )? ["]
|
||||
additional-kv ::= additional-k ":" space string
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
root ::= "{" space a-kv ( "," space ( additional-kv ( "," space additional-kv )* ) )? "}" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
|
||||
root ::= "{" space a-kv ( "," space ( additional-kv ( "," space additional-kv )* ) )? space "}"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
string ::= "\"" char* "\""
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -1037,13 +1037,13 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
R"""(
|
||||
a-kv ::= "\"a\"" space ":" space number
|
||||
a-rest ::= ( "," space additional-kv )*
|
||||
additional-k ::= ["] ( [a] char+ | [^"a] char* )? ["] space
|
||||
additional-k ::= ["] ( [a] char+ | [^"a] char* )? ["]
|
||||
additional-kv ::= additional-k ":" space number
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
root ::= "{" space (a-kv a-rest | additional-kv ( "," space additional-kv )* )? "}" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
|
||||
root ::= "{" space (a-kv a-rest | additional-kv ( "," space additional-kv )* )? space "}"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -1061,7 +1061,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"additionalProperties": {"type": "number"}
|
||||
})""",
|
||||
R"""(
|
||||
additional-k ::= ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["] space
|
||||
additional-k ::= ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["]
|
||||
additional-kv ::= additional-k ":" space number
|
||||
also-kv ::= "\"also\"" space ":" space number
|
||||
also-rest ::= ( "," space additional-kv )*
|
||||
@@ -1069,8 +1069,8 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
root ::= "{" space and-kv ( "," space ( also-kv also-rest | additional-kv ( "," space additional-kv )* ) )? "}" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
|
||||
root ::= "{" space and-kv ( "," space ( also-kv also-rest | additional-kv ( "," space additional-kv )* ) )? space "}"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -1090,13 +1090,13 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
-rest ::= ( "," space a-kv )? a-rest
|
||||
a-kv ::= "\"a\"" space ":" space integer
|
||||
a-rest ::= ( "," space additional-kv )*
|
||||
additional-k ::= ["] ( [a] char+ | [^"a] char* ) ["] space
|
||||
additional-k ::= ["] ( [a] char+ | [^"a] char* ) ["]
|
||||
additional-kv ::= additional-k ":" space integer
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
integer ::= ("-"? integral-part) space
|
||||
integer ::= ("-"? integral-part)
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
root ::= ("-"? integral-part) space
|
||||
root0 ::= "{" space (-kv -rest | a-kv a-rest | additional-kv ( "," space additional-kv )* )? "}" space
|
||||
root ::= ("-"? integral-part)
|
||||
root0 ::= "{" space (-kv -rest | a-kv a-rest | additional-kv ( "," space additional-kv )* )? space "}"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -1116,12 +1116,12 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
a-rest ::= ( "," space aa-kv )? aa-rest
|
||||
aa-kv ::= "\"aa\"" space ":" space integer
|
||||
aa-rest ::= ( "," space additional-kv )*
|
||||
additional-k ::= ["] ( [a] ([a] char+ | [^"a] char*) | [^"a] char* )? ["] space
|
||||
additional-k ::= ["] ( [a] ([a] char+ | [^"a] char*) | [^"a] char* )? ["]
|
||||
additional-kv ::= additional-k ":" space integer
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
integer ::= ("-"? integral-part) space
|
||||
integer ::= ("-"? integral-part)
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
root ::= "{" space (a-kv a-rest | aa-kv aa-rest | additional-kv ( "," space additional-kv )* )? "}" space
|
||||
root ::= "{" space (a-kv a-rest | aa-kv aa-rest | additional-kv ( "," space additional-kv )* )? space "}"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -1141,12 +1141,12 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
ab-rest ::= ( "," space ac-kv )? ac-rest
|
||||
ac-kv ::= "\"ac\"" space ":" space integer
|
||||
ac-rest ::= ( "," space additional-kv )*
|
||||
additional-k ::= ["] ( [a] ([b] char+ | [c] char+ | [^"bc] char*) | [^"a] char* )? ["] space
|
||||
additional-k ::= ["] ( [a] ([b] char+ | [c] char+ | [^"bc] char*) | [^"a] char* )? ["]
|
||||
additional-kv ::= additional-k ":" space integer
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
integer ::= ("-"? integral-part) space
|
||||
integer ::= ("-"? integral-part)
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
root ::= "{" space (ab-kv ab-rest | ac-kv ac-rest | additional-kv ( "," space additional-kv )* )? "}" space
|
||||
root ::= "{" space (ab-kv ab-rest | ac-kv ac-rest | additional-kv ( "," space additional-kv )* )? space "}"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -1173,11 +1173,11 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
ref-definitions-foo ::= "{" space ref-definitions-foo-a-kv "}" space
|
||||
ref-definitions-foo ::= "{" space ref-definitions-foo-a-kv space "}"
|
||||
ref-definitions-foo-a-kv ::= "\"a\"" space ":" space string
|
||||
root ::= ref-definitions-foo
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
string ::= "\"" char* "\""
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -1204,10 +1204,10 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
alternative-1 ::= ref-definitions-bar
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
ref-definitions-bar ::= "{" space (ref-definitions-bar-b-kv )? "}" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
|
||||
ref-definitions-bar ::= "{" space (ref-definitions-bar-b-kv )? space "}"
|
||||
ref-definitions-bar-b-kv ::= "\"b\"" space ":" space number
|
||||
ref-definitions-foo ::= "{" space (ref-definitions-foo-a-kv )? "}" space
|
||||
ref-definitions-foo ::= "{" space (ref-definitions-foo-a-kv )? space "}"
|
||||
ref-definitions-foo-a-kv ::= "\"a\"" space ":" space number
|
||||
root ::= alternative-0 | alternative-1
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
@@ -1241,14 +1241,14 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
b ::= b-0 | boolean
|
||||
b-0 ::= string
|
||||
b-kv ::= "\"b\"" space ":" space b
|
||||
boolean ::= ("true" | "false") space
|
||||
boolean ::= ("true" | "false")
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
root ::= "{" space (a-kv a-rest | b-kv )? "}" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
|
||||
root ::= "{" space (a-kv a-rest | b-kv )? space "}"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
string ::= "\"" char* "\""
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -1290,8 +1290,8 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
d-rest ::= ( "," space c-kv )?
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
root ::= "{" space a-kv "," space b-kv ( "," space ( d-kv d-rest | c-kv ) )? "}" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
|
||||
root ::= "{" space a-kv "," space b-kv ( "," space ( d-kv d-rest | c-kv ) )? space "}"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -1311,7 +1311,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
}
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("\"a\"" | "\"b\"") space
|
||||
root ::= ("\"a\"" | "\"b\"")
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -1336,7 +1336,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
}
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("\"b\"" | "\"c\"") space
|
||||
root ::= ("\"b\"" | "\"c\"")
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -1378,13 +1378,13 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
R"""(
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
number- ::= "{" space number-number-kv "}" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
|
||||
number- ::= "{" space number-number-kv space "}"
|
||||
number-kv ::= "\"number\"" space ":" space number-
|
||||
number-number ::= "{" space number-number-root-kv "}" space
|
||||
number-number ::= "{" space number-number-root-kv space "}"
|
||||
number-number-kv ::= "\"number\"" space ":" space number-number
|
||||
number-number-root-kv ::= "\"root\"" space ":" space number
|
||||
root ::= "{" space number-kv "}" space
|
||||
root ::= "{" space number-kv space "}"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -1394,17 +1394,17 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"description only (no type) treated as unconstrained",
|
||||
R"""({"description": "The 0-based index of the last line to be retrieved (inclusive). If None, read until the end of the file."})""",
|
||||
R"""(
|
||||
array ::= "[" space ( value ("," space value)* )? "]" space
|
||||
boolean ::= ("true" | "false") space
|
||||
array ::= "[" space ( value ("," space value)* )? space "]"
|
||||
boolean ::= ("true" | "false")
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
null ::= "null" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
|
||||
null ::= "null"
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? space "}"
|
||||
root ::= value
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
string ::= "\"" char* "\""
|
||||
value ::= object | array | string | number | boolean | null
|
||||
)"""
|
||||
});
|
||||
@@ -1428,9 +1428,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"type": "object"
|
||||
})""",
|
||||
R"""(
|
||||
code ::= "\" \\r \\n \\\" \\\\ \"" space
|
||||
code ::= "\" \\r \\n \\\" \\\\ \""
|
||||
code-kv ::= "\"code\"" space ":" space code
|
||||
root ::= "{" space code-kv "}" space
|
||||
root ::= "{" space code-kv space "}"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -1547,7 +1547,7 @@ int main() {
|
||||
"pattern": "^(?:foo|bar)baz$"
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "\"" (("foo" | "bar") "baz") "\"" space
|
||||
root ::= "\"" (("foo" | "bar") "baz") "\""
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""",
|
||||
});
|
||||
@@ -1560,7 +1560,7 @@ int main() {
|
||||
"pattern": "^(?:(?:ab)+c)?d$"
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "\"" ((("ab")+ "c")? "d") "\"" space
|
||||
root ::= "\"" ((("ab")+ "c")? "d") "\""
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""",
|
||||
});
|
||||
|
||||
@@ -360,9 +360,9 @@ int main(void) {
|
||||
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 1}, {0.241818f, 0.241818f, 0.032727f, 0.241818f, 0.241818f}, 2.0f, 1.1f, 2, 5, {});
|
||||
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 4, 7, {});
|
||||
|
||||
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f, 0.0f, 0.0f}, 1.00f);
|
||||
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.0f, 0.0f, 0.428571f, 0.571429f}, 1.00f);
|
||||
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.1f, 0.2f, 0.3f, 0.4f}, 0.00f); // top_n_sigma == 0 now represents a no-op rather than greedy decoding as of PR#13345
|
||||
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 3.00f);
|
||||
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.1f, 0.2f, 0.3f, 0.4f}, 3.00f);
|
||||
|
||||
test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f);
|
||||
test_sampler_queue(10000, "k", 1, 1.0f, 1.0f);
|
||||
|
||||
@@ -2,11 +2,13 @@
|
||||
|
||||
set(TARGET llama-cli-impl)
|
||||
|
||||
add_library(${TARGET} cli.cpp)
|
||||
add_library(${TARGET} cli.cpp
|
||||
cli-client.cpp
|
||||
cli-context.cpp)
|
||||
set_target_properties(${TARGET} PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS ON)
|
||||
|
||||
target_include_directories(${TARGET} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} ../server)
|
||||
target_link_libraries(${TARGET} PUBLIC server-context llama-common ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(${TARGET} PUBLIC llama-server-impl llama-common ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
if(LLAMA_TOOLS_INSTALL)
|
||||
install(TARGETS ${TARGET} LIBRARY)
|
||||
|
||||
+2
-1
@@ -161,7 +161,7 @@
|
||||
| `-mmu, --mmproj-url URL` | URL to a multimodal projector file. see tools/mtmd/README.md<br/>(env: LLAMA_ARG_MMPROJ_URL) |
|
||||
| `--mmproj-auto, --no-mmproj, --no-mmproj-auto` | whether to use multimodal projector file (if available), useful when using -hf (default: enabled)<br/>(env: LLAMA_ARG_MMPROJ_AUTO) |
|
||||
| `--mmproj-offload, --no-mmproj-offload` | whether to enable GPU offloading for multimodal projector (default: enabled)<br/>(env: LLAMA_ARG_MMPROJ_OFFLOAD) |
|
||||
| `--image, --audio FILE` | path to an image or audio file. use with multimodal models, use comma-separated values for multiple files |
|
||||
| `--image, --audio, --video FILE` | path to an image, audio, or video file. use with multimodal models, use comma-separated values for multiple files |
|
||||
| `--image-min-tokens N` | minimum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)<br/>(env: LLAMA_ARG_IMAGE_MIN_TOKENS) |
|
||||
| `--image-max-tokens N` | maximum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)<br/>(env: LLAMA_ARG_IMAGE_MAX_TOKENS) |
|
||||
| `--chat-template-kwargs STRING` | sets additional params for the json template parser, must be a valid json object string, e.g. '{"key1":"value1","key2":"value2"}'<br/>(env: LLAMA_ARG_CHAT_TEMPLATE_KWARGS) |
|
||||
@@ -174,6 +174,7 @@
|
||||
| `--chat-template-file JINJA_TEMPLATE_FILE` | set custom jinja chat template file (default: template taken from model's metadata)<br/>if suffix/prefix are specified, template will be disabled<br/>only commonly used templates are accepted (unless --jinja is set before this flag):<br/>list of built-in templates:<br/>bailing, bailing-think, bailing2, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek-ocr, deepseek2, deepseek3, exaone-moe, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, granite-4.0, granite-4.1, grok-2, hunyuan-dense, hunyuan-moe, hunyuan-vl, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, pangu-embedded, phi3, phi4, rwkv-world, seed_oss, smolvlm, solar-open, vicuna, vicuna-orca, yandex, zephyr<br/>(env: LLAMA_ARG_CHAT_TEMPLATE_FILE) |
|
||||
| `--skip-chat-parsing, --no-skip-chat-parsing` | force a pure content parser, even if a Jinja template is specified; model will output everything in the content section, including any reasoning and/or tool calls (default: disabled)<br/>(env: LLAMA_ARG_SKIP_CHAT_PARSING) |
|
||||
| `--simple-io` | use basic IO for better compatibility in subprocesses and limited consoles |
|
||||
| `--log-prompts-dir PATH` | Log prompts to directory (only used for debugging, default: disabled) |
|
||||
| `--spec-draft-hf, -hfd, -hfrd, --hf-repo-draft <user>/<model>[:quant]` | Same as --hf-repo, but for the draft model (default: unused)<br/>(env: LLAMA_ARG_SPEC_DRAFT_HF_REPO) |
|
||||
| `--spec-draft-threads, -td, --threads-draft N` | number of threads to use during generation (default: same as --threads) |
|
||||
| `--spec-draft-threads-batch, -tbd, --threads-batch-draft N` | number of threads to use during batch and prompt processing (default: same as --threads-draft) |
|
||||
|
||||
@@ -0,0 +1,164 @@
|
||||
#include "cli-client.h"
|
||||
|
||||
#include "http.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <thread>
|
||||
|
||||
// generation can stall for a long time during prompt processing, so the
|
||||
// read timeout must be generous
|
||||
static constexpr time_t CLI_HTTP_READ_TIMEOUT_SEC = 3600;
|
||||
|
||||
// upper bound for the accumulated response body kept for error reporting
|
||||
static constexpr size_t CLI_HTTP_MAX_ERROR_BODY = 1024 * 1024;
|
||||
|
||||
// returns the path with the base url's path prefix prepended (if any)
|
||||
static std::string join_path(const common_http_url & parts, const std::string & path) {
|
||||
if (parts.path.empty() || parts.path == "/") {
|
||||
return path;
|
||||
}
|
||||
std::string prefix = parts.path;
|
||||
if (prefix.back() == '/') {
|
||||
prefix.pop_back();
|
||||
}
|
||||
return prefix + path;
|
||||
}
|
||||
|
||||
json cli_client::get(const std::string & path) {
|
||||
auto [cli, parts] = common_http_client(server_base);
|
||||
cli.set_read_timeout(CLI_HTTP_READ_TIMEOUT_SEC, 0);
|
||||
auto path_with_model = path + (model.empty() ? "" : ("?model=" + model));
|
||||
auto res = cli.Get(join_path(parts, path_with_model));
|
||||
if (!res) {
|
||||
throw std::runtime_error("failed to connect to " + server_base + ": " + httplib::to_string(res.error()));
|
||||
}
|
||||
if (res->status < 200 || res->status >= 300) {
|
||||
throw std::runtime_error("GET " + path + " failed with status " + std::to_string(res->status) + ": " + res->body);
|
||||
}
|
||||
json result = json::parse(res->body, nullptr, false);
|
||||
if (result.is_discarded()) {
|
||||
throw std::runtime_error("GET " + path + " returned invalid JSON");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
json cli_client::post(const std::string & path, const json & body) {
|
||||
auto [cli, parts] = common_http_client(server_base);
|
||||
cli.set_read_timeout(CLI_HTTP_READ_TIMEOUT_SEC, 0);
|
||||
auto body_with_model = body;
|
||||
if (!model.empty()) {
|
||||
body_with_model["model"] = model;
|
||||
}
|
||||
auto res = cli.Post(join_path(parts, path), body_with_model.dump(), "application/json");
|
||||
if (!res) {
|
||||
throw std::runtime_error("failed to connect to " + server_base + ": " + httplib::to_string(res.error()));
|
||||
}
|
||||
if (res->status < 200 || res->status >= 300) {
|
||||
throw std::runtime_error("POST " + path + " failed with status " + std::to_string(res->status) + ": " + res->body);
|
||||
}
|
||||
json result = json::parse(res->body, nullptr, false);
|
||||
if (result.is_discarded()) {
|
||||
throw std::runtime_error("POST " + path + " returned invalid JSON");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
json cli_client::post_sse(const std::string & path,
|
||||
const json & body,
|
||||
const std::function<bool()> & should_stop,
|
||||
const std::function<void(const json &)> & on_data) {
|
||||
auto [cli, parts] = common_http_client(server_base);
|
||||
cli.set_read_timeout(CLI_HTTP_READ_TIMEOUT_SEC, 0);
|
||||
|
||||
std::string pending; // buffer for incomplete SSE lines
|
||||
std::string raw_body; // accumulated body, used only for error reporting
|
||||
|
||||
auto receiver = [&](const char * data, size_t len) -> bool {
|
||||
if (should_stop()) {
|
||||
return false; // aborts the request
|
||||
}
|
||||
if (raw_body.size() < CLI_HTTP_MAX_ERROR_BODY) {
|
||||
raw_body.append(data, std::min(len, CLI_HTTP_MAX_ERROR_BODY - raw_body.size()));
|
||||
}
|
||||
pending.append(data, len);
|
||||
size_t pos;
|
||||
while ((pos = pending.find('\n')) != std::string::npos) {
|
||||
std::string line = pending.substr(0, pos);
|
||||
pending.erase(0, pos + 1);
|
||||
if (!line.empty() && line.back() == '\r') {
|
||||
line.pop_back();
|
||||
}
|
||||
if (line.rfind("data: ", 0) != 0) {
|
||||
continue;
|
||||
}
|
||||
std::string payload = line.substr(6);
|
||||
if (payload == "[DONE]") {
|
||||
continue;
|
||||
}
|
||||
json event = json::parse(payload, nullptr, false);
|
||||
if (!event.is_discarded()) {
|
||||
on_data(event);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
httplib::Headers headers = {{"Accept", "text/event-stream"}};
|
||||
auto body_with_model = body;
|
||||
if (!model.empty()) {
|
||||
body_with_model["model"] = model;
|
||||
}
|
||||
auto res = cli.Post(join_path(parts, path), headers, body_with_model.dump(), "application/json", receiver);
|
||||
|
||||
if (!res) {
|
||||
if (res.error() == httplib::Error::Canceled && should_stop()) {
|
||||
return json(); // cancelled by the user
|
||||
}
|
||||
return json {{"error", {{"message", "failed to connect to " + server_base + ": " + httplib::to_string(res.error())}}}};
|
||||
}
|
||||
if (res->status < 200 || res->status >= 300) {
|
||||
json error_body = json::parse(raw_body, nullptr, false);
|
||||
if (!error_body.is_discarded() && error_body.contains("error")) {
|
||||
return error_body;
|
||||
}
|
||||
return json {{"error", {{"message", "request failed with status " + std::to_string(res->status)}}}};
|
||||
}
|
||||
return json();
|
||||
}
|
||||
|
||||
bool cli_client::wait_health(const std::function<bool()> & is_aborted) {
|
||||
int connect_attempts = 0;
|
||||
while (!is_aborted()) {
|
||||
auto [cli, parts] = common_http_client(server_base);
|
||||
cli.set_connection_timeout(1, 0);
|
||||
auto res = cli.Get(join_path(parts, "/health"));
|
||||
if (res) {
|
||||
if (res->status == 200) {
|
||||
return true;
|
||||
}
|
||||
// any other status means the server is up but not ready yet
|
||||
// (e.g. 503 while the model is still loading)
|
||||
} else if (++connect_attempts >= 10) {
|
||||
last_error = "failed to connect to " + server_base + ": " + httplib::to_string(res.error());
|
||||
return false;
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(300));
|
||||
}
|
||||
last_error = "aborted while waiting for the server to become ready";
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<std::string> cli_client::list_models() {
|
||||
json resp = get("/v1/models");
|
||||
if (!resp.contains("data") || !resp.at("data").is_array()) {
|
||||
throw std::runtime_error("invalid response from /v1/models");
|
||||
}
|
||||
std::vector<std::string> models;
|
||||
for (const auto & m : resp.at("data")) {
|
||||
if (m.contains("id") && m.at("id").is_string()) {
|
||||
models.push_back(m.at("id").get<std::string>());
|
||||
}
|
||||
}
|
||||
return models;
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include <functional>
|
||||
#include <string>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
// openai-like client for CLI
|
||||
struct cli_client {
|
||||
std::string server_base; // base url, for example "http://127.0.0.1:8080"
|
||||
std::string last_error; // set when wait_health() fails
|
||||
|
||||
std::string model; // optional, set when the server has multiple models (router mode)
|
||||
|
||||
// simple GET request, returns the response json
|
||||
// throws std::runtime_error on transport error or non-2xx status
|
||||
json get(const std::string & path);
|
||||
|
||||
// simple POST request, returns the response json
|
||||
// throws std::runtime_error on transport error or non-2xx status
|
||||
json post(const std::string & path, const json & body);
|
||||
|
||||
// POST request with an SSE streaming response; on_data is invoked once
|
||||
// per "data:" event; the function returns after the stream is finished:
|
||||
// a null json on graceful exit (incl. cancellation via should_stop),
|
||||
// the error response json otherwise
|
||||
json post_sse(const std::string & path,
|
||||
const json & body,
|
||||
const std::function<bool()> & should_stop,
|
||||
const std::function<void(const json &)> & on_data);
|
||||
|
||||
// poll /health until the server is ready to accept requests
|
||||
// returns false if is_aborted returned true or the server is unreachable
|
||||
bool wait_health(const std::function<bool()> & is_aborted);
|
||||
|
||||
//
|
||||
// higher-level wrappers
|
||||
//
|
||||
|
||||
json create_chat_completion(const json & request,
|
||||
const std::function<bool()> & should_stop,
|
||||
const std::function<void(const json &)> & on_data) {
|
||||
return post_sse("/v1/chat/completions", request, should_stop, on_data);
|
||||
}
|
||||
|
||||
json get_props() {
|
||||
return get("/props");
|
||||
}
|
||||
|
||||
std::vector<std::string> list_models();
|
||||
};
|
||||
@@ -0,0 +1,559 @@
|
||||
#include "cli-context.h"
|
||||
#include "cli-view.h"
|
||||
|
||||
#include "arg.h"
|
||||
#include "base64.hpp"
|
||||
#include "log.h"
|
||||
#include "console.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <set>
|
||||
|
||||
std::atomic<bool> g_cli_interrupted = false;
|
||||
|
||||
static bool should_stop() {
|
||||
return g_cli_interrupted.load();
|
||||
}
|
||||
|
||||
static constexpr size_t FILE_GLOB_MAX_RESULTS = 100;
|
||||
|
||||
const char * LLAMA_ASCII_LOGO = R"(
|
||||
▄▄ ▄▄
|
||||
██ ██
|
||||
██ ██ ▀▀█▄ ███▄███▄ ▀▀█▄ ▄████ ████▄ ████▄
|
||||
██ ██ ▄█▀██ ██ ██ ██ ▄█▀██ ██ ██ ██ ██ ██
|
||||
██ ██ ▀█▄██ ██ ██ ██ ▀█▄██ ██ ▀████ ████▀ ████▀
|
||||
██ ██
|
||||
▀▀ ▀▀
|
||||
)";
|
||||
|
||||
// number of values an arg consumes on the command line
|
||||
static int arg_num_values(const common_arg & opt) {
|
||||
if (opt.value_hint_2 != nullptr) {
|
||||
return 2;
|
||||
}
|
||||
if (opt.value_hint != nullptr) {
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
static std::string format_error_message(const json & err) {
|
||||
if (err.contains("error") && err.at("error").is_object()) {
|
||||
const auto & e = err.at("error");
|
||||
if (e.contains("message") && e.at("message").is_string()) {
|
||||
return e.at("message").get<std::string>();
|
||||
}
|
||||
}
|
||||
return err.dump();
|
||||
}
|
||||
|
||||
static std::string media_type_from_ext(const std::string & fname) {
|
||||
std::string ext = std::filesystem::path(fname).extension().string();
|
||||
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
|
||||
if (ext == ".wav" || ext == ".mp3") {
|
||||
return "audio";
|
||||
}
|
||||
if (ext == ".mp4" || ext == ".avi" || ext == ".mkv" || ext == ".mov" || ext == ".webm") {
|
||||
return "video";
|
||||
}
|
||||
return "image";
|
||||
}
|
||||
|
||||
bool cli_context::init() {
|
||||
view::init(params);
|
||||
|
||||
std::optional<view::spinner> spinner;
|
||||
|
||||
bool use_external_server = !params.server_base.empty();
|
||||
if (use_external_server) {
|
||||
std::string base = params.server_base;
|
||||
while (!base.empty() && base.back() == '/') {
|
||||
base.pop_back();
|
||||
}
|
||||
client.server_base = base;
|
||||
|
||||
spinner.emplace("Connecting to server at " + base);
|
||||
} else {
|
||||
if (params.model.path.empty() && params.model.url.empty() &&
|
||||
params.model.hf_repo.empty() && params.model.docker_repo.empty()) {
|
||||
view::show_error(
|
||||
"no model specified",
|
||||
"use -m <file.gguf> or -hf <user/repo> to run a local model,\n"
|
||||
"or --server-base <url> to connect to a running llama-server"
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
spinner.emplace("\n\nLoading model...");
|
||||
|
||||
server.emplace();
|
||||
if (!server->start(params)) {
|
||||
view::show_error("server start failed");
|
||||
return false;
|
||||
}
|
||||
if (!server->wait_ready(should_stop)) {
|
||||
if (!should_stop()) {
|
||||
view::show_error("the server exited before becoming ready");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
client.server_base = server->address();
|
||||
}
|
||||
|
||||
// for --server-base this is the main availability check; for a spawned
|
||||
// server it is a cheap sanity check on top of the ready signal
|
||||
auto is_aborted = [this]() {
|
||||
return should_stop() || (server && !server->alive());
|
||||
};
|
||||
bool healthy = false;
|
||||
try {
|
||||
healthy = client.wait_health(is_aborted);
|
||||
} catch (const std::exception & e) {
|
||||
client.last_error = e.what();
|
||||
}
|
||||
if (!healthy) {
|
||||
if (!should_stop()) {
|
||||
view::show_error(client.last_error);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if (use_external_server) {
|
||||
spinner.reset();
|
||||
if (!list_and_ask_models()) {
|
||||
return false;
|
||||
}
|
||||
// restore the spinner for the next step
|
||||
spinner.emplace("Waiting for server...");
|
||||
}
|
||||
|
||||
fetch_server_props();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void cli_context::fetch_server_props() {
|
||||
try {
|
||||
json props = client.get_props();
|
||||
model_name = props.value("model_alias", "");
|
||||
if (model_name.empty()) {
|
||||
const std::string path = props.value("model_path", "");
|
||||
if (!path.empty()) {
|
||||
model_name = std::filesystem::path(path).filename().string();
|
||||
}
|
||||
}
|
||||
build_info = props.value("build_info", "");
|
||||
if (props.contains("modalities") && props.at("modalities").is_object()) {
|
||||
const auto & modalities = props.at("modalities");
|
||||
has_vision = modalities.value("vision", false);
|
||||
has_audio = modalities.value("audio", false);
|
||||
has_video = modalities.value("video", false);
|
||||
}
|
||||
} catch (const std::exception & e) {
|
||||
// /props can be disabled on remote servers; not fatal
|
||||
LOG_DBG("failed to fetch /props: %s\n", e.what());
|
||||
}
|
||||
}
|
||||
|
||||
bool cli_context::list_and_ask_models() {
|
||||
auto models = client.list_models();
|
||||
|
||||
// only one model: use it without asking
|
||||
if (models.size() == 1) {
|
||||
model_name = models[0];
|
||||
client.model = model_name;
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string message = "\nAvailable models:";
|
||||
if (!models.empty()) {
|
||||
for (size_t i = 0; i < models.size(); ++i) {
|
||||
message += "\n " + std::to_string(i + 1) + ". " + models[i];
|
||||
}
|
||||
}
|
||||
message += "\n";
|
||||
view::show_message(message);
|
||||
std::string selection;
|
||||
while (selection.empty()) {
|
||||
if (should_stop()) {
|
||||
return false;
|
||||
}
|
||||
view::user_turn user_turn;
|
||||
selection = user_turn.read_input(false, "Select model by number: ");
|
||||
if (selection.empty()) {
|
||||
continue;
|
||||
}
|
||||
try {
|
||||
size_t idx = std::stoul(selection);
|
||||
if (idx > 0 && idx <= models.size()) {
|
||||
model_name = models[idx - 1];
|
||||
client.model = model_name;
|
||||
view::show_message("Selected model: " + model_name);
|
||||
break;
|
||||
}
|
||||
} catch (...) {
|
||||
// ignore
|
||||
}
|
||||
view::show_error("Invalid selection. Please enter a valid number.");
|
||||
selection.clear();
|
||||
continue;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void cli_context::add_system_prompt() {
|
||||
if (!params.system_prompt.empty()) {
|
||||
messages.push_back({
|
||||
{"role", "system"},
|
||||
{"content", params.system_prompt}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void cli_context::push_user_message(const std::string & text) {
|
||||
json content;
|
||||
if (pending_media.empty()) {
|
||||
content = text;
|
||||
} else {
|
||||
// multimodal message: media parts first, then the text
|
||||
content = pending_media;
|
||||
content.push_back({
|
||||
{"type", "text"},
|
||||
{"text", text}
|
||||
});
|
||||
pending_media = json::array();
|
||||
}
|
||||
messages.push_back({
|
||||
{"role", "user"},
|
||||
{"content", content}
|
||||
});
|
||||
}
|
||||
|
||||
bool cli_context::stage_media_file(const std::string & fname, const std::string & type) {
|
||||
std::ifstream file(fname, std::ios::binary);
|
||||
if (!file) {
|
||||
return false;
|
||||
}
|
||||
std::string data((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
|
||||
std::string encoded = base64::encode(data);
|
||||
|
||||
if (type == "audio") {
|
||||
std::string ext = std::filesystem::path(fname).extension().string();
|
||||
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
|
||||
pending_media.push_back({
|
||||
{"type", "input_audio"},
|
||||
{"input_audio", {
|
||||
{"data", encoded},
|
||||
{"format", ext == ".mp3" ? "mp3" : "wav"}
|
||||
}}
|
||||
});
|
||||
} else if (type == "video") {
|
||||
pending_media.push_back({
|
||||
{"type", "input_video"},
|
||||
{"input_video", {
|
||||
{"data", encoded}
|
||||
}}
|
||||
});
|
||||
} else {
|
||||
// the server detects the actual image type from the data
|
||||
pending_media.push_back({
|
||||
{"type", "image_url"},
|
||||
{"image_url", {
|
||||
{"url", "data:image/unknown;base64," + encoded}
|
||||
}}
|
||||
});
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool cli_context::generate_completion(std::string & assistant_content, cli_timings & timings) {
|
||||
json body = {
|
||||
{"messages", messages},
|
||||
{"stream", true},
|
||||
// in order to get timings even when we cancel mid-way
|
||||
{"timings_per_token", true},
|
||||
};
|
||||
|
||||
bool stream_error = false;
|
||||
|
||||
view::assistant_turn a;
|
||||
|
||||
json err = client.create_chat_completion(body, should_stop, [&](const json & chunk) {
|
||||
if (chunk.contains("error")) {
|
||||
stream_error = true;
|
||||
view::show_error(format_error_message(chunk));
|
||||
return;
|
||||
}
|
||||
if (chunk.contains("timings")) {
|
||||
const auto & t = chunk.at("timings");
|
||||
timings.prompt_per_second = t.value("prompt_per_second", 0.0);
|
||||
timings.predicted_per_second = t.value("predicted_per_second", 0.0);
|
||||
}
|
||||
if (!chunk.contains("choices") || !chunk.at("choices").is_array() || chunk.at("choices").empty()) {
|
||||
return;
|
||||
}
|
||||
const auto & choice = chunk.at("choices").at(0);
|
||||
if (!choice.contains("delta")) {
|
||||
return;
|
||||
}
|
||||
const auto & delta = choice.at("delta");
|
||||
if (delta.contains("reasoning_content") && delta.at("reasoning_content").is_string()) {
|
||||
const std::string text = delta.at("reasoning_content").get<std::string>();
|
||||
if (!text.empty()) {
|
||||
a.push(view::ASSISTANT_DISPLAY_MODE_REASONING, text);
|
||||
}
|
||||
}
|
||||
if (delta.contains("content") && delta.at("content").is_string()) {
|
||||
const std::string text = delta.at("content").get<std::string>();
|
||||
if (!text.empty()) {
|
||||
assistant_content += text;
|
||||
a.push(view::ASSISTANT_DISPLAY_MODE_CONTENT, text);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
g_cli_interrupted.store(false);
|
||||
|
||||
if (!err.is_null()) {
|
||||
view::show_error(format_error_message(err));
|
||||
return false;
|
||||
}
|
||||
return !stream_error;
|
||||
}
|
||||
|
||||
int cli_context::run() {
|
||||
add_system_prompt();
|
||||
|
||||
std::string modalities = "text";
|
||||
if (has_vision) {
|
||||
modalities += ", vision";
|
||||
}
|
||||
if (has_audio) {
|
||||
modalities += ", audio";
|
||||
}
|
||||
if (has_video) {
|
||||
modalities += ", video";
|
||||
}
|
||||
|
||||
std::string banner;
|
||||
banner += "\n";
|
||||
banner += LLAMA_ASCII_LOGO;
|
||||
banner += "\n";
|
||||
banner += "build : " + build_info + "\n";
|
||||
banner += "model : " + model_name + "\n";
|
||||
banner += "modalities : " + modalities + "\n";
|
||||
if (!params.system_prompt.empty()) {
|
||||
banner += "using custom system prompt\n";
|
||||
}
|
||||
banner += "\n";
|
||||
banner += "available commands:\n";
|
||||
banner += " /exit or Ctrl+C stop or exit\n";
|
||||
banner += " /regen regenerate the last response\n";
|
||||
banner += " /clear clear the chat history\n";
|
||||
banner += " /read <file> add a text file\n";
|
||||
banner += " /glob <pattern> add text files using globbing pattern\n";
|
||||
if (has_vision) {
|
||||
banner += " /image <file> add an image file\n";
|
||||
}
|
||||
if (has_audio) {
|
||||
banner += " /audio <file> add an audio file\n";
|
||||
}
|
||||
if (has_video) {
|
||||
banner += " /video <file> add a video file\n";
|
||||
}
|
||||
banner += "\n";
|
||||
|
||||
view::show_message(banner);
|
||||
|
||||
// interactive loop
|
||||
std::string cur_msg;
|
||||
|
||||
auto add_text_file = [&](const std::string & fname) -> bool {
|
||||
std::ifstream file(fname, std::ios::binary);
|
||||
if (!file) {
|
||||
view::show_error(string_format("file does not exist or cannot be opened: '%s'", fname.c_str()));
|
||||
return false;
|
||||
}
|
||||
std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
|
||||
cur_msg += "--- File: ";
|
||||
cur_msg += fname;
|
||||
cur_msg += " ---\n";
|
||||
cur_msg += content;
|
||||
view::show_message(string_format("Loaded text from '%s'", fname.c_str()));
|
||||
return true;
|
||||
};
|
||||
|
||||
while (true) {
|
||||
std::string buffer;
|
||||
{
|
||||
view::user_turn user_turn;
|
||||
|
||||
if (params.prompt.empty()) {
|
||||
buffer = user_turn.read_input(params.multiline_input);
|
||||
} else {
|
||||
// process input prompt from args
|
||||
for (auto & fname : params.image) {
|
||||
if (!stage_media_file(fname, media_type_from_ext(fname))) {
|
||||
view::show_error(string_format("file does not exist or cannot be opened: '%s'", fname.c_str()));
|
||||
break;
|
||||
}
|
||||
view::show_message(string_format("Loaded media from '%s'", fname.c_str()));
|
||||
}
|
||||
buffer = params.prompt;
|
||||
user_turn.echo(buffer);
|
||||
params.prompt.clear(); // only use it once
|
||||
}
|
||||
}
|
||||
|
||||
if (should_stop()) {
|
||||
g_cli_interrupted.store(false);
|
||||
break;
|
||||
}
|
||||
|
||||
// remove trailing newline
|
||||
if (!buffer.empty() && buffer.back() == '\n') {
|
||||
buffer.pop_back();
|
||||
}
|
||||
|
||||
// skip empty messages
|
||||
if (buffer.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool add_user_msg = true;
|
||||
|
||||
// process commands
|
||||
if (string_starts_with(buffer, "/exit")) {
|
||||
break;
|
||||
} else if (string_starts_with(buffer, "/regen")) {
|
||||
if (messages.size() >= 2) {
|
||||
size_t last_idx = messages.size() - 1;
|
||||
messages.erase(last_idx);
|
||||
add_user_msg = false;
|
||||
} else {
|
||||
view::show_error("No message to regenerate.");
|
||||
continue;
|
||||
}
|
||||
} else if (string_starts_with(buffer, "/clear")) {
|
||||
messages.clear();
|
||||
add_system_prompt();
|
||||
|
||||
pending_media = json::array();
|
||||
view::show_message("Chat history cleared.");
|
||||
continue;
|
||||
} else if (
|
||||
(string_starts_with(buffer, "/image ") && has_vision) ||
|
||||
(string_starts_with(buffer, "/audio ") && has_audio) ||
|
||||
(string_starts_with(buffer, "/video ") && has_video)) {
|
||||
std::string type = buffer.substr(1, 5);
|
||||
// just in case (bad copy-paste for example), we strip all trailing/leading spaces
|
||||
std::string fname = string_strip(buffer.substr(7));
|
||||
if (!stage_media_file(fname, type)) {
|
||||
view::show_error(string_format("file does not exist or cannot be opened: '%s'", fname.c_str()));
|
||||
continue;
|
||||
}
|
||||
view::show_message(string_format("Loaded media from '%s'", fname.c_str()));
|
||||
continue;
|
||||
} else if (string_starts_with(buffer, "/read ")) {
|
||||
std::string fname = string_strip(buffer.substr(6));
|
||||
add_text_file(fname);
|
||||
continue;
|
||||
} else if (string_starts_with(buffer, "/glob ")) {
|
||||
std::error_code ec;
|
||||
size_t count = 0;
|
||||
auto curdir = std::filesystem::current_path();
|
||||
std::string pattern = string_strip(buffer.substr(6));
|
||||
std::filesystem::path rel_path;
|
||||
|
||||
auto startglob = pattern.find_first_of("![*?");
|
||||
if (startglob != std::string::npos && startglob != 0) {
|
||||
auto endpath = pattern.substr(0, startglob).find_last_of('/');
|
||||
if (endpath != std::string::npos) {
|
||||
std::string rel_pattern = pattern.substr(0, endpath);
|
||||
#if !defined(_WIN32)
|
||||
if (string_starts_with(rel_pattern, '~')) {
|
||||
const char * home = std::getenv("HOME");
|
||||
if (home && home[0]) {
|
||||
rel_pattern = home + rel_pattern.substr(1);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
rel_path = rel_pattern;
|
||||
pattern.erase(0, endpath + 1);
|
||||
curdir /= rel_path;
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto & entry : std::filesystem::recursive_directory_iterator(curdir,
|
||||
std::filesystem::directory_options::skip_permission_denied, ec)) {
|
||||
if (!entry.is_regular_file()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string rel = std::filesystem::relative(entry.path(), curdir, ec).string();
|
||||
if (ec) {
|
||||
ec.clear();
|
||||
continue;
|
||||
}
|
||||
std::replace(rel.begin(), rel.end(), '\\', '/');
|
||||
|
||||
if (!glob_match(pattern, rel)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!add_text_file((rel_path / rel).string())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (++count >= FILE_GLOB_MAX_RESULTS) {
|
||||
view::show_error(string_format("Maximum number of globbed files allowed (%zu) reached.", FILE_GLOB_MAX_RESULTS));
|
||||
break;
|
||||
}
|
||||
}
|
||||
continue;
|
||||
} else {
|
||||
// not a command
|
||||
cur_msg += buffer;
|
||||
}
|
||||
|
||||
// generate response
|
||||
if (add_user_msg) {
|
||||
push_user_message(cur_msg);
|
||||
cur_msg.clear();
|
||||
}
|
||||
cli_timings timings;
|
||||
std::string assistant_content;
|
||||
generate_completion(assistant_content, timings);
|
||||
messages.push_back({
|
||||
{"role", "assistant"},
|
||||
{"content", assistant_content}
|
||||
});
|
||||
|
||||
if (params.show_timings) {
|
||||
view::show_info(string_format(
|
||||
"\n[ Prompt: %.1f t/s | Generation: %.1f t/s ]",
|
||||
timings.prompt_per_second,
|
||||
timings.predicted_per_second
|
||||
));
|
||||
}
|
||||
|
||||
if (params.single_turn) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
view::show_message("\n\nExiting...");
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void cli_context::shutdown() {
|
||||
if (server) {
|
||||
server->stop();
|
||||
server.reset();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
|
||||
#include "cli-client.h"
|
||||
#include "cli-server.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
||||
struct cli_timings {
|
||||
double prompt_per_second = 0.0;
|
||||
double predicted_per_second = 0.0;
|
||||
};
|
||||
|
||||
// set by the SIGINT handler; cleared once the interrupt has been handled
|
||||
extern std::atomic<bool> g_cli_interrupted;
|
||||
|
||||
struct cli_context {
|
||||
common_params params;
|
||||
|
||||
cli_client client; // always initialized
|
||||
std::optional<cli_server> server; // only set when no --server-base is given
|
||||
|
||||
json messages = json::array();
|
||||
json pending_media = json::array(); // staged multimodal content parts
|
||||
|
||||
// properties of the connected server
|
||||
// will be populated by fetch_server_props()
|
||||
std::string model_name;
|
||||
std::string build_info;
|
||||
bool has_vision = false;
|
||||
bool has_audio = false;
|
||||
bool has_video = false;
|
||||
|
||||
cli_context(const common_params & params) : params(params) {}
|
||||
~cli_context() {
|
||||
shutdown();
|
||||
}
|
||||
|
||||
// connect to --server-base or spawn a local llama-server child;
|
||||
// argc/argv are needed to forward the server-relevant args to the child
|
||||
bool init();
|
||||
|
||||
// run the interactive chat loop, returns the process exit code
|
||||
int run();
|
||||
|
||||
// stop the local server child (if any)
|
||||
void shutdown();
|
||||
|
||||
private:
|
||||
bool generate_completion(std::string & assistant_content, cli_timings & timings);
|
||||
void fetch_server_props();
|
||||
void add_system_prompt();
|
||||
void push_user_message(const std::string & text);
|
||||
|
||||
// check if server have multiple models (router mode)
|
||||
// if yes, list them then ask; do nothing otherwise
|
||||
bool list_and_ask_models();
|
||||
|
||||
// read a file and stage it as a multimodal content part; type is one of
|
||||
// "image", "audio", "video"; returns false if the file cannot be read
|
||||
bool stage_media_file(const std::string & fname, const std::string & type);
|
||||
};
|
||||
@@ -0,0 +1,83 @@
|
||||
#pragma once
|
||||
|
||||
#include <thread>
|
||||
|
||||
#include "http.h"
|
||||
|
||||
// llama_server will be available as a dynamic library symbol
|
||||
int llama_server(common_params & params, int argc, char ** argv);
|
||||
void llama_server_terminate();
|
||||
|
||||
struct cli_server {
|
||||
std::thread th;
|
||||
int port = -1;
|
||||
std::atomic<bool> is_alive = false;
|
||||
std::atomic<bool> is_stopping = false;
|
||||
|
||||
~cli_server() {
|
||||
stop();
|
||||
}
|
||||
|
||||
void stop() {
|
||||
if (alive() && !is_stopping.exchange(true)) {
|
||||
llama_server_terminate();
|
||||
th.join();
|
||||
}
|
||||
}
|
||||
|
||||
// spawn llama-server in a thread and interact with it via a random port
|
||||
bool start(common_params & params) {
|
||||
port = common_http_get_free_port();
|
||||
if (port <= 0) {
|
||||
fprintf(stderr, "failed to get a free port\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
is_alive.store(true, std::memory_order_release);
|
||||
|
||||
th = std::thread([&]() {
|
||||
common_params server_params = params; // copy
|
||||
server_params.port = port;
|
||||
// argc / argv are only used in router mode, we can skip them for now
|
||||
int res = llama_server(server_params, 0, nullptr);
|
||||
if (res != 0) {
|
||||
fprintf(stderr, "llama_server exited with code %d\n", res);
|
||||
}
|
||||
is_alive.store(false, std::memory_order_release);
|
||||
});
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string address() const {
|
||||
return "http://127.0.0.1:" + std::to_string(port);
|
||||
}
|
||||
|
||||
bool wait_ready(std::function<bool()> should_stop) {
|
||||
if (!alive()) {
|
||||
return false;
|
||||
}
|
||||
while (!should_stop()) {
|
||||
auto [cli, parts] = common_http_client(address());
|
||||
cli.set_connection_timeout(1, 0);
|
||||
auto res = cli.Get("/health");
|
||||
if (res) {
|
||||
if (res->status == 200) {
|
||||
return true;
|
||||
}
|
||||
// any other status means the server is up but not ready yet
|
||||
// (e.g. 503 while the model is still loading)
|
||||
}
|
||||
if (!alive()) {
|
||||
// in case server die permanently
|
||||
return false;
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(200));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool alive() const {
|
||||
return is_alive.load(std::memory_order_acquire);
|
||||
}
|
||||
};
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user