Compare commits

...

19 Commits

Author SHA1 Message Date
Samanvya Tripathi af5c13841f common : fix tool call type detection for nullable and enum schemas (#21327)
* common : fix tool call type detection for nullable and enum schemas

* common, tests : fix grammar delegation for nullable/enum schemas and add tests

Fix enum type inference to scan all enum values (not just index 0) so
schemas like {"enum": [0, "celsius"]} correctly detect string type.

Fix schema_delegates in peg-parser to handle nullable type arrays
(["string", "null"]) and typeless enum schemas in raw mode, allowing
the tagged parser to use raw text instead of JSON-formatted strings.

Add test cases for Qwen3-Coder (TAG_WITH_TAGGED format):
- nullable string ["string", "null"]
- nullable string with null first ["null", "string"]
- nullable integer ["integer", "null"]
- enum without explicit type key
2026-04-03 17:51:23 +02:00
M1DNYT3 277ff5fff7 docker : bump cuda12 to 12.9.1 (#20920)
Co-authored-by: M1DNYT3 <m1dnyt3@MacBookPro.lan>
Co-authored-by: CISC <CISC@users.noreply.github.com>
2026-04-03 15:06:45 +02:00
jeromew 384c0076bc docs: Update build.md: HSA_OVERRIDE_GFX_VERSION clarification (#21331)
The `HSA_OVERRIDE_GFX_VERSION` variable can be used in ROCm to override an unsupported target architecture with a similar but supported target architecture.

This does not and has never worked on Windows. I think the clarification could avoid driving Windows people towards this solution that does not work.
2026-04-03 21:05:14 +08:00
Sigbjørn Skjæret 1f34806c44 jinja: coerce input for string-specific filters (#21370) 2026-04-03 15:03:33 +02:00
Aaron Teo 887535c33f ci: add more binary checks (#21349) 2026-04-03 20:50:00 +08:00
Piotr Wilkin (ilintar) d3416a4aa9 fix: remove stale assert (#21369) 2026-04-03 13:40:41 +02:00
uvos 43a4ee4a2c HIP: build eatch ci build test for a different architecture (#21337)
This helps improve our chances of finding build failures before the release workflow
builds for all architectures.
2026-04-03 11:38:22 +02:00
Tillerino f851fa5ab0 fix: add openssl to nix dependencies (#21353) (#21355) 2026-04-03 12:21:07 +03:00
Vishal Singh f1ac84119c ggml-zendnn : add MUL_MAT_ID op support for MoE models (#21315)
* ggml-zendnn : add MUL_MAT_ID op support for MoE models
- Add MUL_MAT_ID op acceleration for Mixture-of-Experts models
- MUL_MAT_ID op fallback to CPU backend if total experts > 32
- Point ZenDNN lib to latest bits ZenDNN-2026-WW13

* ggml-zendnn : add braces to sgemm failure condition for consistency

Co-authored-by: Aaron Teo <taronaeo@gmail.com>

---------

Co-authored-by: Aaron Teo <taronaeo@gmail.com>
2026-04-03 12:19:08 +03:00
Piotr Wilkin (ilintar) b069b10ab4 vocab: fix Gemma4 tokenizer (#21343)
* seems to work

* fix case with new line

Co-authored-by: sayap <sokann@gmail.com>

* gemma 4: fix pre tok regex

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
Co-authored-by: sayap <sokann@gmail.com>
2026-04-03 10:33:03 +02:00
Radoslav Gerganov 0c58ba3365 rpc : reuse compute graph buffers (#21299)
Reuse the buffer for the ggml context which is used for creating the
compute graph on the server side. This partially addresses a memory leak
created by the CUDA backend due to using buffer addresses as cache
keys.

ref: #21265
ref: #20315
2026-04-03 10:28:09 +03:00
Georgi Gerganov 57ace0d612 chat : avoid including json in chat.h (#21306) 2026-04-03 09:07:59 +03:00
Georgi Gerganov 39b27f0da0 (revert) kv-cache : do not quantize SWA KV cache (#21332)
This reverts commit 17193cce34.
2026-04-03 09:07:01 +03:00
Vishal Singh f49e917876 ci : add AMD ZenDNN label to PR labeler (#21345)
* ci : add AMD CPU label to PR labeler
Add automatic labeling for PRs that modify AMD CPU (ZenDNN) backend files

* ci : rename label AMD CPU to AMD ZenDNN in labeler config

Co-authored-by: Aaron Teo <taronaeo@gmail.com>

---------

Co-authored-by: Aaron Teo <taronaeo@gmail.com>
2026-04-03 10:35:15 +08:00
Slobodan Josic 7c7d6ce5c7 [HIP] Bump ROCm version to 7.2.1 (#21066)
Bump ROCm version on Linux from 7.2 to 7.2.1
Add gfx1102 target
Delete LLVM workaround since ROCm 7.2.1 has fix for ROCm 7.2 perf regression https://github.com/ROCm/rocm-systems/issues/2865

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-04-03 00:59:20 +02:00
Piotr Wilkin (ilintar) 5208e2d5ba fix: gemma 4 template (#21326) 2026-04-02 23:31:02 +02:00
Bartowski 7992aa7c8e tests : add unit test coverage for llama_tensor_get_type (#20112)
* Add unit test coverage for llama_tensor_get_type

* Fix merge conflicts, add more schemas

* clang formatter changes

* Trailing whitespace

* Update name

* Start rebase

* Updating files with upstream changes prior to rebase

* Changes needed from rebase

* Update attn_qkv schema, change throw behaviour

* Fix merge conflicts

* White space

* Update with latest changes to state counters

* Revert accidental personal CLAUDE.md changes

* Change quotation mark

* Reuse metadata.name since we have it

* Move test-only stuff out of llama-quant.cpp

* Hide the regex functionality back in llama-quant.cpp, use a unique pointer to a new struct 'compiled_tensor_type_patterns' which contains the patterns

* cont : inital deslop guidelines

* Cleanup based on review comments

* Continue cleanup

* Small cleanup

* Manually set proper ordering of tensors, mostly applies to gemma

* Formatting

* Update tests/test-quant-type-selection.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Fix merge conflicts

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-04-02 22:53:58 +02:00
Zheyuan Chen a1cfb64530 ggml-webgpu: add vectorized flash attention (#20709)
* naive vectorized version

* add vectorized flash attention

* update vec version

* remove unused path and shader

* remove unused helper functions

* add comments

* remove pad path

* ggml-webgpu: fix flash-attn vec nwg=1 path and tighten vec specialization

* change back to vec4

* enable multi split

* enable vec path when:
- Q->ne[1] < 20
- Q->ne[0] % 32 == 0
- V->ne[0] % 4 == 0
- K->type == f16

* update flast_attn_vec_split.wgsl to reduce redundant workgroup barrier usage and use select

* enable vec path for q4 and q8

* flash-attn vec nwg=1 fast path (skip tmp/reduce staging)

* use packed f16 K loads in flash-attn vec split

* use packed f16 K loads in flash-attn vec split on host side

* tune flash-attn vec f16 VEC_NE by head dim

* cleanup

* cleanup

* keep host side clean

* cleanup host side

* change back to original host wait/submit behavior

* formatting

* reverted param-buffer pool r ecfactor

* add helper functions

* ggml-webgpu: move flash-attn vec pipeline caching back into shader lib

* ggml-webgpu: remove duplicate functions

* ggml-webgpu: reserve flash-attn vec scratch in dst buffer allocation

* ggml-webgpu: revert unrelated change

* ggml-webgpu: revert deleted comment

* disable uniformity check

* remove unnecessary change

* Update ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl

* Update ggml/src/ggml-webgpu/ggml-webgpu.cpp

---------

Co-authored-by: Reese Levine <reeselevine1@gmail.com>
2026-04-02 10:40:42 -07:00
Ruben Ortlam 5803c8d115 tests: allow exporting graph ops from HF file without downloading weights (#21182)
* tests: allow exporting graph ops from HF file without downloading weights

* use unique_ptr for llama_context in HF metadata case

* fix missing non-required tensors falling back to type f32

* use unique pointers where possible

* use no_alloc instead of fixing f32 fallback

* fix missing space
2026-04-02 18:19:20 +02:00
65 changed files with 40905 additions and 7573 deletions
-97
View File
@@ -1,97 +0,0 @@
ARG UBUNTU_VERSION=24.04
# This needs to generally match the container host's environment.
ARG CUDA_VERSION=13.1.1
# Target the CUDA build image
ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
ARG BASE_CUDA_RUN_CONTAINER=nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}
FROM ${BASE_CUDA_DEV_CONTAINER} AS build
# CUDA architecture to build for (defaults to all supported archs)
ARG CUDA_DOCKER_ARCH=default
RUN apt-get update && \
apt-get install -y gcc-14 g++-14 build-essential cmake python3 python3-pip git libssl-dev libgomp1
ENV CC=gcc-14 CXX=g++-14 CUDAHOSTCXX=g++-14
WORKDIR /app
COPY . .
RUN if [ "${CUDA_DOCKER_ARCH}" != "default" ]; then \
export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=${CUDA_DOCKER_ARCH}"; \
fi && \
cmake -B build -DGGML_NATIVE=OFF -DGGML_CUDA=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DLLAMA_BUILD_TESTS=OFF ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \
cmake --build build --config Release -j$(nproc)
RUN mkdir -p /app/lib && \
find build -name "*.so*" -exec cp -P {} /app/lib \;
RUN mkdir -p /app/full \
&& cp build/bin/* /app/full \
&& cp *.py /app/full \
&& cp -r gguf-py /app/full \
&& cp -r requirements /app/full \
&& cp requirements.txt /app/full \
&& cp .devops/tools.sh /app/full/tools.sh
## Base image
FROM ${BASE_CUDA_RUN_CONTAINER} AS base
RUN apt-get update \
&& apt-get install -y libgomp1 curl \
&& apt autoremove -y \
&& apt clean -y \
&& rm -rf /tmp/* /var/tmp/* \
&& find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \
&& find /var/cache -type f -delete
COPY --from=build /app/lib/ /app
### Full
FROM base AS full
COPY --from=build /app/full /app
WORKDIR /app
RUN apt-get update \
&& apt-get install -y \
git \
python3 \
python3-pip \
python3-wheel \
&& pip install --break-system-packages --upgrade setuptools \
&& pip install --break-system-packages -r requirements.txt \
&& apt autoremove -y \
&& apt clean -y \
&& rm -rf /tmp/* /var/tmp/* \
&& find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \
&& find /var/cache -type f -delete
ENTRYPOINT ["/app/tools.sh"]
### Light, CLI only
FROM base AS light
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
WORKDIR /app
ENTRYPOINT [ "/app/llama-cli" ]
### Server, Server only
FROM base AS server
ENV LLAMA_ARG_HOST=0.0.0.0
COPY --from=build /app/full/llama-server /app
WORKDIR /app
HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ]
ENTRYPOINT [ "/app/llama-server" ]
+3 -2
View File
@@ -16,7 +16,7 @@
rocmPackages,
vulkan-headers,
vulkan-loader,
curl,
openssl,
shaderc,
useBlas ?
builtins.all (x: !x) [
@@ -160,7 +160,8 @@ effectiveStdenv.mkDerivation (finalAttrs: {
++ optionals useMpi [ mpi ]
++ optionals useRocm rocmBuildInputs
++ optionals useBlas [ blas ]
++ optionals useVulkan vulkanBuildInputs;
++ optionals useVulkan vulkanBuildInputs
++ [ openssl ];
cmakeFlags =
[
+4 -4
View File
@@ -1,8 +1,8 @@
ARG UBUNTU_VERSION=24.04
# This needs to generally match the container host's environment.
ARG ROCM_VERSION=7.2
ARG AMDGPU_VERSION=7.2
ARG ROCM_VERSION=7.2.1
ARG AMDGPU_VERSION=7.2.1
# Target the ROCm build image
ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete
@@ -12,11 +12,11 @@ FROM ${BASE_ROCM_DEV_CONTAINER} AS build
# Unless otherwise specified, we make a fat build.
# This is mostly tied to rocBLAS supported archs.
# check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-7.2.0/reference/system-requirements.html
# check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-7.2.1/reference/system-requirements.html
# check https://rocm.docs.amd.com/projects/radeon-ryzen/en/latest/docs/compatibility/compatibilityrad/native_linux/native_linux_compatibility.html
# check https://rocm.docs.amd.com/projects/radeon-ryzen/en/latest/docs/compatibility/compatibilityryz/native_linux/native_linux_compatibility.html
ARG ROCM_DOCKER_ARCH='gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1151;gfx1150;gfx1200;gfx1201'
ARG ROCM_DOCKER_ARCH='gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1150;gfx1200;gfx1201'
# Set ROCm architectures
ENV AMDGPU_TARGETS=${ROCM_DOCKER_ARCH}
+5
View File
@@ -27,6 +27,11 @@ IBM zDNN:
- any-glob-to-any-file:
- ggml/include/ggml-zdnn.h
- ggml/src/ggml-zdnn/**
AMD ZenDNN:
- changed-files:
- any-glob-to-any-file:
- ggml/include/ggml-zendnn.h
- ggml/src/ggml-zendnn/**
documentation:
- changed-files:
- any-glob-to-any-file:
+4 -2
View File
@@ -472,6 +472,7 @@ jobs:
cmake -B build -S . \
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
-DGGML_HIP_ROCWMMA_FATTN=ON \
-DGPU_TARGETS="gfx1030" \
-DGGML_HIP=ON
cmake --build build --config Release -j $(nproc)
@@ -941,7 +942,7 @@ jobs:
- name: Grab rocWMMA package
id: grab_rocwmma
run: |
curl -o rocwmma.deb "https://repo.radeon.com/rocm/apt/7.2/pool/main/r/rocwmma-dev/rocwmma-dev_2.2.0.70200-43~24.04_amd64.deb"
curl -o rocwmma.deb "https://repo.radeon.com/rocm/apt/7.2.1/pool/main/r/rocwmma-dev/rocwmma-dev_2.2.0.70201-81~24.04_amd64.deb"
7z x rocwmma.deb
7z x data.tar
@@ -984,12 +985,13 @@ jobs:
cmake -G "Unix Makefiles" -B build -S . `
-DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" `
-DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" `
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-7.2.0/include/" `
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-7.2.1/include/" `
-DCMAKE_BUILD_TYPE=Release `
-DLLAMA_BUILD_BORINGSSL=ON `
-DROCM_DIR="${env:HIP_PATH}" `
-DGGML_HIP=ON `
-DGGML_HIP_ROCWMMA_FATTN=ON `
-DGPU_TARGETS="gfx1100" `
-DGGML_RPC=ON
cmake --build build -j ${env:NUMBER_OF_PROCESSORS}
+4 -4
View File
@@ -73,10 +73,10 @@ jobs:
{ "tag": "cpu", "dockerfile": ".devops/cpu.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04" },
{ "tag": "cpu", "dockerfile": ".devops/cpu.Dockerfile", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-arm" },
{ "tag": "cpu", "dockerfile": ".devops/s390x.Dockerfile", "platforms": "linux/s390x", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-s390x" },
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" },
{ "tag": "cuda13", "dockerfile": ".devops/cuda-new.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
{ "tag": "cuda13", "dockerfile": ".devops/cuda-new.Dockerfile", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" },
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "12.9.1", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "12.9.1", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" },
{ "tag": "cuda13", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "13.1.1", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
{ "tag": "cuda13", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "13.1.1", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" },
{ "tag": "musa", "dockerfile": ".devops/musa.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
{ "tag": "intel", "dockerfile": ".devops/intel.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
{ "tag": "vulkan", "dockerfile": ".devops/vulkan.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04" },
+2 -2
View File
@@ -35,7 +35,7 @@ env:
jobs:
ubuntu-22-hip-quality-check:
runs-on: ubuntu-22.04
container: rocm/dev-ubuntu-22.04:7.2
container: rocm/dev-ubuntu-22.04:7.2.1
steps:
- name: Clone
id: checkout
@@ -59,7 +59,7 @@ jobs:
run: |
cmake -B build -S . \
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
-DGPU_TARGETS=gfx908 \
-DGPU_TARGETS=gfx942 \
-DGGML_HIP=ON \
-DGGML_HIP_EXPORT_METRICS=Off \
-DCMAKE_HIP_FLAGS="-Werror -Wno-tautological-compare" \
+12 -10
View File
@@ -639,8 +639,8 @@ jobs:
strategy:
matrix:
include:
- ROCM_VERSION: "7.2"
gpu_targets: "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1151;gfx1150;gfx1200;gfx1201"
- ROCM_VERSION: "7.2.1"
gpu_targets: "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1150;gfx1200;gfx1201"
build: 'x64'
steps:
@@ -662,7 +662,7 @@ jobs:
sudo apt install -y build-essential git cmake wget
- name: Setup Legacy ROCm
if: matrix.ROCM_VERSION == '7.2'
if: matrix.ROCM_VERSION == '7.2.1'
id: legacy_env
run: |
sudo mkdir --parents --mode=0755 /etc/apt/keyrings
@@ -683,7 +683,7 @@ jobs:
sudo apt-get install -y libssl-dev rocm-hip-sdk
- name: Setup TheRock
if: matrix.ROCM_VERSION != '7.2'
if: matrix.ROCM_VERSION != '7.2.1'
id: therock_env
run: |
wget https://repo.amd.com/rocm/tarball/therock-dist-linux-gfx1151-${{ matrix.ROCM_VERSION }}.tar.gz
@@ -699,7 +699,6 @@ jobs:
run: |
cmake -B build -S . \
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
-DCMAKE_HIP_FLAGS="-mllvm --amdgpu-unroll-threshold-local=600" \
-DCMAKE_BUILD_TYPE=Release \
-DGGML_BACKEND_DL=ON \
-DGGML_NATIVE=OFF \
@@ -717,17 +716,20 @@ jobs:
id: tag
uses: ./.github/actions/get-tag-name
- name: Get ROCm short version
run: echo "ROCM_VERSION_SHORT=$(echo '${{ matrix.ROCM_VERSION }}' | cut -d '.' -f 1,2)" >> $GITHUB_ENV
- name: Pack artifacts
id: pack_artifacts
run: |
cp LICENSE ./build/bin/
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-rocm-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-rocm-${{ env.ROCM_VERSION_SHORT }}-${{ matrix.build }}.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
- name: Upload artifacts
uses: actions/upload-artifact@v6
with:
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-rocm-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}.tar.gz
name: llama-bin-ubuntu-rocm-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}.tar.gz
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-rocm-${{ env.ROCM_VERSION_SHORT }}-${{ matrix.build }}.tar.gz
name: llama-bin-ubuntu-rocm-${{ env.ROCM_VERSION_SHORT }}-${{ matrix.build }}.tar.gz
windows-hip:
runs-on: windows-2022
@@ -749,7 +751,7 @@ jobs:
- name: Grab rocWMMA package
id: grab_rocwmma
run: |
curl -o rocwmma.deb "https://repo.radeon.com/rocm/apt/7.2/pool/main/r/rocwmma-dev/rocwmma-dev_2.2.0.70200-43~24.04_amd64.deb"
curl -o rocwmma.deb "https://repo.radeon.com/rocm/apt/7.2.1/pool/main/r/rocwmma-dev/rocwmma-dev_2.2.0.70201-81~24.04_amd64.deb"
7z x rocwmma.deb
7z x data.tar
@@ -806,7 +808,7 @@ jobs:
cmake -G "Unix Makefiles" -B build -S . `
-DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" `
-DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" `
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-7.2.0/include/ -Wno-ignored-attributes -Wno-nested-anon-types" `
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-7.2.1/include/ -Wno-ignored-attributes -Wno-nested-anon-types" `
-DCMAKE_BUILD_TYPE=Release `
-DGGML_BACKEND_DL=ON `
-DGGML_NATIVE=OFF `
+30 -2
View File
@@ -221,7 +221,7 @@ function gg_run_ctest_debug {
set -e
# Check cmake and ctest are installed
# Check required binaries are installed
gg_check_build_requirements
(cmake -G "${CMAKE_GENERATOR}" -DCMAKE_BUILD_TYPE=Debug ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log
@@ -252,7 +252,7 @@ function gg_run_ctest_release {
set -e
# Check cmake and ctest are installed
# Check required binaries are installed
gg_check_build_requirements
(cmake -G "${CMAKE_GENERATOR}" -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log
@@ -627,10 +627,38 @@ function gg_sum_rerank_tiny {
}
function gg_check_build_requirements {
if ! command -v git &> /dev/null; then
gg_printf 'git not found, please install'
fi
if ! command -v git-lfs &> /dev/null; then
gg_printf 'git-lfs not found, please install'
fi
if ! command -v wget &> /dev/null; then
gg_printf 'wget not found, please install'
fi
if ! command -v python3 &> /dev/null; then
gg_printf 'python3 not found, please install'
fi
if ! command -v pip3 &> /dev/null; then
gg_printf 'pip3 not found, please install'
fi
if ! python3 -m ensurepip --help &> /dev/null; then
gg_printf 'ensurepip not found, please install python3-venv package'
fi
if ! command -v cmake &> /dev/null; then
gg_printf 'cmake not found, please install'
fi
if ! command -v ccache &> /dev/null; then
gg_printf 'ccache not found, please consider installing for faster builds'
fi
if ! command -v ctest &> /dev/null; then
gg_printf 'ctest not found, please install'
fi
+5 -3
View File
@@ -537,9 +537,11 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
} catch (const std::exception & e) {
LOG_WRN("HF cache migration failed: %s\n", e.what());
}
// export_graph_ops loads only metadata
const bool skip_model_download = ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
// maybe handle remote preset
if (!params.model.hf_repo.empty()) {
if (!params.model.hf_repo.empty() && !skip_model_download) {
std::string cli_hf_repo = params.model.hf_repo;
bool has_preset = common_params_handle_remote_preset(params, ctx_arg.ex);
@@ -570,7 +572,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
}
// handle model and download
{
if (!skip_model_download) {
auto res = common_params_handle_model(params.model, params.hf_token, params.offline);
if (params.no_mmproj) {
params.mmproj = {};
@@ -591,7 +593,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
// model is required (except for server)
// TODO @ngxson : maybe show a list of available models in CLI in this case
if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER && !params.usage && !params.completion) {
if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER && !skip_model_download && !params.usage && !params.completion) {
throw std::invalid_argument("error: --model is required\n");
}
+171 -16
View File
@@ -7,11 +7,109 @@
#include "log.h"
#include "nlohmann/json.hpp"
#include <algorithm>
#include <stdexcept>
#include <string>
using json = nlohmann::ordered_json;
namespace {
// Gemma4-specific PEG builder extending the standard chat builder.
// Adds value type parsers that use <|\"|> as string delimiters
// instead of JSON's double quotes, and disables json-to-schema
// conversion for these types.
class common_peg_gemma4_builder {
common_chat_peg_builder & p_;
static constexpr const char * QUOTE = "<|\"|>";
public:
explicit common_peg_gemma4_builder(common_chat_peg_builder & p) : p_(p) {}
common_peg_parser gemma4_string() {
return p_.rule("gemma4-string", [&]() {
return p_.literal(QUOTE) + p_.until(QUOTE) + p_.literal(QUOTE);
});
}
common_peg_parser gemma4_number() {
return p_.rule("gemma4-number", [&]() {
auto digit1_9 = p_.chars("[1-9]", 1, 1);
auto digits = p_.chars("[0-9]");
auto int_part = p_.choice({p_.literal("0"), p_.sequence({digit1_9, p_.chars("[0-9]", 0, -1)})});
auto frac = p_.sequence({p_.literal("."), digits});
auto exp = p_.sequence({p_.choice({p_.literal("e"), p_.literal("E")}),
p_.optional(p_.chars("[+-]", 1, 1)), digits});
auto not_number_continuation = p_.negate(p_.chars("[0-9.eE+-]", 1, 1));
return p_.sequence({p_.optional(p_.literal("-")), int_part, p_.optional(frac),
p_.optional(exp), not_number_continuation});
});
}
common_peg_parser gemma4_bool() {
return p_.rule("gemma4-bool", [&]() {
return p_.choice({p_.literal("true"), p_.literal("false")});
});
}
common_peg_parser gemma4_null() {
return p_.rule("gemma4-null", [&]() {
return p_.literal("null");
});
}
common_peg_parser gemma4_dict() {
return p_.rule("gemma4-dict", [&]() {
auto ws = p_.space();
auto key = p_.until(":");
auto member = p_.sequence({key, p_.literal(":"), ws, gemma4_value()});
auto members = p_.sequence({member, p_.zero_or_more(p_.sequence({p_.literal(","), ws, member}))});
return p_.sequence({
p_.literal("{"), ws,
p_.choice({p_.literal("}"), p_.sequence({members, ws, p_.literal("}")})})
});
});
}
common_peg_parser gemma4_array() {
return p_.rule("gemma4-array", [&]() {
auto ws = p_.space();
auto elements = p_.sequence({gemma4_value(), p_.zero_or_more(p_.sequence({p_.literal(","), ws, gemma4_value()}))});
return p_.sequence({
p_.literal("["), ws,
p_.choice({p_.literal("]"), p_.sequence({elements, ws, p_.literal("]")})})
});
});
}
common_peg_parser gemma4_value() {
return p_.rule("gemma4-value", [&]() {
return p_.choice({gemma4_string(), gemma4_dict(), gemma4_array(),
gemma4_number(), gemma4_bool(), gemma4_null()});
});
}
// Select the appropriate value parser based on JSON schema type.
// Does NOT use schema() - the gemma4 types are pure PEG without
// JSON schema metadata, so GBNF is generated directly from the
// PEG structure.
common_peg_parser gemma4_value_for_type(const json & schema) {
if (!schema.contains("type") || !schema.at("type").is_string()) {
return gemma4_value();
}
std::string type = schema.at("type").get<std::string>();
if (type == "string") { return gemma4_string(); }
if (type == "number") { return gemma4_number(); }
if (type == "integer") { return gemma4_number(); }
if (type == "boolean") { return gemma4_bool(); }
if (type == "object") { return gemma4_dict(); }
if (type == "array") { return gemma4_array(); }
return gemma4_value();
}
};
} // anonymous namespace
// Helper to iterate over tools/functions
static void foreach_function(const json & tools, const std::function<void(const json &)> & fn) {
for (const auto & tool : tools) {
@@ -43,7 +141,9 @@ common_chat_params peg_generator::generate_parser(const common_chat_template &
// Create the result structure
common_chat_params data;
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.format = (autoparser.tools.format.mode == tool_format::TAG_WITH_GEMMA4_DICT)
? COMMON_CHAT_FORMAT_PEG_GEMMA4
: COMMON_CHAT_FORMAT_PEG_NATIVE;
data.preserved_tokens = autoparser.preserved_tokens;
auto parser = autoparser.build_parser(inputs);
@@ -92,6 +192,7 @@ common_peg_arena autoparser::build_parser(const generation_params & inputs) cons
ctx.extracting_reasoning = extract_reasoning && reasoning.mode != reasoning_mode::NONE;
ctx.content = &content;
ctx.reasoning = &reasoning;
// Build reasoning parser
ctx.reasoning_parser = reasoning.build_parser(ctx);
@@ -299,12 +400,34 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
for (const auto & [param_name, param_schema] : properties.items()) {
bool is_required = required.find(param_name) != required.end();
std::string type = "object";
auto type_obj = param_schema.contains("type") ? param_schema.at("type") : json::object();
if (type_obj.is_string()) {
type_obj.get_to(type);
} else if (type_obj.is_object()) {
if (type_obj.contains("type") && type_obj.at("type").is_string()) {
type_obj.at("type").get_to(type);
if (param_schema.contains("type")) {
const auto & type_obj = param_schema.at("type");
if (type_obj.is_string()) {
type_obj.get_to(type);
} else if (type_obj.is_array()) {
// Handle nullable types like ["string", "null"]
for (const auto & t : type_obj) {
if (t.is_string() && t.get<std::string>() != "null") {
type = t.get<std::string>();
break;
}
}
} else if (type_obj.is_object()) {
if (type_obj.contains("type") && type_obj.at("type").is_string()) {
type_obj.at("type").get_to(type);
}
}
}
// Infer string type from enum values when type is unspecified
if (type == "object" && param_schema.contains("enum")) {
const auto & enum_vals = param_schema.at("enum");
if (enum_vals.is_array()) {
for (const auto & v : enum_vals) {
if (v.is_string()) {
type = "string";
break;
}
}
}
}
@@ -440,7 +563,7 @@ common_peg_parser analyze_tools::build_tool_parser_tag_gemma4_dict(parser_build_
const auto & inputs = ctx.inputs;
bool force_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED;
// The Gemma4 string quote token used in place of JSON "
common_peg_gemma4_builder g4(p);
static const std::string QUOTE = "<|\"|>";
common_peg_parser tool_choice = p.choice();
@@ -451,7 +574,6 @@ common_peg_parser analyze_tools::build_tool_parser_tag_gemma4_dict(parser_build_
const auto & params = func.at("parameters");
if (!params.contains("properties") || !params.at("properties").is_object()) {
// No arguments - just match the function name with empty braces
auto func_parser = p.atomic(
p.tool_open(p.literal(function.name_prefix) + p.tool_name(p.literal(name)) + p.literal("{")) +
p.tool_args(p.eps()) +
@@ -474,9 +596,33 @@ common_peg_parser analyze_tools::build_tool_parser_tag_gemma4_dict(parser_build_
std::vector<arg_entry> arg_entries;
for (const auto & [param_name, param_schema] : properties.items()) {
std::string type = "object";
auto type_v = param_schema.contains("type") ? param_schema.at("type") : json::object();
if (type_v.is_string()) type_v.get_to(type);
std::string type = "object";
if (param_schema.contains("type")) {
const auto & type_v = param_schema.at("type");
if (type_v.is_string()) {
type_v.get_to(type);
} else if (type_v.is_array()) {
// Handle nullable types like ["string", "null"]
for (const auto & t : type_v) {
if (t.is_string() && t.get<std::string>() != "null") {
type = t.get<std::string>();
break;
}
}
}
}
// Infer string type from enum values when type is unspecified
if (type == "object" && param_schema.contains("enum")) {
const auto & enum_vals = param_schema.at("enum");
if (enum_vals.is_array()) {
for (const auto & v : enum_vals) {
if (v.is_string()) {
type = "string";
break;
}
}
}
}
common_peg_parser value_parser = p.eps();
if (type == "string") {
@@ -486,9 +632,18 @@ common_peg_parser analyze_tools::build_tool_parser_tag_gemma4_dict(parser_build_
p.tool_arg_string_value(p.schema(p.until(QUOTE),
"tool-" + name + "-arg-" + param_name + "-schema", param_schema, true)) +
p.literal(QUOTE);
} else if (type == "number" || type == "integer") {
value_parser = p.tool_arg_value(g4.gemma4_number());
} else if (type == "boolean") {
value_parser = p.tool_arg_value(g4.gemma4_bool());
} else if (type == "null") {
value_parser = p.tool_arg_value(g4.gemma4_null());
} else if (type == "object") {
value_parser = p.tool_arg_value(g4.gemma4_dict());
} else if (type == "array") {
value_parser = p.tool_arg_value(g4.gemma4_array());
} else {
// Numbers, booleans: raw text up to the next comma or closing brace
value_parser = p.tool_arg_value(p.until_one_of({",", "}"}));
value_parser = p.tool_arg_value(g4.gemma4_value());
}
auto arg = p.tool_arg(
@@ -538,9 +693,9 @@ common_peg_parser analyze_tools::build_tool_parser_tag_gemma4_dict(parser_build_
tool_calls = p.optional(tool_calls);
}
auto content_before_tools = p.until(format.per_call_start);
auto content_before_tools = p.until_one_of({ format.per_call_start, ctx.reasoning->start });
return ctx.reasoning_parser +
(force_tools ? p.eps() : p.optional(p.content(content_before_tools))) +
(force_tools ? p.eps() : p.optional(p.content(content_before_tools) + p.optional(ctx.reasoning_parser))) +
tool_calls + p.end();
}
+1 -1
View File
@@ -1,7 +1,7 @@
#pragma once
#include "chat-auto-parser.h"
#include "peg-parser.h"
#include <functional>
#include <optional>
#include <string>
+4 -1
View File
@@ -4,6 +4,7 @@
#include "common.h"
#include "jinja/caps.h"
#include "peg-parser.h"
#include "nlohmann/json.hpp"
#include <chrono>
#include <optional>
@@ -215,12 +216,14 @@ struct tool_id_analysis {
// ============================================================================
struct analyze_content;
struct analyze_reasoning;
struct parser_build_context {
common_chat_peg_builder & p;
const generation_params & inputs;
const generation_params & inputs;
common_peg_parser reasoning_parser;
bool extracting_reasoning = false;
const analyze_reasoning * reasoning = nullptr;
const analyze_content * content = nullptr;
parser_build_context(common_chat_peg_builder & p, const generation_params & inputs);
+2 -1
View File
@@ -104,10 +104,11 @@ static std::vector<std::function<void(const common_chat_template & tmpl, autopar
analysis.tools.function.name_suffix = "";
analysis.tools.arguments.start = "{";
analysis.tools.arguments.end = "}";
analysis.tools.arguments.name_prefix = "";
analysis.tools.arguments.name_suffix = ":";
analysis.tools.arguments.separator = ",";
analysis.reasoning.mode = reasoning_mode::TAG_BASED;
analysis.reasoning.start = "<|channel>thought\n";
analysis.reasoning.start = "<|channel>thought";
analysis.reasoning.end = "<channel|>";
analysis.preserved_tokens.clear();
analysis.preserved_tokens.push_back("<|tool_call>");
+87 -1
View File
@@ -75,6 +75,84 @@ static std::string escape_json_string_inner(const std::string & s) {
return escaped;
}
static const std::string GEMMA4_QUOTE = "<|\"|>";
static std::string normalize_gemma4_to_json(const std::string & input) {
std::string result;
result.reserve(input.size() * 2);
enum Ctx { DICT, ARRAY };
std::vector<Ctx> ctx;
auto is_ws = [](char c) { return c == ' ' || c == '\t' || c == '\n' || c == '\r'; };
auto skip_ws = [&](size_t & pos) {
while (pos < input.size() && is_ws(input[pos])) {
result += input[pos++];
}
};
auto quote_unquoted_key = [&](size_t & pos) {
if (pos < input.size() && input[pos] != '"' && input[pos] != '}') {
result += '"';
while (pos < input.size() && input[pos] != ':' && !is_ws(input[pos])) {
result += input[pos++];
}
result += '"';
skip_ws(pos);
}
};
size_t i = 0;
while (i < input.size()) {
if (i + GEMMA4_QUOTE.size() <= input.size() &&
input.compare(i, GEMMA4_QUOTE.size(), GEMMA4_QUOTE) == 0) {
result += '"';
i += GEMMA4_QUOTE.size();
continue;
}
char c = input[i];
if (c == '{') {
result += c;
ctx.push_back(DICT);
++i;
skip_ws(i);
quote_unquoted_key(i);
continue;
}
if (c == '}') {
result += c;
if (!ctx.empty()) ctx.pop_back();
++i;
continue;
}
if (c == '[') {
result += c;
ctx.push_back(ARRAY);
++i;
continue;
}
if (c == ']') {
result += c;
if (!ctx.empty()) ctx.pop_back();
++i;
continue;
}
if (c == ',' && !ctx.empty() && ctx.back() == DICT) {
result += c;
++i;
skip_ws(i);
quote_unquoted_key(i);
continue;
}
result += c;
++i;
}
return result;
}
// Convert Python-style single-quoted strings to JSON double-quoted strings
// Only converts outer string delimiters, properly handling escape sequences:
// - {'key': 'value'} -> {"key": "value"}
@@ -214,6 +292,14 @@ std::string & common_chat_peg_mapper::args_target() {
return (current_tool && !current_tool->name.empty()) ? current_tool->arguments : args_buffer;
}
std::string common_chat_peg_mapper::normalize_container_value(const std::string & input) {
return normalize_quotes_to_json(input);
}
std::string common_chat_peg_gemma4_mapper::normalize_container_value(const std::string & input) {
return normalize_quotes_to_json(normalize_gemma4_to_json(input));
}
void common_chat_peg_mapper::from_ast(const common_peg_ast_arena & arena,
const common_peg_parse_result & parse_result_arg) {
arena.visit(parse_result_arg, [this](const common_peg_ast_node & node) { map(node); });
@@ -352,7 +438,7 @@ void common_chat_peg_mapper::map(const common_peg_ast_node & node) {
// For potential containers, normalize Python-style single quotes to JSON double quotes
bool is_potential_container = value_content[0] == '[' || value_content[0] == '{';
if (is_potential_container) {
value_content = normalize_quotes_to_json(value_content);
value_content = normalize_container_value(value_content);
}
// Try to parse as JSON value (number, bool, null, object, array)
+10 -1
View File
@@ -17,7 +17,9 @@ class common_chat_peg_mapper {
virtual void from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result);
virtual void map(const common_peg_ast_node & node);
private:
protected:
virtual std::string normalize_container_value(const std::string & input);
private:
// Tool call handling state
std::optional<common_chat_tool_call> pending_tool_call; // Tool call waiting for name
common_chat_tool_call * current_tool = nullptr;
@@ -30,6 +32,13 @@ class common_chat_peg_mapper {
std::string & args_target();
};
class common_chat_peg_gemma4_mapper : public common_chat_peg_mapper {
public:
common_chat_peg_gemma4_mapper(common_chat_msg & msg) : common_chat_peg_mapper(msg) {}
protected:
std::string normalize_container_value(const std::string & input) override;
};
struct content_structure;
struct tool_call_structure;
+38 -18
View File
@@ -13,6 +13,8 @@
#include "jinja/caps.h"
#include "peg-parser.h"
#include "nlohmann/json.hpp"
#include <cstdio>
#include <cstdlib>
#include <ctime>
@@ -694,6 +696,8 @@ const char * common_chat_format_name(common_chat_format format) {
return "peg-simple";
case COMMON_CHAT_FORMAT_PEG_NATIVE:
return "peg-native";
case COMMON_CHAT_FORMAT_PEG_GEMMA4:
return "peg-gemma4";
default:
throw std::runtime_error("Unknown chat format");
}
@@ -760,12 +764,12 @@ static void foreach_parameter(const json &
}
}
std::string common_chat_template_direct_apply(
static std::string common_chat_template_direct_apply_impl(
const common_chat_template & tmpl,
const autoparser::generation_params & inputs,
const std::optional<json> & messages_override,
const std::optional<json> & tools_override,
const std::optional<json> & additional_context) {
const std::optional<json> & messages_override = std::nullopt,
const std::optional<json> & tools_override = std::nullopt,
const std::optional<json> & additional_context = std::nullopt) {
jinja::context ctx(tmpl.source());
nlohmann::ordered_json inp = nlohmann::ordered_json{
@@ -812,6 +816,12 @@ std::string common_chat_template_direct_apply(
return result;
}
std::string common_chat_template_direct_apply(
const common_chat_template & tmpl,
const autoparser::generation_params & inputs) {
return common_chat_template_direct_apply_impl(tmpl, inputs, std::nullopt, std::nullopt, std::nullopt);
}
static common_chat_params common_chat_params_init_ministral_3(const common_chat_template & tmpl,
const autoparser::generation_params & inputs) {
common_chat_params data;
@@ -862,7 +872,7 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
data.supports_thinking = true;
data.thinking_start_tag = "[THINK]";
data.thinking_end_tag = "[/THINK]";
data.prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override = */ adjusted_messages);
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs, /* messages_override = */ adjusted_messages);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.preserved_tokens = {
"[THINK]",
@@ -945,7 +955,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
adjusted_messages.push_back(msg);
}
auto prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override= */ adjusted_messages);
auto prompt = common_chat_template_direct_apply_impl(tmpl, inputs, /* messages_override= */ adjusted_messages);
// Check if we need to replace the return token with end token during
// inference and without generation prompt. For more details see:
@@ -1072,7 +1082,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
const autoparser::generation_params & inputs) {
common_chat_params data;
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.preserved_tokens = {
">>>all",
@@ -1166,7 +1176,7 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
const autoparser::generation_params & inputs) {
common_chat_params data;
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.supports_thinking = true;
data.preserved_tokens = {
@@ -1289,7 +1299,7 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
const autoparser::generation_params & inputs) {
common_chat_params data;
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.supports_thinking = true;
data.preserved_tokens = {
@@ -1368,7 +1378,7 @@ static common_chat_params common_chat_params_init_lfm2_5(const common_chat_templ
const autoparser::generation_params & inputs) {
common_chat_params data;
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.supports_thinking = true;
data.preserved_tokens = {
@@ -1439,7 +1449,7 @@ static common_chat_params common_chat_params_init_gigachat_v3(
common_chat_params data;
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.supports_thinking = false;
data.preserved_tokens = {
@@ -1722,9 +1732,9 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
}
params.add_generation_prompt = false;
std::string no_gen_prompt = common_chat_template_direct_apply(tmpl, params);
std::string no_gen_prompt = common_chat_template_direct_apply_impl(tmpl, params);
params.add_generation_prompt = true;
std::string gen_prompt = common_chat_template_direct_apply(tmpl, params);
std::string gen_prompt = common_chat_template_direct_apply_impl(tmpl, params);
auto diff = calculate_diff_split(no_gen_prompt, gen_prompt);
params.generation_prompt = diff.right;
@@ -1758,7 +1768,7 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
common_chat_params data;
auto params_copy = params;
params_copy.reasoning_format = COMMON_REASONING_FORMAT_NONE;
data.prompt = common_chat_template_direct_apply(tmpl, params_copy);
data.prompt = common_chat_template_direct_apply_impl(tmpl, params_copy);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.generation_prompt = params.generation_prompt;
auto parser = build_chat_peg_parser([&params](common_chat_peg_builder &p) {
@@ -1905,8 +1915,13 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
// Try to extract any partial results from what was successfully parsed
common_chat_msg msg;
msg.role = "assistant";
auto mapper = common_chat_peg_mapper(msg);
mapper.from_ast(ctx.ast, result);
std::unique_ptr<common_chat_peg_mapper> mapper;
if (params.format == COMMON_CHAT_FORMAT_PEG_GEMMA4) {
mapper = std::make_unique<common_chat_peg_gemma4_mapper>(msg);
} else {
mapper = std::make_unique<common_chat_peg_mapper>(msg);
}
mapper->from_ast(ctx.ast, result);
if (ctx.is_debug()) {
fprintf(stderr, "\nAST for partial parse (fail):\n%s\n", ctx.ast.dump().c_str());
@@ -1921,8 +1936,13 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
common_chat_msg msg;
msg.role = "assistant";
auto mapper = common_chat_peg_mapper(msg);
mapper.from_ast(ctx.ast, result);
std::unique_ptr<common_chat_peg_mapper> mapper;
if (params.format == COMMON_CHAT_FORMAT_PEG_GEMMA4) {
mapper = std::make_unique<common_chat_peg_gemma4_mapper>(msg);
} else {
mapper = std::make_unique<common_chat_peg_mapper>(msg);
}
mapper->from_ast(ctx.ast, result);
if (ctx.is_debug()) {
fprintf(stderr, "\nAST for %s parse:\n%s\n", is_partial ? "partial" : "full", ctx.ast.dump().c_str());
+10 -46
View File
@@ -3,12 +3,12 @@
#pragma once
#include "common.h"
#include "jinja/parser.h"
#include "nlohmann/json_fwd.hpp"
#include "peg-parser.h"
#include "jinja/parser.h"
#include "jinja/runtime.h"
#include "jinja/caps.h"
#include "nlohmann/json.hpp"
#include "nlohmann/json_fwd.hpp"
#include <chrono>
#include <functional>
@@ -19,8 +19,6 @@
using chat_template_caps = jinja::caps;
using json = nlohmann::ordered_json;
#include <nlohmann/json_fwd.hpp>
struct common_chat_templates;
namespace autoparser {
@@ -75,41 +73,9 @@ struct common_chat_template {
const std::string & bos_token() const { return bos_tok; }
const std::string & eos_token() const { return eos_tok; }
// TODO: this is ugly, refactor it somehow
json add_system(const json & messages, const std::string & system_prompt) const {
GGML_ASSERT(messages.is_array());
auto msgs_copy = messages;
if (!caps.supports_system_role) {
if (msgs_copy.empty()) {
msgs_copy.insert(msgs_copy.begin(), json{
{"role", "user"},
{"content", system_prompt}
});
} else {
auto & first_msg = msgs_copy[0];
if (!first_msg.contains("content")) {
first_msg["content"] = "";
}
first_msg["content"] = system_prompt + "\n\n"
+ first_msg["content"].get<std::string>();
}
} else {
if (msgs_copy.empty() || msgs_copy[0].at("role") != "system") {
msgs_copy.insert(msgs_copy.begin(), json{
{"role", "system"},
{"content", system_prompt}
});
} else if (msgs_copy[0].at("role") == "system") {
msgs_copy[0]["content"] = system_prompt;
}
}
return msgs_copy;
}
chat_template_caps original_caps() const {
return caps;
}
};
struct common_chat_msg {
@@ -184,6 +150,7 @@ enum common_chat_format {
// These are intended to be parsed by the PEG parser
COMMON_CHAT_FORMAT_PEG_SIMPLE,
COMMON_CHAT_FORMAT_PEG_NATIVE,
COMMON_CHAT_FORMAT_PEG_GEMMA4,
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
};
@@ -256,8 +223,8 @@ common_chat_templates_ptr common_chat_templates_init(const struct llama_model *
const std::string & bos_token_override = "",
const std::string & eos_token_override = "");
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant = "");
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant = "");
struct common_chat_params common_chat_templates_apply(const struct common_chat_templates * tmpls,
const struct common_chat_templates_inputs & inputs);
@@ -274,9 +241,9 @@ std::string common_chat_format_example(const struct common_chat_templates *
bool use_jinja,
const std::map<std::string, std::string> & chat_template_kwargs);
const char * common_chat_format_name(common_chat_format format);
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & params);
common_chat_msg common_chat_peg_parse(const common_peg_arena & src_parser, const std::string & input, bool is_partial, const common_chat_parser_params & params);
const char * common_chat_format_name(common_chat_format format);
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & params);
common_chat_msg common_chat_peg_parse(const common_peg_arena & src_parser, const std::string & input, bool is_partial, const common_chat_parser_params & params);
// used by arg and server
const char * common_reasoning_format_name(common_reasoning_format format);
@@ -302,7 +269,4 @@ std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_tem
std::string common_chat_template_direct_apply(
const common_chat_template & tmpl,
const autoparser::generation_params & inputs,
const std::optional<json> & messages_override = std::nullopt,
const std::optional<json> & tools_override = std::nullopt,
const std::optional<json> & additional_context = std::nullopt);
const autoparser::generation_params & inputs);
+1
View File
@@ -1442,6 +1442,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
mparams.progress_callback = params.load_progress_callback;
mparams.progress_callback_user_data = params.load_progress_callback_user_data;
mparams.no_alloc = params.no_alloc;
return mparams;
}
+1
View File
@@ -679,6 +679,7 @@ struct common_params {
// return false from callback to abort model loading or true to continue
llama_progress_callback load_progress_callback = NULL;
void * load_progress_callback_user_data = NULL;
bool no_alloc = false; // Don't allocate model buffers
};
// call once at the start of a program if it uses libcommon
+13
View File
@@ -306,6 +306,19 @@ value filter_expression::execute_impl(context & ctx) {
filter_id = "strip"; // alias
}
JJ_DEBUG("Applying filter '%s' to %s", filter_id.c_str(), input->type().c_str());
// TODO: Refactor filters so this coercion can be done automatically
if (!input->is_undefined() && !is_val<value_string>(input) && (
filter_id == "capitalize" ||
filter_id == "lower" ||
filter_id == "replace" ||
filter_id == "strip" ||
filter_id == "title" ||
filter_id == "upper" ||
filter_id == "wordcount"
)) {
JJ_DEBUG("Coercing %s to String for '%s' filter", input->type().c_str(), filter_id.c_str());
input = mk_val<value_string>(input->as_string());
}
return try_builtin_func(ctx, filter_id, input)->invoke(func_args(ctx));
} else if (is_stmt<call_expression>(filter)) {
+16 -16
View File
@@ -465,8 +465,9 @@ const func_builtins & value_int_t::get_builtins() const {
double val = static_cast<double>(args.get_pos(0)->as_int());
return mk_val<value_float>(val);
}},
{"tojson", tojson},
{"safe", tojson},
{"string", tojson},
{"tojson", tojson},
};
return builtins;
}
@@ -485,8 +486,9 @@ const func_builtins & value_float_t::get_builtins() const {
int64_t val = static_cast<int64_t>(args.get_pos(0)->as_float());
return mk_val<value_int>(val);
}},
{"tojson", tojson},
{"safe", tojson},
{"string", tojson},
{"tojson", tojson},
};
return builtins;
}
@@ -771,6 +773,11 @@ const func_builtins & value_string_t::get_builtins() const {
const func_builtins & value_bool_t::get_builtins() const {
static const func_handler tostring = [](const func_args & args) -> value {
args.ensure_vals<value_bool>();
bool val = args.get_pos(0)->as_bool();
return mk_val<value_string>(val ? "True" : "False");
};
static const func_builtins builtins = {
{"default", default_value},
{"int", [](const func_args & args) -> value {
@@ -783,11 +790,8 @@ const func_builtins & value_bool_t::get_builtins() const {
bool val = args.get_pos(0)->as_bool();
return mk_val<value_float>(val ? 1.0 : 0.0);
}},
{"string", [](const func_args & args) -> value {
args.ensure_vals<value_bool>();
bool val = args.get_pos(0)->as_bool();
return mk_val<value_string>(val ? "True" : "False");
}},
{"safe", tostring},
{"string", tostring},
{"tojson", tojson},
};
return builtins;
@@ -1100,18 +1104,14 @@ const func_builtins & value_object_t::get_builtins() const {
}
const func_builtins & value_none_t::get_builtins() const {
static const func_handler tostring = [](const func_args &) -> value {
return mk_val<value_string>("None");
};
static const func_builtins builtins = {
{"default", default_value},
{"tojson", tojson},
{"string", [](const func_args &) -> value {
return mk_val<value_string>("None");
}},
{"safe", [](const func_args &) -> value {
return mk_val<value_string>("None");
}},
{"strip", [](const func_args &) -> value {
return mk_val<value_string>("None");
}},
{"string", tostring},
{"safe", tostring},
{"items", empty_value_fn<value_array>},
{"map", empty_value_fn<value_array>},
{"reject", empty_value_fn<value_array>},
+17 -1
View File
@@ -1561,7 +1561,23 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
if (!s.schema) {
return true;
}
if (s.raw && s.schema->contains("type") && s.schema->at("type").is_string() && s.schema->at("type") == "string") {
if (s.raw && s.schema->contains("type")) {
const auto & type_val = s.schema->at("type");
if (type_val.is_string() && type_val == "string") {
return true;
}
// Handle nullable types like ["string", "null"] - delegate when the
// non-null type is string, since the tagged format uses raw text
if (type_val.is_array()) {
for (const auto & t : type_val) {
if (t.is_string() && t.get<std::string>() != "null") {
return t.get<std::string>() == "string";
}
}
}
}
// Delegate for enum schemas in raw mode - enum values are literal strings
if (s.raw && !s.schema->contains("type") && s.schema->contains("enum")) {
return true;
}
return false;
-3
View File
@@ -7464,9 +7464,6 @@ class Gemma4Model(Gemma3Model):
assert len(tokens) == vocab.vocab_size
# TODO @ngxson : there are some known (rare) issues with the tokenizer during development
# but I don't have time to dive into them right now;
# using a dedicated tokenizer name so that we can fix later without re-converting GGUF
self.gguf_writer.add_tokenizer_model("gemma4")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
+5 -4
View File
@@ -57,13 +57,14 @@ ZenDNN is optimized for AMD EPYC™ processors and AMD Ryzen™ processors based
## Supported Operations
The ZenDNN backend currently accelerates **matrix multiplication (MUL_MAT)** operations only. Other operations are handled by the standard CPU backend.
The ZenDNN backend accelerates **matrix multiplication (MUL_MAT)** and **expert-based matrix multiplication (MUL_MAT_ID)** operations. Other operations are handled by the standard CPU backend.
| Operation | Status | Notes |
|:-------------|:-------:|:----------------------------------------------:|
| MUL_MAT | Support | Accelerated via ZenDNN LowOHA MatMul |
| MUL_MAT_ID | Support | Accelerated via ZenDNN LowOHA MatMul (MoE) |
*Note:* Since only MUL_MAT is accelerated, models will benefit most from ZenDNN when matrix multiplications dominate the computational workload (which is typical for transformer-based LLMs).
*Note:* Since MUL_MAT and MUL_MAT_ID are accelerated, models will benefit most from ZenDNN when matrix multiplications dominate the computational workload (which is typical for transformer-based LLMs and Mixture-of-Experts models).
## DataType Supports
@@ -181,7 +182,7 @@ For detailed profiling and logging options, refer to the [ZenDNN Logging Documen
## Known Issues
- **Limited operation support**: Currently only matrix multiplication (MUL_MAT) is accelerated via ZenDNN. Other operations fall back to the standard CPU backend.
- **Limited operation support**: Currently matrix multiplication (MUL_MAT) and expert-based matrix multiplication (MUL_MAT_ID) are accelerated via ZenDNN. Other operations fall back to the standard CPU backend. Future updates may expand supported operations.
- **BF16 support**: BF16 operations require AMD Zen 4 or Zen 5 architecture (EPYC 9004/9005 series). On older CPUs, operations will use FP32.
- **NUMA awareness**: For multi-socket systems, manual NUMA binding may be required for optimal performance.
@@ -216,4 +217,4 @@ Please add the **[ZenDNN]** prefix/tag in issues/PRs titles to help the ZenDNN-t
## TODO
- Expand operation support beyond MUL_MAT (attention operations, activations, etc.)
- Expand operation support beyond MUL_MAT and MUL_MAT_ID (attention operations, activations, etc.)
+1 -1
View File
@@ -389,7 +389,7 @@ You can download it from your Linux distro's package manager or from here: [ROCm
The environment variable [`HIP_VISIBLE_DEVICES`](https://rocm.docs.amd.com/en/latest/understand/gpu_isolation.html#hip-visible-devices) can be used to specify which GPU(s) will be used.
If your GPU is not officially supported you can use the environment variable [`HSA_OVERRIDE_GFX_VERSION`] set to a similar GPU, for example 10.3.0 on RDNA2 (e.g. gfx1030, gfx1031, or gfx1035) or 11.0.0 on RDNA3.
If your GPU is not officially supported you can use the environment variable [`HSA_OVERRIDE_GFX_VERSION`] set to a similar GPU, for example 10.3.0 on RDNA2 (e.g. gfx1030, gfx1031, or gfx1035) or 11.0.0 on RDNA3. Note that [`HSA_OVERRIDE_GFX_VERSION`] is [not supported on Windows](https://github.com/ROCm/ROCm/issues/2654)
### Unified Memory
+1 -1
View File
@@ -68,7 +68,7 @@ Legend:
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ❌ | | ❌ |
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ❌ | 🟡 | ❌ |
| NEG | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ | ❌ | ❌ |
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
+2773 -7213
View File
File diff suppressed because it is too large Load Diff
+6 -5
View File
@@ -1009,8 +1009,8 @@ public:
bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
struct stored_graph {
ggml_context_ptr ctx_ptr;
ggml_cgraph * graph;
std::vector<uint8_t> buffer;
ggml_cgraph * graph;
};
private:
@@ -1518,10 +1518,12 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input) {
LOG_DBG("[%s] device: %u, n_nodes: %u, n_tensors: %u\n", __func__, device, n_nodes, n_tensors);
size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
if (stored_graphs[device].buffer.size() < buf_size) {
stored_graphs[device].buffer.resize(buf_size);
}
struct ggml_init_params params = {
/*.mem_size =*/ buf_size,
/*.mem_buffer =*/ NULL,
/*.mem_buffer =*/ stored_graphs[device].buffer.data(),
/*.no_alloc =*/ true,
};
ggml_context_ptr ctx_ptr { ggml_init(params) };
@@ -1551,7 +1553,6 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input) {
}
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
stored_graphs[device].ctx_ptr.swap(ctx_ptr);
stored_graphs[device].graph = graph;
return true;
}
+191 -39
View File
@@ -95,6 +95,12 @@ struct ggml_webgpu_generic_shader_decisions {
uint32_t wg_size = 0;
};
struct ggml_webgpu_processed_shader {
std::string wgsl;
std::string variant;
std::shared_ptr<void> decisions;
};
struct ggml_webgpu_ssm_conv_shader_decisions {
uint32_t block_size;
uint32_t tokens_per_wg;
@@ -384,11 +390,12 @@ struct ggml_webgpu_flash_attn_pipeline_key {
bool has_mask;
bool has_sinks;
bool uses_logit_softcap;
bool use_vec;
bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const {
return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&
kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks &&
uses_logit_softcap == other.uses_logit_softcap;
uses_logit_softcap == other.uses_logit_softcap && use_vec == other.use_vec;
}
};
@@ -402,6 +409,7 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash {
ggml_webgpu_hash_combine(seed, key.has_mask);
ggml_webgpu_hash_combine(seed, key.has_sinks);
ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
ggml_webgpu_hash_combine(seed, key.use_vec);
return seed;
}
};
@@ -421,6 +429,115 @@ struct ggml_webgpu_flash_attn_shader_decisions {
uint32_t wg_size = 0;
};
inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_pipeline_key & key) {
// Keep conservative defaults unless this is the f16 vec-split shape family.
if (key.kv_type != GGML_TYPE_F16 || key.head_dim_qk != key.head_dim_v) {
return 1u;
}
// Head-dim specializations used by the tuned vec f16 path.
switch (key.head_dim_qk) {
case 64: return 2u;
case 96: return 4u;
case 128: return 1u;
case 192: return 2u;
case 576: return 2u;
default: return 1u;
}
}
struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key {
uint32_t head_dim_v;
uint32_t wg_size;
};
struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash {
size_t operator()(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.head_dim_v);
ggml_webgpu_hash_combine(seed, key.wg_size);
return seed;
}
};
inline bool operator==(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & lhs,
const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & rhs) {
return lhs.head_dim_v == rhs.head_dim_v && lhs.wg_size == rhs.wg_size;
}
struct ggml_webgpu_flash_attn_vec_reduce_shader_lib_context {
ggml_webgpu_flash_attn_vec_reduce_pipeline_key key;
uint32_t max_wg_size;
};
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_vec_reduce_shader(
pre_wgsl::Preprocessor & preprocessor,
const char * shader_src,
const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) {
std::vector<std::string> defines;
std::string variant = "flash_attn_vec_reduce";
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v));
variant += std::string("_hsv") + std::to_string(context.key.head_dim_v);
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
variant += std::string("_wg") + std::to_string(context.max_wg_size);
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
return result;
}
struct ggml_webgpu_flash_attn_blk_pipeline_key {
uint32_t q_tile;
uint32_t kv_tile;
bool operator==(const ggml_webgpu_flash_attn_blk_pipeline_key & other) const {
return q_tile == other.q_tile && kv_tile == other.kv_tile;
}
};
struct ggml_webgpu_flash_attn_blk_pipeline_key_hash {
size_t operator()(const ggml_webgpu_flash_attn_blk_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.q_tile);
ggml_webgpu_hash_combine(seed, key.kv_tile);
return seed;
}
};
struct ggml_webgpu_flash_attn_blk_shader_lib_context {
ggml_webgpu_flash_attn_blk_pipeline_key key;
uint32_t max_wg_size;
};
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_blk_shader(
pre_wgsl::Preprocessor & preprocessor,
const char * shader_src,
const ggml_webgpu_flash_attn_blk_shader_lib_context & context) {
std::vector<std::string> defines;
std::string variant = "flash_attn_vec_blk";
defines.push_back(std::string("Q_TILE=") + std::to_string(context.key.q_tile));
variant += std::string("_qt") + std::to_string(context.key.q_tile);
defines.push_back(std::string("KV_TILE=") + std::to_string(context.key.kv_tile));
variant += std::string("_kvt") + std::to_string(context.key.kv_tile);
uint32_t wg_size = 1;
while ((wg_size << 1) <= context.max_wg_size) {
wg_size <<= 1;
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
variant += std::string("_wg") + std::to_string(wg_size);
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
return result;
}
// This is exposed because it's necessary in supports_op
inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
uint32_t kv_tile,
@@ -659,6 +776,14 @@ class ggml_webgpu_shader_lib {
repeat_pipelines; // type
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
flash_attn_pipelines;
std::unordered_map<ggml_webgpu_flash_attn_vec_reduce_pipeline_key,
webgpu_pipeline,
ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash>
flash_attn_vec_reduce_pipelines;
std::unordered_map<ggml_webgpu_flash_attn_blk_pipeline_key,
webgpu_pipeline,
ggml_webgpu_flash_attn_blk_pipeline_key_hash>
flash_attn_blk_pipelines;
std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key,
webgpu_pipeline,
ggml_webgpu_legacy_mul_mat_pipeline_key_hash>
@@ -1673,24 +1798,8 @@ class ggml_webgpu_shader_lib {
return repeat_pipelines[key];
}
webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) {
const bool has_mask = context.src3 != nullptr;
const bool has_sinks = context.src4 != nullptr;
bool kv_direct = (context.src1->type == GGML_TYPE_F16) && (context.src0->ne[0] % context.sg_mat_k == 0) &&
(context.src1->ne[1] % context.sg_mat_n == 0);
ggml_webgpu_flash_attn_pipeline_key key = {
.kv_type = context.src1->type,
.head_dim_qk = (uint32_t) context.src0->ne[0],
.head_dim_v = (uint32_t) context.src2->ne[0],
.kv_direct = kv_direct,
.has_mask = has_mask,
.has_sinks = has_sinks,
.uses_logit_softcap = (*(float *) &context.dst->op_params[2]) != 0.0f,
};
auto it = flash_attn_pipelines.find(key);
webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_flash_attn_shader_lib_context & context) {
auto it = flash_attn_pipelines.find(context.key);
if (it != flash_attn_pipelines.end()) {
return it->second;
}
@@ -1698,7 +1807,7 @@ class ggml_webgpu_shader_lib {
std::vector<std::string> defines;
std::string variant = "flash_attn";
switch (key.kv_type) {
switch (context.key.kv_type) {
case GGML_TYPE_F32:
defines.push_back("KV_F32");
break;
@@ -1714,41 +1823,52 @@ class ggml_webgpu_shader_lib {
default:
GGML_ABORT("Unsupported KV type for flash attention shader");
}
variant += std::string("_") + ggml_type_name(key.kv_type);
variant += std::string("_") + ggml_type_name(context.key.kv_type);
if (key.has_mask) {
if (context.key.has_mask) {
defines.push_back("MASK");
variant += "_mask";
}
if (key.has_sinks) {
if (context.key.has_sinks) {
defines.push_back("SINKS");
variant += "_sinks";
}
if (key.uses_logit_softcap) {
if (context.key.uses_logit_softcap) {
defines.push_back("LOGIT_SOFTCAP");
variant += "_lgsc";
}
if (key.kv_direct) {
if (context.key.kv_direct) {
defines.push_back("KV_DIRECT");
variant += "_kvdirect";
}
if (context.key.has_mask && context.key.use_vec) {
defines.push_back("BLK");
variant += "_blk";
}
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk));
variant += std::string("_hsqk") + std::to_string(key.head_dim_qk);
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.key.head_dim_qk));
variant += std::string("_hsqk") + std::to_string(context.key.head_dim_qk);
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
variant += std::string("_hsv") + std::to_string(key.head_dim_v);
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v));
variant += std::string("_hsv") + std::to_string(context.key.head_dim_v);
defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
uint32_t q_tile = context.sg_mat_m;
uint32_t q_tile = context.sg_mat_m;
uint32_t kv_tile =
std::min(ggml_webgpu_flash_attn_max_kv_tile({ key, context.sg_mat_m, context.sg_mat_n, context.sg_mat_k,
context.wg_mem_limit_bytes, context.max_subgroup_size }),
std::min(ggml_webgpu_flash_attn_max_kv_tile(context),
context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
if (key.kv_direct) {
if (context.key.use_vec) {
q_tile = 1;
kv_tile = std::max(context.sg_mat_n, std::min(32u, ggml_webgpu_flash_attn_max_kv_tile(context)));
kv_tile = (kv_tile / context.sg_mat_n) * context.sg_mat_n;
const uint32_t vec_ne = ggml_webgpu_flash_attn_pick_vec_ne(context.key);
defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u");
}
if (context.key.kv_direct) {
GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
kv_tile -= context.sg_mat_n;
}
@@ -1757,19 +1877,51 @@ class ggml_webgpu_shader_lib {
defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile));
defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile));
uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
uint32_t wg_size = 0;
if (context.key.use_vec) {
wg_size = std::max(1u, std::min<uint32_t>(32u, context.max_subgroup_size));
} else {
wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
auto processed = preprocessor.preprocess(wgsl_flash_attn, defines);
const char * shader_src = context.key.use_vec ? wgsl_flash_attn_vec_split : wgsl_flash_attn;
webgpu_pipeline pipeline =
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant);
auto decisions = std::make_shared<ggml_webgpu_flash_attn_shader_decisions>();
decisions->q_tile = q_tile;
decisions->kv_tile = kv_tile;
decisions->wg_size = wg_size;
pipeline.context = decisions;
flash_attn_pipelines[context.key] = pipeline;
return flash_attn_pipelines[context.key];
}
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
flash_attn_pipelines[key] = pipeline;
return flash_attn_pipelines[key];
webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_flash_attn_blk_shader_lib_context & context) {
auto it = flash_attn_blk_pipelines.find(context.key);
if (it != flash_attn_blk_pipelines.end()) {
return it->second;
}
ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_flash_attn_blk_shader(preprocessor, wgsl_flash_attn_vec_blk, context);
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant);
flash_attn_blk_pipelines[context.key] = pipeline;
return flash_attn_blk_pipelines[context.key];
}
webgpu_pipeline get_flash_attn_vec_reduce_pipeline(
const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) {
auto it = flash_attn_vec_reduce_pipelines.find(context.key);
if (it != flash_attn_vec_reduce_pipelines.end()) {
return it->second;
}
ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_flash_attn_vec_reduce_shader(preprocessor, wgsl_flash_attn_vec_reduce, context);
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant);
flash_attn_vec_reduce_pipelines[context.key] = pipeline;
return flash_attn_vec_reduce_pipelines[context.key];
}
webgpu_pipeline get_cpy_pipeline(const ggml_webgpu_shader_lib_context & context) {
+309 -14
View File
@@ -658,7 +658,6 @@ static webgpu_command ggml_backend_webgpu_build_multi(
for (size_t i = 0; i < params_bufs_list.size(); i++) {
ctx->queue.WriteBuffer(params_bufs_list[i], 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t));
}
#ifdef GGML_WEBGPU_GPU_PROFILE
webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs();
if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
@@ -1481,7 +1480,6 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
}
#ifndef __EMSCRIPTEN__
static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
ggml_tensor * Q,
ggml_tensor * K,
@@ -1565,30 +1563,248 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = Q,
.src1 = K,
.src2 = V,
.src3 = mask,
.src4 = sinks,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
const uint32_t k_offset_elems = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type));
const uint32_t v_offset_elems = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type));
const bool f16_vec4_aligned = (k_offset_elems % 4u == 0u) && (v_offset_elems % 4u == 0u);
const bool kv_direct = (K->type == GGML_TYPE_F16) && f16_vec4_aligned &&
(Q->ne[0] % ctx->global_ctx->capabilities.sg_mat_k == 0) &&
(K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
const bool kv_vec_type_supported =
K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0;
const bool use_vec = (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported &&
(K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type);
const uint32_t vec_nwg_cap =
std::max(1u, std::min<uint32_t>(32u, ctx->global_ctx->capabilities.max_subgroup_size));
const bool use_blk = use_vec && has_mask;
ggml_webgpu_flash_attn_pipeline_key key = {
.kv_type = K->type,
.head_dim_qk = (uint32_t) Q->ne[0],
.head_dim_v = (uint32_t) V->ne[0],
.kv_direct = kv_direct,
.has_mask = static_cast<bool>(has_mask),
.has_sinks = static_cast<bool>(has_sinks),
.uses_logit_softcap = logit_softcap != 0.0f,
.use_vec = use_vec,
};
ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = {
.key = key,
.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m,
.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n,
.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k,
.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context.get());
uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile);
uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches
wgpu::Buffer blk_buf = {};
uint64_t blk_size_bytes = 0;
uint32_t blk_nblk0 = 0;
uint32_t blk_nblk1 = 0;
uint32_t blk_batch_count = 0;
if (use_vec) {
uint32_t nwg = 1u;
const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile);
while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) {
nwg <<= 1;
}
nwg = std::min(nwg, vec_nwg_cap);
GGML_ASSERT(nwg <= ctx->global_ctx->capabilities.max_subgroup_size);
const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3];
const bool use_vec_reduce = nwg > 1u;
GGML_ASSERT(nrows <= UINT32_MAX);
uint64_t tmp_stats_base = 0;
uint64_t tmp_size_bytes = 0;
wgpu::Buffer tmp_buf = {};
uint64_t tmp_bind_offset = 0;
uint64_t tmp_bind_size = 0;
const size_t align_bytes = ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment;
const size_t dst_offset = ggml_webgpu_tensor_offset(dst);
size_t scratch_offset = ROUNDUP_POW2(dst_offset + ggml_nbytes(dst), align_bytes);
if (use_vec_reduce) {
const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg;
const uint64_t tmp_stats_elems = nrows * 2u * nwg;
tmp_stats_base = tmp_data_elems;
tmp_size_bytes =
ROUNDUP_POW2((tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT);
GGML_ASSERT(tmp_stats_base <= UINT32_MAX);
tmp_buf = ggml_webgpu_tensor_buf(dst);
tmp_bind_offset = scratch_offset;
tmp_bind_size = tmp_size_bytes;
scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes);
} else {
// nwg==1 writes final dst directly in vec-split; keep tmp binding valid without extra allocation.
tmp_buf = ggml_webgpu_tensor_buf(dst);
tmp_bind_offset = ggml_webgpu_tensor_align_offset(ctx, dst);
tmp_bind_size = ggml_webgpu_tensor_binding_size(ctx, dst);
}
webgpu_pipeline blk_pipeline;
std::vector<uint32_t> blk_params;
std::vector<wgpu::BindGroupEntry> blk_entries;
if (use_blk) {
GGML_ASSERT(has_mask);
blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile);
blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], decisions->q_tile);
blk_buf = ggml_webgpu_tensor_buf(dst);
const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type));
blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u;
const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count;
blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT);
ggml_webgpu_flash_attn_blk_shader_lib_context blk_shader_ctx = {
.key =
{
.q_tile = decisions->q_tile,
.kv_tile = decisions->kv_tile,
},
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
};
blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx);
blk_params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)), // offset_mask
(uint32_t) Q->ne[1], // seq_len_q
(uint32_t) K->ne[1], // seq_len_kv
stride_mask3, // stride_mask3
blk_nblk0, // nblk0
blk_nblk1, // nblk1
};
blk_entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(mask),
.offset = ggml_webgpu_tensor_align_offset(ctx, mask),
.size = ggml_webgpu_tensor_binding_size(ctx, mask) },
{ .binding = 1, .buffer = blk_buf, .offset = scratch_offset, .size = blk_size_bytes },
};
scratch_offset = ROUNDUP_POW2(scratch_offset + blk_size_bytes, align_bytes);
}
std::vector<uint32_t> split_params = params;
if (use_blk) {
split_params.push_back(0u); // blk_base
split_params.push_back(blk_nblk0); // blk_nblk0
split_params.push_back(blk_nblk1); // blk_nblk1
}
split_params.push_back(0u); // tmp_data_base
split_params.push_back((uint32_t) tmp_stats_base); // tmp_stats_base
split_params.push_back(nwg); // nwg
std::vector<wgpu::BindGroupEntry> split_entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(Q),
.offset = ggml_webgpu_tensor_align_offset(ctx, Q),
.size = ggml_webgpu_tensor_binding_size(ctx, Q) },
{ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(K),
.offset = ggml_webgpu_tensor_align_offset(ctx, K),
.size = ggml_webgpu_tensor_binding_size(ctx, K) },
{ .binding = 2,
.buffer = ggml_webgpu_tensor_buf(V),
.offset = ggml_webgpu_tensor_align_offset(ctx, V),
.size = ggml_webgpu_tensor_binding_size(ctx, V) },
};
uint32_t split_binding_index = 3;
if (has_mask) {
split_entries.push_back({ .binding = split_binding_index++,
.buffer = ggml_webgpu_tensor_buf(mask),
.offset = ggml_webgpu_tensor_align_offset(ctx, mask),
.size = ggml_webgpu_tensor_binding_size(ctx, mask) });
}
if (has_sinks) {
split_entries.push_back({ .binding = split_binding_index++,
.buffer = ggml_webgpu_tensor_buf(sinks),
.offset = ggml_webgpu_tensor_align_offset(ctx, sinks),
.size = ggml_webgpu_tensor_binding_size(ctx, sinks) });
}
if (use_blk) {
split_entries.push_back(
{ .binding = split_binding_index++, .buffer = blk_buf, .offset = blk_entries[1].offset, .size = blk_size_bytes });
}
split_entries.push_back(
{ .binding = split_binding_index++, .buffer = tmp_buf, .offset = tmp_bind_offset, .size = tmp_bind_size });
split_entries.push_back({ .binding = split_binding_index++,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
webgpu_pipeline reduce_pipeline;
std::vector<uint32_t> reduce_params;
std::vector<wgpu::BindGroupEntry> reduce_entries;
if (use_vec_reduce) {
const uint32_t reduce_wg_size = std::max(
32u,
std::min<uint32_t>(nwg * 32u, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup));
ggml_webgpu_flash_attn_vec_reduce_shader_lib_context reduce_shader_ctx = {
.key =
{
.head_dim_v = (uint32_t) V->ne[0],
.wg_size = reduce_wg_size,
},
.max_wg_size = reduce_wg_size,
};
reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx);
reduce_params = {
(uint32_t) nrows, // nrows
(uint32_t) Q->ne[1], // seq_len_q
(uint32_t) Q->ne[2], // n_heads
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), // offset_dst
nwg, // nwg
0u, // tmp_data_base
(uint32_t) tmp_stats_base, // tmp_stats_base
};
reduce_entries = {
{ .binding = 0, .buffer = tmp_buf, .offset = tmp_bind_offset, .size = tmp_size_bytes },
{ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) },
};
}
const uint64_t split_wg_total = (uint64_t) wg_x * nwg;
GGML_ASSERT(split_wg_total <= UINT32_MAX);
std::vector<webgpu_pipeline> pipelines;
std::vector<std::vector<uint32_t>> params_list;
std::vector<std::vector<wgpu::BindGroupEntry>> entries_list;
std::vector<std::pair<uint32_t, uint32_t>> workgroups_list;
if (use_blk) {
pipelines.push_back(blk_pipeline);
params_list.push_back(std::move(blk_params));
entries_list.push_back(std::move(blk_entries));
workgroups_list.push_back({ blk_nblk0, blk_nblk1 * blk_batch_count });
}
pipelines.push_back(pipeline);
params_list.push_back(std::move(split_params));
entries_list.push_back(std::move(split_entries));
workgroups_list.push_back({ (uint32_t) split_wg_total, 1u });
if (use_vec_reduce) {
pipelines.push_back(reduce_pipeline);
params_list.push_back(std::move(reduce_params));
entries_list.push_back(std::move(reduce_entries));
workgroups_list.push_back({ (uint32_t) nrows, 1u });
}
return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list,
entries_list, workgroups_list);
}
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
#endif
static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
bool is_unary = dst->op == GGML_OP_UNARY;
@@ -2559,7 +2775,6 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
std::vector<webgpu_submission> subs;
uint32_t num_batched_kernels = 0;
bool contains_set_rows = false;
for (int i = 0; i < cgraph->n_nodes; i++) {
if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) {
contains_set_rows = true;
@@ -2834,6 +3049,86 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
}
}
break;
case GGML_OP_FLASH_ATTN_EXT:
{
const ggml_tensor * Q = tensor->src[0];
const ggml_tensor * K = tensor->src[1];
const ggml_tensor * V = tensor->src[2];
const ggml_tensor * mask = tensor->src[3];
const ggml_tensor * sinks = tensor->src[4];
if (Q && K && V) {
GGML_UNUSED(sinks);
const bool kv_direct = (K->type == GGML_TYPE_F16) &&
(Q->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k == 0) &&
(K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
const bool kv_vec_type_supported =
K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0;
const bool use_vec =
(Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported &&
(V->type == K->type);
if (use_vec) {
const uint32_t sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m;
const uint32_t sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n;
const size_t limit_bytes =
ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
const size_t q_tile = sg_mat_m;
const size_t base_q_bytes =
(Q->ne[0] + V->ne[0]) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
size_t bytes_per_kv = 0;
if (!kv_direct) {
bytes_per_kv += std::max(Q->ne[0], V->ne[0]);
}
if (mask != nullptr) {
bytes_per_kv += q_tile;
}
bytes_per_kv += q_tile;
bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
uint32_t kv_tile =
((limit_bytes - base_q_bytes) / bytes_per_kv / sg_mat_n) * sg_mat_n;
kv_tile = std::max(sg_mat_n, std::min(32u, kv_tile));
kv_tile = (kv_tile / sg_mat_n) * sg_mat_n;
if (kv_direct) {
GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
kv_tile -= sg_mat_n;
}
}
const uint32_t vec_nwg_cap = std::max(
1u, std::min<uint32_t>(32u, ctx->webgpu_global_ctx->capabilities.max_subgroup_size));
uint32_t nwg = 1u;
const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile);
while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) {
nwg <<= 1;
}
nwg = std::min(nwg, vec_nwg_cap);
const size_t align = ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment;
const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3];
if (nwg > 1u) {
const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg;
const uint64_t tmp_stats_elems = nrows * 2u * nwg;
const size_t tmp_size_bytes = ROUNDUP_POW2(
(tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT);
res += tmp_size_bytes + align;
}
if (mask != nullptr) {
const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile);
const uint32_t blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], 1u);
const uint32_t stride_mask3 =
(uint32_t) (mask->nb[3] / ggml_type_size(mask->type));
const uint32_t blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u;
const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count;
const size_t blk_size_bytes =
ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT);
res += blk_size_bytes + align;
}
res = ROUNDUP_POW2(res, WEBGPU_STORAGE_BUF_BINDING_MULT);
}
}
}
break;
default:
break;
}
@@ -0,0 +1,105 @@
diagnostic(off, subgroup_uniformity);
enable f16;
#define Q_TILE 1
#define KV_TILE 32
#define WG_SIZE 32
struct Params {
offset_mask: u32,
seq_len_q: u32,
seq_len_kv: u32,
stride_mask3: u32,
// Number of KV blocks and Q blocks per batch.
// nblk0 = ceil(seq_len_kv / KV_TILE), nblk1 = ceil(seq_len_q / Q_TILE).
nblk0: u32,
nblk1: u32,
};
@group(0) @binding(0) var<storage, read> mask: array<f16>;
@group(0) @binding(1) var<storage, read_write> blk: array<u32>;
@group(0) @binding(2) var<uniform> params: Params;
const MASK_MIN: f32 = -65504.0;
const MASK_MAX: f32 = 65504.0;
var<workgroup> wg_min: array<f32, WG_SIZE>;
var<workgroup> wg_max: array<f32, WG_SIZE>;
var<workgroup> wg_any: array<u32, WG_SIZE>;
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>) {
// Dispatch mapping:
// - x indexes KV blocks
// - y flattens (batch_idx, q_blk) as y = batch_idx * nblk1 + q_blk
let kv_blk = wg_id.x;
let y = wg_id.y;
let q_blk = y % params.nblk1;
let batch_idx = y / params.nblk1;
if (kv_blk >= params.nblk0) {
return;
}
let q_start = q_blk * Q_TILE;
let k_start = kv_blk * KV_TILE;
let mask_batch = select(0u, batch_idx, params.stride_mask3 > 0u);
let mask_batch_base = params.offset_mask + mask_batch * params.stride_mask3;
// We keep min/max to classify:
// - fully masked (max <= MASK_MIN)
// - all-zero mask (min == 0 && max == 0)
// - mixed/general mask
var local_min = MASK_MAX;
var local_max = -MASK_MAX;
var local_any = 0u;
for (var q_rel = 0u; q_rel < Q_TILE; q_rel += 1u) {
let q_row = q_start + q_rel;
if (q_row >= params.seq_len_q) {
continue;
}
let row_base = mask_batch_base + q_row * params.seq_len_kv;
for (var k_rel = local_id.x; k_rel < KV_TILE; k_rel += WG_SIZE) {
let k_col = k_start + k_rel;
if (k_col >= params.seq_len_kv) {
continue;
}
let mv = f32(mask[row_base + k_col]);
local_min = min(local_min, mv);
local_max = max(local_max, mv);
local_any = 1u;
}
}
wg_min[local_id.x] = local_min;
wg_max[local_id.x] = local_max;
wg_any[local_id.x] = local_any;
workgroupBarrier();
// Thread 0 writes one state per block.
if (local_id.x == 0u) {
var mmin = wg_min[0];
var mmax = wg_max[0];
var many = wg_any[0];
for (var i = 1u; i < WG_SIZE; i += 1u) {
mmin = min(mmin, wg_min[i]);
mmax = max(mmax, wg_max[i]);
many = max(many, wg_any[i]);
}
var state = 0u;
if (many != 0u) {
if (mmax <= MASK_MIN) {
state = 0u;
} else if (mmin == 0.0 && mmax == 0.0) {
state = 2u;
} else {
state = 1u;
}
}
let blk_idx = (batch_idx * params.nblk1 + q_blk) * params.nblk0 + kv_blk;
blk[blk_idx] = state;
}
}
@@ -0,0 +1,78 @@
diagnostic(off, subgroup_uniformity);
enable f16;
enable subgroups;
// Default values
#define HEAD_DIM_V 64
#define WG_SIZE 128
struct Params {
nrows: u32,
seq_len_q: u32,
n_heads: u32,
offset_dst: u32,
nwg: u32,
tmp_data_base: u32,
tmp_stats_base: u32,
};
@group(0) @binding(0) var<storage, read_write> tmp: array<f32>;
@group(0) @binding(1) var<storage, read_write> dst: array<vec4<f32>>;
@group(0) @binding(2) var<uniform> params: Params;
const FLOAT_MIN: f32 = -1.0e9;
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
@builtin(subgroup_id) subgroup_id: u32,
@builtin(num_subgroups) num_subgroups: u32,
@builtin(subgroup_size) subgroup_size: u32,
@builtin(subgroup_invocation_id) sg_inv_id: u32) {
let rid = wg_id.x;
if (rid >= params.nrows) {
return;
}
let rows_per_batch = params.n_heads * params.seq_len_q;
let batch_idx = rid / rows_per_batch;
let rem = rid % rows_per_batch;
let head_idx = rem / params.seq_len_q;
let q_row = rem % params.seq_len_q;
let dst2_stride = HEAD_DIM_V * params.n_heads;
let dst3_stride = dst2_stride * params.seq_len_q;
let row_base = params.offset_dst + batch_idx * dst3_stride + q_row * dst2_stride + head_idx * HEAD_DIM_V;
let thread = sg_inv_id;
if (params.nwg > subgroup_size) {
return;
}
let stats_base = params.tmp_stats_base + rid * (2u * params.nwg);
let active_thread = thread < params.nwg;
let si = select(0.0, tmp[stats_base + 2u * thread + 0u], active_thread);
let mi = select(FLOAT_MIN, tmp[stats_base + 2u * thread + 1u], active_thread);
let m = subgroupMax(mi);
let ms = select(0.0, exp(mi - m), active_thread);
let s = subgroupAdd(si * ms);
let inv_s = select(0.0, 1.0 / s, s != 0.0);
let row_tmp_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg);
for (var elem_base = subgroup_id * 4u; elem_base < HEAD_DIM_V; elem_base += num_subgroups * 4u) {
var weighted = vec4<f32>(0.0, 0.0, 0.0, 0.0);
if (active_thread) {
let src = row_tmp_base + thread * HEAD_DIM_V + elem_base;
weighted = vec4<f32>(tmp[src + 0u], tmp[src + 1u], tmp[src + 2u], tmp[src + 3u]) * ms;
}
let sum_x = subgroupAdd(weighted.x);
let sum_y = subgroupAdd(weighted.y);
let sum_z = subgroupAdd(weighted.z);
let sum_w = subgroupAdd(weighted.w);
if (thread == 0u) {
let dst_vec_index = (row_base + elem_base) >> 2u;
dst[dst_vec_index] = vec4<f32>(sum_x, sum_y, sum_z, sum_w) * inv_s;
}
}
}
@@ -0,0 +1,729 @@
diagnostic(off, chromium.subgroup_matrix_uniformity);
diagnostic(off, subgroup_uniformity);
enable f16;
enable subgroups;
enable chromium_experimental_subgroup_matrix;
#ifdef KV_F32
#define KV_TYPE f32
#else
#define KV_TYPE f16
#endif
#define HEAD_DIM_QK 64
#define HEAD_DIM_V 64
#define SG_MAT_M 8
#define SG_MAT_N 8
#define SG_MAT_K 8
#define Q_TILE SG_MAT_M
#define KV_TILE 16
#define WG_SIZE 64
#ifndef VEC_NE
#define VEC_NE 4u
#endif
#define KV_BLOCKS (KV_TILE / SG_MAT_N)
#define BLOCK_SIZE 32
#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE)
#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE)
#if defined(KV_Q4_0)
#define NQ 16
#define F16_PER_BLOCK 9
#define WEIGHTS_PER_F16 4
#elif defined(KV_Q8_0)
#define NQ 8
#define F16_PER_BLOCK 17
#define WEIGHTS_PER_F16 2
#endif
#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16)
fn get_byte(value: u32, index: u32) -> u32 {
return (value >> (index * 8)) & 0xFF;
}
fn get_byte_i32(value: u32, index: u32) -> i32 {
return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;
}
struct Params {
offset_q: u32,
offset_k: u32,
offset_v: u32,
offset_mask: u32,
offset_sinks: u32,
offset_dst: u32,
// shapes of Q/K/V
n_heads: u32,
seq_len_q: u32,
seq_len_kv: u32,
// strides (in elements)
stride_q1: u32,
stride_q2: u32,
stride_q3: u32,
stride_k1: u32,
stride_k2: u32,
stride_k3: u32,
stride_v1: u32,
stride_v2: u32,
stride_v3: u32,
stride_mask3: u32,
// repeat factors for K/V, e.g., MHA vs. MQA vs. GQA
q_per_kv: u32,
// softmax params
scale: f32,
max_bias: f32,
logit_softcap: f32,
n_head_log2: f32,
m0: f32,
m1: f32,
#ifdef BLK
blk_base: u32,
blk_nblk0: u32,
blk_nblk1: u32,
#endif
tmp_data_base: u32,
tmp_stats_base: u32,
nwg: u32,
};
@group(0) @binding(0) var<storage, read_write> Q: array<f32>;
#if defined(KV_Q4_0) || defined(KV_Q8_0)
@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
#else
@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>;
#endif
#if defined(KV_Q4_0) || defined(KV_Q8_0)
@group(0) @binding(2) var<storage, read_write> V: array<KV_TYPE>;
#else
@group(0) @binding(2) var<storage, read_write> V: array<vec4<KV_TYPE>>;
#endif
#if defined(MASK) && defined(SINKS)
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
@group(0) @binding(4) var<storage, read_write> sinks: array<f32>;
#ifdef BLK
#define BLK_BINDING 5
#define TMP_BINDING 6
#define DST_BINDING 7
#define PARAMS_BINDING 8
#else
#define TMP_BINDING 5
#define DST_BINDING 6
#define PARAMS_BINDING 7
#endif
#elif defined(MASK)
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
#ifdef BLK
#define BLK_BINDING 4
#define TMP_BINDING 5
#define DST_BINDING 6
#define PARAMS_BINDING 7
#else
#define TMP_BINDING 4
#define DST_BINDING 5
#define PARAMS_BINDING 6
#endif
#elif defined(SINKS)
@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
#define TMP_BINDING 4
#define DST_BINDING 5
#define PARAMS_BINDING 6
#else
#define TMP_BINDING 3
#define DST_BINDING 4
#define PARAMS_BINDING 5
#endif
#ifdef BLK
@group(0) @binding(BLK_BINDING) var<storage, read_write> blk: array<u32>;
#endif
@group(0) @binding(TMP_BINDING) var<storage, read_write> tmp: array<f32>;
@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<f32>>;
@group(0) @binding(PARAMS_BINDING) var<uniform> params: Params;
// Just a very small float value.
const FLOAT_MIN: f32 = -1.0e9;
var<workgroup> q_shmem: array<f16, Q_TILE * HEAD_DIM_QK>;
#ifndef KV_DIRECT
const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V);
// we can reuse the same shmem for K and V since we only need one at a time
var<workgroup> kv_shmem: array<f16, kv_shmem_size>;
#endif
var<workgroup> o_shmem: array<f16, Q_TILE * HEAD_DIM_V>;
#ifdef MASK
// storage for mask values
var<workgroup> mask_shmem: array<f16, Q_TILE * KV_TILE>;
#endif
// note that we reuse the same storage for both since we only need one at a time
var<workgroup> inter_shmem: array<f16, Q_TILE * KV_TILE>;
// Storage for row max and exp sum during online softmax
var<workgroup> row_max_shmem: array<f32, Q_TILE>;
var<workgroup> exp_sum_shmem: array<f32, Q_TILE>;
var<workgroup> blk_state_wg: u32;
fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32, has_bias: bool, apply_mask: bool) -> f32 {
var v = select(FLOAT_MIN,
f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]) * params.scale,
kv_idx < KV_TILE);
#ifdef LOGIT_SOFTCAP
v = params.logit_softcap * tanh(v);
#endif
#ifdef MASK
if (apply_mask) {
var mask_val = select(0.0,f32(mask_shmem[q_tile_row * KV_TILE + kv_idx]), kv_idx < KV_TILE);
v += select(mask_val, slope * mask_val, has_bias);
}
#endif
return v;
}
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(subgroup_id) subgroup_id: u32,
@builtin(subgroup_size) subgroup_size: u32,
@builtin(num_subgroups) num_subgroups: u32,
@builtin(subgroup_invocation_id) sg_inv_id: u32) {
// initialize row max for online softmax
for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) {
row_max_shmem[i] = FLOAT_MIN;
exp_sum_shmem[i] = 0.0;
}
for (var i = local_id.x; i < Q_TILE * HEAD_DIM_V; i += WG_SIZE) {
o_shmem[i] = 0.0;
}
// workgroups per head/batch
let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE;
let wg_per_batch = wg_per_head * params.n_heads;
let dst2_stride = HEAD_DIM_V * params.n_heads;
let dst3_stride = dst2_stride * params.seq_len_q;
let iwg = wg_id.x % params.nwg;
let base_wg_id = wg_id.x / params.nwg;
// batch index
let batch_idx = base_wg_id / wg_per_batch;
let q_batch_offset = params.offset_q + batch_idx * params.stride_q3;
let k_batch_offset = params.offset_k + batch_idx * params.stride_k3;
let v_batch_offset = params.offset_v + batch_idx * params.stride_v3;
let wg_in_batch = base_wg_id % wg_per_batch;
// head index
let head_idx = wg_in_batch / wg_per_head;
let q_head_offset = q_batch_offset + head_idx * params.stride_q2;
let k_head_idx = head_idx / params.q_per_kv;
let v_head_idx = k_head_idx;
let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2;
let v_head_offset = v_batch_offset + v_head_idx * params.stride_v2;
// starting Q row for this workgroup
let wg_in_head = wg_in_batch % wg_per_head;
let q_row_start = wg_in_head * Q_TILE;
#ifdef MASK
// mask offset
let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv;
#endif
let head = f32(head_idx);
let has_bias = params.max_bias > 0.0;
let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), has_bias);
// load q tile into shared memory
for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) {
let q_row = elem_idx / HEAD_DIM_QK;
let q_col = elem_idx % HEAD_DIM_QK;
let head_q_row = q_row_start + q_row;
let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1;
q_shmem[elem_idx] = f16(select(
0.0,
Q[global_q_row_offset + q_col],
head_q_row < params.seq_len_q && q_col < HEAD_DIM_QK));
}
for (var kv_tile = iwg * KV_TILE; kv_tile < params.seq_len_kv; kv_tile += KV_TILE * params.nwg) {
#ifdef BLK
let q_blk = q_row_start / Q_TILE;
let kv_blk = kv_tile / KV_TILE;
let blk_batch = select(0u, batch_idx, params.stride_mask3 > 0u);
let blk_idx = params.blk_base + (blk_batch * params.blk_nblk1 + q_blk) * params.blk_nblk0 + kv_blk;
let blk_state_local = blk[blk_idx];
#else
let blk_state_local = 1u;
#endif
if (local_id.x == 0u) {
blk_state_wg = blk_state_local;
}
workgroupBarrier();
let blk_state = blk_state_wg;
let skip_tile = blk_state == 0u;
for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
inter_shmem[elem_idx] = f16(0.0);
}
// load k tile into shared memory
#if defined(KV_Q4_0)
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {
let blck_idx = elem_idx / BLOCK_SIZE;
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
let k_row = blck_idx / BLOCKS_K;
let global_k_row = kv_tile + k_row;
let block_k = blck_idx % BLOCKS_K;
let row_offset = k_row * HEAD_DIM_QK;
if (global_k_row < params.seq_len_kv) {
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
let base_idx = global_block_idx * F16_PER_BLOCK;
let d = K[base_idx];
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = K[base_idx + 1u + block_offset + j];
let q_1 = K[base_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
kv_shmem[row_offset + idx] = q_lo;
kv_shmem[row_offset + idx + 16u] = q_hi;
}
}
}
}
#elif defined(KV_Q8_0)
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {
let blck_idx = elem_idx / BLOCK_SIZE;
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
let k_row = blck_idx / BLOCKS_K;
let global_k_row = kv_tile + k_row;
let block_k = blck_idx % BLOCKS_K;
let row_offset = k_row * HEAD_DIM_QK;
if (global_k_row < params.seq_len_kv) {
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
let base_idx = global_block_idx * F16_PER_BLOCK;
let d = K[base_idx];
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = K[base_idx + 1u + block_offset + j];
let q_1 = K[base_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f16(q_byte) * d;
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
kv_shmem[row_offset + idx] = q_val;
}
}
}
}
#elif defined(KV_DIRECT)
// Direct global loads for KV
#else
for (var elem_idx = local_id.x * 4u; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * 4u) {
let k_row = elem_idx / HEAD_DIM_QK;
let k_col = elem_idx % HEAD_DIM_QK;
let global_k_row = kv_tile + k_row;
let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1;
let in_bounds = global_k_row < params.seq_len_kv && (k_col + 3u) < HEAD_DIM_QK;
let vec_idx = (global_k_row_offset + k_col) >> 2u;
let k4 = select(vec4<KV_TYPE>(0.0), K[vec_idx], in_bounds);
kv_shmem[elem_idx + 0u] = f16(k4.x);
kv_shmem[elem_idx + 1u] = f16(k4.y);
kv_shmem[elem_idx + 2u] = f16(k4.z);
kv_shmem[elem_idx + 3u] = f16(k4.w);
}
#endif
workgroupBarrier();
// accumulate q block * k block into registers across the entire KV tile
if (!skip_tile) {
let num_of_threads = subgroup_size / VEC_NE;
let tx = sg_inv_id % num_of_threads;
let ty = sg_inv_id / num_of_threads;
for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) {
let global_q_row = q_row_start + q_tile_row;
if (global_q_row >= params.seq_len_q) {
continue;
}
let local_q_row_offset = q_tile_row * HEAD_DIM_QK;
for (var kv_base : u32 = 0u; kv_base < KV_TILE; kv_base += VEC_NE) {
let kv_idx = kv_base + ty;
var partial_sum: f32 = 0.0;
let kv_valid = kv_idx < KV_TILE && (kv_tile + kv_idx) < params.seq_len_kv;
if (kv_valid) {
for (var i = tx; i < (HEAD_DIM_QK / 4u); i += num_of_threads) {
let q_off = local_q_row_offset + i * 4u;
let qv = vec4<f32>(
f32(q_shmem[q_off + 0u]),
f32(q_shmem[q_off + 1u]),
f32(q_shmem[q_off + 2u]),
f32(q_shmem[q_off + 3u]));
#ifdef KV_DIRECT
let idx = k_head_offset + (kv_tile + kv_idx) * params.stride_k1 + (i * 4u);
let kv = vec4<f32>(K[idx >> 2u]);
#else
let idx = kv_idx * HEAD_DIM_QK + (i * 4u);
let kv = vec4<f32>(
f32(kv_shmem[idx + 0u]),
f32(kv_shmem[idx + 1u]),
f32(kv_shmem[idx + 2u]),
f32(kv_shmem[idx + 3u]));
#endif
partial_sum += dot(qv, kv);
}
}
var sum = partial_sum;
// Reduce over tx threads (NL) for this ty stripe.
var tx_delta = num_of_threads >> 1u;
loop {
if (tx_delta == 0u) {
break;
}
let sh = subgroupShuffleDown(sum, tx_delta);
if (tx < tx_delta) {
sum += sh;
}
tx_delta >>= 1u;
}
let sum_bcast = subgroupShuffle(sum, num_of_threads * ty);
if (tx == 0u && kv_valid) {
let dst_idx = q_tile_row * KV_TILE + kv_idx;
inter_shmem[dst_idx] = f16(sum_bcast);
}
}
}
}
#ifdef MASK
let apply_mask = !skip_tile && (blk_state != 2u);
if (apply_mask) {
// load mask tile into shared memory for this KV block
for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
let mask_row = elem_idx / KV_TILE;
let mask_col = elem_idx % KV_TILE;
let global_q_row = q_row_start + mask_row;
let global_k_col = kv_tile + mask_col;
let mask_in_bounds = global_q_row < params.seq_len_q && global_k_col < params.seq_len_kv;
let mask_idx = mask_global_offset + mask_row * params.seq_len_kv + global_k_col;
mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds);
}
}
#else
let apply_mask = false;
#endif
workgroupBarrier();
// online softmax
if (!skip_tile) {
for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) {
let global_q_row = q_row_start + q_tile_row;
if (global_q_row >= params.seq_len_q) {
break;
}
var prev_max = row_max_shmem[q_tile_row];
var final_max = prev_max;
// pass 1: compute final max across the full KV tile in chunks
for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
let kv_idx = kv_offset + sg_inv_id;
let kv_valid = kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE;
let softmax_term = select(FLOAT_MIN,
calc_softmax_term(kv_idx, q_tile_row, slope, has_bias, apply_mask),
kv_valid);
final_max = subgroupMax(max(final_max, softmax_term));
}
var total_exp_term: f32 = 0.0;
// pass 2: compute exp sum and write P using final_max
for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
let kv_idx = kv_offset + sg_inv_id;
let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope, has_bias, apply_mask);
let cur_p = select(0.0,
exp(softmax_term - final_max),
kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE);
total_exp_term += subgroupAdd(cur_p);
if (kv_idx < KV_TILE) {
inter_shmem[kv_idx + q_tile_row * KV_TILE] = f16(cur_p);
}
}
let cur_exp = exp(prev_max - final_max);
if (sg_inv_id == 0) {
row_max_shmem[q_tile_row] = final_max;
exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * cur_exp + total_exp_term;
}
for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
let idx = q_tile_row * HEAD_DIM_V + elem_idx;
o_shmem[idx] = f16(f32(o_shmem[idx]) * cur_exp);
}
}
}
// load v tile into shared memory
#if defined(KV_Q4_0)
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {
let blck_idx = elem_idx / BLOCK_SIZE;
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
let v_row = blck_idx / BLOCKS_V;
let global_v_row = kv_tile + v_row;
let block_k = blck_idx % BLOCKS_V;
let row_offset = v_row * HEAD_DIM_V;
if (global_v_row < params.seq_len_kv) {
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
let base_idx = global_block_idx * F16_PER_BLOCK;
let d = V[base_idx];
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = V[base_idx + 1u + block_offset + j];
let q_1 = V[base_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
kv_shmem[row_offset + idx] = q_lo;
kv_shmem[row_offset + idx + 16u] = q_hi;
}
}
}
}
#elif defined(KV_Q8_0)
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {
let blck_idx = elem_idx / BLOCK_SIZE;
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
let v_row = blck_idx / BLOCKS_V;
let global_v_row = kv_tile + v_row;
let block_k = blck_idx % BLOCKS_V;
let row_offset = v_row * HEAD_DIM_V;
if (global_v_row < params.seq_len_kv) {
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
let base_idx = global_block_idx * F16_PER_BLOCK;
let d = V[base_idx];
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = V[base_idx + 1u + block_offset + j];
let q_1 = V[base_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f16(q_byte) * d;
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
kv_shmem[row_offset + idx] = q_val;
}
}
}
}
#elif defined(KV_DIRECT)
// Direct global loads for KV
#else
for (var elem_idx = local_id.x * 4u; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * 4u) {
let v_row = elem_idx / HEAD_DIM_V;
let v_col = elem_idx % HEAD_DIM_V;
let global_v_row = kv_tile + v_row;
let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1;
let in_bounds = global_v_row < params.seq_len_kv && (v_col + 3u) < HEAD_DIM_V;
let vec_idx = (global_v_row_offset + v_col) >> 2u;
let v4 = select(vec4<KV_TYPE>(0.0), V[vec_idx], in_bounds);
kv_shmem[elem_idx + 0u] = f16(v4.x);
kv_shmem[elem_idx + 1u] = f16(v4.y);
kv_shmem[elem_idx + 2u] = f16(v4.z);
kv_shmem[elem_idx + 3u] = f16(v4.w);
}
#endif
workgroupBarrier();
if (!skip_tile) {
// we have P (Q_TILE x KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem
// we want to compute O += P * V across the full KV tile
let ne_threads : u32 = VEC_NE;
let nl_threads = max(1u, subgroup_size / ne_threads);
let tx_pv = sg_inv_id % nl_threads;
let ty_pv = sg_inv_id / nl_threads;
for (var q_tile_row = subgroup_id;
q_tile_row < Q_TILE;
q_tile_row += num_subgroups) {
for (var vec_col = tx_pv; vec_col < (HEAD_DIM_V / 4u); vec_col += nl_threads) {
var lo = vec4<f32>(0.0, 0.0, 0.0, 0.0);
for (var cc = 0u; cc < KV_TILE / ne_threads; cc += 1u) {
let kv_idx = cc * ne_threads + ty_pv;
let v_row = kv_tile + kv_idx;
if (v_row >= params.seq_len_kv) {
continue;
}
let p = f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]);
#ifdef KV_DIRECT
let v_idx = v_head_offset + v_row * params.stride_v1 + vec_col * 4u;
let v4 = vec4<f32>(V[v_idx >> 2u]);
#else
let v_idx = kv_idx * HEAD_DIM_V + vec_col * 4u;
let v4 = vec4<f32>(
f32(kv_shmem[v_idx + 0u]),
f32(kv_shmem[v_idx + 1u]),
f32(kv_shmem[v_idx + 2u]),
f32(kv_shmem[v_idx + 3u]));
#endif
lo += p * v4;
}
var lo_x = lo.x;
var lo_y = lo.y;
var lo_z = lo.z;
var lo_w = lo.w;
// Reduce over ty threads (NE) for this tx thread.
var ty_delta = ne_threads >> 1u;
loop {
if (ty_delta == 0u) {
break;
}
let thread_delta = ty_delta * nl_threads;
let shx = subgroupShuffleDown(lo_x, thread_delta);
let shy = subgroupShuffleDown(lo_y, thread_delta);
let shz = subgroupShuffleDown(lo_z, thread_delta);
let shw = subgroupShuffleDown(lo_w, thread_delta);
if (ty_pv < ty_delta) {
lo_x += shx;
lo_y += shy;
lo_z += shz;
lo_w += shw;
}
ty_delta >>= 1u;
}
if (ty_pv == 0u) {
let elem_base = vec_col * 4u;
let o_base_idx = q_tile_row * HEAD_DIM_V + elem_base;
o_shmem[o_base_idx + 0u] = f16(f32(o_shmem[o_base_idx + 0u]) + lo_x);
o_shmem[o_base_idx + 1u] = f16(f32(o_shmem[o_base_idx + 1u]) + lo_y);
o_shmem[o_base_idx + 2u] = f16(f32(o_shmem[o_base_idx + 2u]) + lo_z);
o_shmem[o_base_idx + 3u] = f16(f32(o_shmem[o_base_idx + 3u]) + lo_w);
}
}
}
}
workgroupBarrier();
}
#ifdef SINKS
// Sinks are global terms and must be applied exactly once across split workgroups.
if (iwg == 0u) {
for (var q_tile_row = subgroup_id;
q_tile_row < Q_TILE;
q_tile_row += num_subgroups) {
let global_q_row = q_row_start + q_tile_row;
if (global_q_row >= params.seq_len_q) {
break;
}
var prev_max = row_max_shmem[q_tile_row];
// for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum
let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0);
let new_max = subgroupMax(max(prev_max, sink_val));
let max_exp = exp(prev_max - new_max);
let sink_exp = exp(sink_val - new_max);
let sink_exp_sum = subgroupAdd(sink_exp);
if (sg_inv_id == 0) {
row_max_shmem[q_tile_row] = new_max;
exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * max_exp + sink_exp_sum;
}
for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
let idx = q_tile_row * HEAD_DIM_V + elem_idx;
o_shmem[idx] = f16(f32(o_shmem[idx]) * max_exp);
}
}
workgroupBarrier();
}
#endif
let rows_per_batch = params.n_heads * params.seq_len_q;
for (var q_tile_row = subgroup_id;
q_tile_row < Q_TILE;
q_tile_row += num_subgroups) {
let global_q_row = q_row_start + q_tile_row;
if (global_q_row >= params.seq_len_q) { break; }
if (params.nwg == 1u) {
let exp_sum = exp_sum_shmem[q_tile_row];
let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0);
let row_base: u32 =
params.offset_dst + batch_idx * dst3_stride + global_q_row * dst2_stride + head_idx * HEAD_DIM_V;
for (var elem_base = sg_inv_id * 4u; elem_base < HEAD_DIM_V; elem_base += subgroup_size * 4u) {
let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u);
let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u);
let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u);
let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u);
let v = vec4<f32>(
f32(o_shmem[i0]) * scale,
f32(o_shmem[i1]) * scale,
f32(o_shmem[i2]) * scale,
f32(o_shmem[i3]) * scale
);
let dst_vec_index: u32 = (row_base + elem_base) >> 2u;
dst[dst_vec_index] = v;
}
} else {
let rid = batch_idx * rows_per_batch + head_idx * params.seq_len_q + global_q_row;
let tmp_row_data_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg) + iwg * HEAD_DIM_V;
let tmp_row_stats_base = params.tmp_stats_base + rid * (2u * params.nwg) + 2u * iwg;
for (var elem_base = sg_inv_id * 4u;
elem_base < HEAD_DIM_V;
elem_base += subgroup_size * 4u) {
let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u);
let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u);
let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u);
let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u);
let tbase = tmp_row_data_base + elem_base;
tmp[tbase + 0u] = f32(o_shmem[i0]);
tmp[tbase + 1u] = f32(o_shmem[i1]);
tmp[tbase + 2u] = f32(o_shmem[i2]);
tmp[tbase + 3u] = f32(o_shmem[i3]);
}
if (sg_inv_id == 0u) {
tmp[tmp_row_stats_base + 0u] = exp_sum_shmem[q_tile_row];
tmp[tmp_row_stats_base + 1u] = row_max_shmem[q_tile_row];
}
}
}
}
+1 -1
View File
@@ -28,7 +28,7 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF")
ExternalProject_Add(
zendnn
GIT_REPOSITORY https://github.com/amd/ZenDNN.git
GIT_TAG a18adf8c605fb5f5e52cefd7eda08a7b18febbaf # ZenDNN-2026-WW08
GIT_TAG f79f7321a1add65ced6397a6bfab7edba6e3e14e # ZenDNN-2026-WW13
PREFIX ${ZENDNN_PREFIX}
SOURCE_DIR ${ZENDNN_SOURCE_DIR}
BINARY_DIR ${ZENDNN_BUILD_DIR}
+179
View File
@@ -190,6 +190,170 @@ static void ggml_zendnn_compute_forward_mul_mat(
}
}
struct mmid_row_mapping {
int32_t i1;
int32_t i2;
};
static void ggml_zendnn_compute_forward_mul_mat_id(
ggml_backend_zendnn_context * ctx,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; // expert weights
const ggml_tensor * src1 = dst->src[1]; // inputs
const ggml_tensor * ids = dst->src[2]; // expert ids
GGML_TENSOR_BINARY_OP_LOCALS
// exit for no tokens to process
if (ne2 == 0 || ne11 == 0) {
return;
}
ggml_type const vec_dot_type = src0->type;
ggml_from_float_t const from_float = ggml_get_type_traits(vec_dot_type)->from_float_ref;
// we don't support permuted src0 or src1
GGML_ASSERT(nb00 == ggml_type_size(src0->type));
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
// dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float));
GGML_ASSERT(nb0 <= nb1);
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);
GGML_ASSERT(ne03 == 1);
GGML_ASSERT(ne13 == 1);
GGML_ASSERT(ne3 == 1);
// row groups
const int n_ids = ids->ne[0]; // n_expert_used
const int n_as = ne02; // n_experts
std::vector<int64_t> matrix_row_counts(n_as, 0);
std::vector<std::vector<mmid_row_mapping>> matrix_rows(n_as);
int64_t max_rows = 0;
// group rows by expert (preprocessing step)
for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
for (int id = 0; id < n_ids; ++id) {
const int32_t i02 = *(const int32_t *)((const char *)ids->data + iid1*ids->nb[1] + id*ids->nb[0]);
GGML_ASSERT(i02 >= 0 && i02 < n_as);
matrix_rows[i02].push_back({id, iid1});
matrix_row_counts[i02]++;
if (matrix_row_counts[i02] > max_rows) {
max_rows = matrix_row_counts[i02];
}
}
}
if (max_rows == 0) {
return; // no rows to process
}
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
// size for converting src1 rows to vec_dot_type if needed
const size_t nbw1 = row_size;
const size_t nbw2 = nbw1 * ne11;
const size_t nbw3 = nbw2 * ne12;
const size_t src1_conv_size = (src1->type != vec_dot_type) ? ne13 * nbw3 : 0;
// size for MoE gather/scatter buffers
const size_t wdata_cur_size = max_rows * row_size;
const size_t dst_cur_size = max_rows * ggml_row_size(dst->type, ne01);
// allocate single buffer for all needs
const size_t total_size = src1_conv_size + wdata_cur_size + dst_cur_size;
if (ctx->work_size < total_size) {
ctx->work_data.reset(new char[total_size]);
ctx->work_size = total_size;
}
// partition the buffer
char * work_data = ctx->work_data.get();
char * wdata_cur = work_data + src1_conv_size;
char * dst_cur = wdata_cur + wdata_cur_size;
if (src1->type != vec_dot_type) {
GGML_ASSERT(src1->type == GGML_TYPE_F32);
#pragma omp parallel for collapse(3) num_threads(ctx->n_threads) schedule(static)
for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i12 = 0; i12 < ne12; ++i12) {
for (int64_t i11 = 0; i11 < ne11; ++i11) {
const float * src1_f32 = (float *)((char *)src1->data + i11*nb11 + i12*nb12 + i13*nb13);
void * src1_conv = (char *)work_data + i11*nbw1 + i12*nbw2 + i13*nbw3;
from_float(src1_f32, src1_conv, ne10);
}
}
}
}
const void * wdata = src1->type == vec_dot_type ? src1->data : work_data;
// process each expert with gather -> gemm -> scatter pattern
for (int64_t cur_a = 0; cur_a < n_as; ++cur_a) {
const int64_t cne1 = matrix_row_counts[cur_a];
if (cne1 == 0) {
continue;
}
const char * src0_cur = (const char *) src0->data + cur_a*nb02;
// gather input rows for this expert
#pragma omp parallel for num_threads(ctx->n_threads) schedule(static)
for (int64_t ir1 = 0; ir1 < cne1; ++ir1) {
const mmid_row_mapping & row_mapping = matrix_rows[cur_a][ir1];
const int64_t id = row_mapping.i1;
const int64_t i11 = id % ne11;
const int64_t i12 = row_mapping.i2;
std::memcpy(
wdata_cur + ir1 * row_size,
(const char *) wdata + (i11 + i12*ne11) * row_size,
row_size
);
}
// batched gemm for all tokens in this expert
if (!ggml_zendnn_sgemm(ctx,
ne01, // m
cne1, // n
ne10, // k
src0_cur,
ne00, // lda
wdata_cur,
ne10, // ldb
dst_cur,
ne01, // ldc
src0->type,
vec_dot_type,
dst->type)) {
GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__);
}
// scatter output rows to destination
#pragma omp parallel for num_threads(ctx->n_threads) schedule(static)
for (int64_t ir1 = 0; ir1 < cne1; ++ir1) {
const mmid_row_mapping & row_mapping = matrix_rows[cur_a][ir1];
const int64_t id = row_mapping.i1;
const int64_t i1 = id;
const int64_t i2 = row_mapping.i2;
std::memcpy(
(char *) dst->data + i1*nb1 + i2*nb2,
dst_cur + ir1 * ggml_row_size(dst->type, ne01),
ggml_row_size(dst->type, ne01)
);
}
}
}
// backend interface
static const char * ggml_backend_zendnn_get_name(ggml_backend_t backend) {
@@ -218,6 +382,9 @@ static ggml_status ggml_backend_zendnn_graph_compute(ggml_backend_t backend, ggm
case GGML_OP_MUL_MAT:
ggml_zendnn_compute_forward_mul_mat(ctx, node);
break;
case GGML_OP_MUL_MAT_ID:
ggml_zendnn_compute_forward_mul_mat_id(ctx, node);
break;
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
@@ -361,6 +528,7 @@ static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const
return true;
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
{
const ggml_tensor * weights = op->src[0];
const ggml_tensor * inputs = op->src[1];
@@ -374,6 +542,17 @@ static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const
ne0 < min_batch || ne1 < min_batch || ne10 < min_batch) {
return false;
}
// MUL_MAT_ID performs best with a moderate number of experts due to its
// gather + batched matmul + scatter approach. Future versions will leverage
// ZenDNN's grouped_gemm for better scalability with larger expert counts:
// https://github.com/amd/ZenDNN/blob/main/docs/operator/lowoha_group_gemm_operator.md
if (op->op == GGML_OP_MUL_MAT_ID) {
const int64_t n_experts = weights->ne[2];
const int64_t max_experts = 32;
if (n_experts > max_experts) {
return false;
}
}
switch (weights->type) {
case GGML_TYPE_F32:
case GGML_TYPE_BF16:
+266
View File
@@ -0,0 +1,266 @@
{%- macro format_parameters(properties, required) -%}
{%- set standard_keys = ['description', 'type', 'properties', 'required', 'nullable'] -%}
{%- set ns = namespace(found_first=false) -%}
{%- for key, value in properties | dictsort -%}
{%- set add_comma = false -%}
{%- if key not in standard_keys -%}
{%- if ns.found_first %},{% endif -%}
{%- set ns.found_first = true -%}
{{ key }}:{
{%- if value['description'] -%}
description:<|"|>{{ value['description'] }}<|"|>
{%- set add_comma = true -%}
{%- endif -%}
{%- if value['nullable'] %}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
nullable:true
{%- endif -%}
{%- if value['type'] | upper == 'STRING' -%}
{%- if value['enum'] -%}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
enum:{{ format_argument(value['enum']) }}
{%- endif -%}
{%- elif value['type'] | upper == 'OBJECT' -%}
,properties:{
{%- if value['properties'] is defined and value['properties'] is mapping -%}
{{- format_parameters(value['properties'], value['required'] | default([])) -}}
{%- elif value is mapping -%}
{{- format_parameters(value, value['required'] | default([])) -}}
{%- endif -%}
}
{%- if value['required'] -%}
,required:[
{%- for item in value['required'] | default([]) -%}
<|"|>{{- item -}}<|"|>
{%- if not loop.last %},{% endif -%}
{%- endfor -%}
]
{%- endif -%}
{%- elif value['type'] | upper == 'ARRAY' -%}
{%- if value['items'] is mapping and value['items'] -%}
,items:{
{%- set ns_items = namespace(found_first=false) -%}
{%- for item_key, item_value in value['items'] | dictsort -%}
{%- if item_value is not none -%}
{%- if ns_items.found_first %},{% endif -%}
{%- set ns_items.found_first = true -%}
{%- if item_key == 'properties' -%}
properties:{
{%- if item_value is mapping -%}
{{- format_parameters(item_value, value['items']['required'] | default([])) -}}
{%- endif -%}
}
{%- elif item_key == 'required' -%}
required:[
{%- for req_item in item_value -%}
<|"|>{{- req_item -}}<|"|>
{%- if not loop.last %},{% endif -%}
{%- endfor -%}
]
{%- elif item_key == 'type' -%}
{%- if item_value is string -%}
type:{{ format_argument(item_value | upper) }}
{%- else -%}
type:{{ format_argument(item_value | map('upper') | list) }}
{%- endif -%}
{%- else -%}
{{ item_key }}:{{ format_argument(item_value) }}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
}
{%- endif -%}
{%- endif -%}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
type:<|"|>{{ value['type'] | upper }}<|"|>}
{%- endif -%}
{%- endfor -%}
{%- endmacro -%}
{%- macro format_function_declaration(tool_data) -%}
declaration:{{- tool_data['function']['name'] -}}{description:<|"|>{{- tool_data['function']['description'] -}}<|"|>
{%- set params = tool_data['function']['parameters'] -%}
{%- if params -%}
,parameters:{
{%- if params['properties'] -%}
properties:{ {{- format_parameters(params['properties'], params['required']) -}} },
{%- endif -%}
{%- if params['required'] -%}
required:[
{%- for item in params['required'] -%}
<|"|>{{- item -}}<|"|>
{{- ',' if not loop.last -}}
{%- endfor -%}
],
{%- endif -%}
{%- if params['type'] -%}
type:<|"|>{{- params['type'] | upper -}}<|"|>}
{%- endif -%}
{%- endif -%}
{%- if 'response' in tool_data['function'] -%}
{%- set response_declaration = tool_data['function']['response'] -%}
,response:{
{%- if response_declaration['description'] -%}
description:<|"|>{{- response_declaration['description'] -}}<|"|>,
{%- endif -%}
{%- if response_declaration['type'] | upper == 'OBJECT' -%}
type:<|"|>{{- response_declaration['type'] | upper -}}<|"|>}
{%- endif -%}
{%- endif -%}
}
{%- endmacro -%}
{%- macro format_argument(argument, escape_keys=True) -%}
{%- if argument is string -%}
{{- '<|"|>' + argument + '<|"|>' -}}
{%- elif argument is boolean -%}
{{- 'true' if argument else 'false' -}}
{%- elif argument is mapping -%}
{{- '{' -}}
{%- set ns = namespace(found_first=false) -%}
{%- for key, value in argument | dictsort -%}
{%- if ns.found_first %},{% endif -%}
{%- set ns.found_first = true -%}
{%- if escape_keys -%}
{{- '<|"|>' + key + '<|"|>' -}}
{%- else -%}
{{- key -}}
{%- endif -%}
:{{- format_argument(value, escape_keys=escape_keys) -}}
{%- endfor -%}
{{- '}' -}}
{%- elif argument is sequence -%}
{{- '[' -}}
{%- for item in argument -%}
{{- format_argument(item, escape_keys=escape_keys) -}}
{%- if not loop.last %},{% endif -%}
{%- endfor -%}
{{- ']' -}}
{%- else -%}
{{- argument -}}
{%- endif -%}
{%- endmacro -%}
{%- macro strip_thinking(text) -%}
{%- set ns = namespace(result='') -%}
{%- for part in text.split('<channel|>') -%}
{%- if '<|channel>' in part -%}
{%- set ns.result = ns.result + part.split('<|channel>')[0] -%}
{%- else -%}
{%- set ns.result = ns.result + part -%}
{%- endif -%}
{%- endfor -%}
{{- ns.result | trim -}}
{%- endmacro -%}
{%- set ns = namespace(prev_message_type=None) -%}
{%- set loop_messages = messages -%}
{{ bos_token }}
{#- Handle System/Tool Definitions Block -#}
{%- if (enable_thinking is defined and enable_thinking) or tools or messages[0]['role'] in ['system', 'developer'] -%}
{{- '<|turn>system\n' -}}
{#- Inject Thinking token at the very top of the FIRST system turn -#}
{%- if enable_thinking is defined and enable_thinking -%}
{{- '<|think|>' -}}
{%- set ns.prev_message_type = 'think' -%}
{%- endif -%}
{%- if messages[0]['role'] in ['system', 'developer'] -%}
{{- messages[0]['content'] | trim -}}
{%- set loop_messages = messages[1:] -%}
{%- endif -%}
{%- if tools -%}
{%- for tool in tools %}
{{- '<|tool>' -}}
{{- format_function_declaration(tool) | trim -}}
{{- '<tool|>' -}}
{%- endfor %}
{%- set ns.prev_message_type = 'tool' -%}
{%- endif -%}
{{- '<turn|>\n' -}}
{%- endif %}
{#- Loop through messages -#}
{%- for message in loop_messages -%}
{%- set ns.prev_message_type = None -%}
{%- set role = 'model' if message['role'] == 'assistant' else message['role'] -%}
{{- '<|turn>' + role + '\n' }}
{%- if message['tool_calls'] -%}
{%- for tool_call in message['tool_calls'] -%}
{%- set function = tool_call['function'] -%}
{{- '<|tool_call>call:' + function['name'] + '{' -}}
{%- if function['arguments'] is mapping -%}
{%- set ns_args = namespace(found_first=false) -%}
{%- for key, value in function['arguments'] | dictsort -%}
{%- if ns_args.found_first %},{% endif -%}
{%- set ns_args.found_first = true -%}
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
{%- endfor -%}
{%- elif function['arguments'] is string -%}
{{- function['arguments'] -}}
{%- endif -%}
{{- '}<tool_call|>' -}}
{%- endfor -%}
{%- set ns.prev_message_type = 'tool_call' -%}
{%- endif -%}
{%- if message['tool_responses'] -%}
{#- Tool Response handling -#}
{%- for tool_response in message['tool_responses'] -%}
{{- '<|tool_response>' -}}
{%- if tool_response['response'] is mapping -%}
{{- 'response:' + tool_response['name'] | default('unknown') + '{' -}}
{%- for key, value in tool_response['response'] | dictsort -%}
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
{%- if not loop.last %},{% endif -%}
{%- endfor -%}
{{- '}' -}}
{%- else -%}
{{- 'response:' + tool_response['name'] | default('unknown') + '{value:' + format_argument(tool_response['response'], escape_keys=False) + '}' -}}
{%- endif -%}
{{- '<tool_response|>' -}}
{%- endfor -%}
{%- set ns.prev_message_type = 'tool_response' -%}
{%- endif -%}
{%- if message['content'] is string -%}
{%- if role == 'model' -%}
{{- strip_thinking(message['content']) -}}
{%- else -%}
{{- message['content'] | trim -}}
{%- endif -%}
{%- elif message['content'] is sequence -%}
{%- for item in message['content'] -%}
{%- if item['type'] == 'text' -%}
{%- if role == 'model' -%}
{{- strip_thinking(item['text']) -}}
{%- else -%}
{{- item['text'] | trim -}}
{%- endif -%}
{%- elif item['type'] == 'image' -%}
{{- '\n\n<|image|>\n\n' -}}
{%- set ns.prev_message_type = 'image' -%}
{%- elif item['type'] == 'audio' -%}
{{- '<|audio|>' -}}
{%- set ns.prev_message_type = 'audio' -%}
{%- elif item['type'] == 'video' -%}
{{- '\n\n<|video|>\n\n' -}}
{%- set ns.prev_message_type = 'video' -%}
{%- endif -%}
{%- endfor -%}
{%- endif -%}
{%- if not (message['tool_responses'] and not message['content']) -%}
{{- '<turn|>\n' -}}
{%- endif -%}
{%- endfor -%}
{%- if add_generation_prompt -%}
{%- if ns.prev_message_type != 'tool_response' -%}
{{- '<|turn>model\n' -}}
{%- endif -%}
{%- if not enable_thinking | default(false) -%}
{{- '<|channel>thought\n<channel|>' -}}
{%- endif -%}
{%- endif -%}
+47 -3
View File
@@ -1,8 +1,8 @@
#pragma once
#include "llama-context.h"
#include "ggml.h"
#include "stdint.h"
#include "llama.h"
#include <cstdint>
// Reserve a new compute graph. It is valid until the next call to llama_graph_reserve.
LLAMA_API struct ggml_cgraph * llama_graph_reserve(
@@ -10,3 +10,47 @@ LLAMA_API struct ggml_cgraph * llama_graph_reserve(
uint32_t n_tokens,
uint32_t n_seqs,
uint32_t n_outputs);
// Get the default ggml_type for a given ftype.
LLAMA_API ggml_type llama_ftype_get_default_type(llama_ftype ftype);
// Quantization state.
struct quantize_state_impl;
LLAMA_API quantize_state_impl * llama_quant_init(
const llama_model * model,
const llama_model_quantize_params * params);
LLAMA_API void llama_quant_free(quantize_state_impl * qs);
// Descriptor for constructing a mock model for quantization testing.
struct llama_quant_model_desc {
const char * architecture;
uint32_t n_embd;
uint32_t n_ff;
uint32_t n_layer;
uint32_t n_head;
uint32_t n_head_kv;
uint32_t n_expert;
uint32_t n_embd_head_k;
uint32_t n_embd_head_v;
};
// Create a mock model from a metadata descriptor (for testing).
// The returned model must be freed with llama_model_free().
LLAMA_API llama_model * llama_quant_model_from_metadata(const llama_quant_model_desc * desc);
// Returns true if this tensor should be quantized (based on name, dims, params).
LLAMA_API bool llama_quant_tensor_allows_quantization(
const quantize_state_impl * qs,
const ggml_tensor * tensor);
// Compute quantization type assignments for a list of tensors.
// All tensors should be quantizable (use llama_quant_tensor_allows_quantization to filter).
// result_types: caller-allocated array of n_tensors elements, filled with assigned types.
LLAMA_API void llama_quant_compute_types(
quantize_state_impl * qs,
llama_ftype ftype,
ggml_tensor ** tensors,
ggml_type * result_types,
size_t n_tensors);
+1 -2
View File
@@ -66,9 +66,8 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
// note: the SWA cache is never quantized because it is relatively small
kv_swa = std::make_unique<llama_kv_cache>(
model, GGML_TYPE_F16, GGML_TYPE_F16,
model, type_k, type_v,
v_trans, offload, unified, size_swa, n_seq_max, n_pad,
hparams.n_swa, hparams.swa_type, filter_swa, reuse);
}
+125 -31
View File
@@ -1,11 +1,11 @@
#include "llama.h"
#include "llama-impl.h"
#include "llama-model.h"
#include "llama-model-loader.h"
#include "llama-ext.h"
#include <algorithm>
#include <cmath>
#include <cstring>
#include <string>
#include <cinttypes>
#include <fstream>
#include <mutex>
@@ -197,6 +197,7 @@ struct quantize_state_impl {
// per-tensor metadata, computed in the preliminary loop and used in the main loop
struct tensor_metadata {
std::string name;
ggml_type target_type;
tensor_category category;
std::string remapped_imatrix_name;
@@ -788,7 +789,7 @@ static bool tensor_requires_imatrix(const char * tensor_name, const ggml_type ds
// given a file type, get the default tensor type
//
static ggml_type llama_ftype_get_default_type(llama_ftype ftype) {
ggml_type llama_ftype_get_default_type(llama_ftype ftype) {
switch (ftype) {
case LLAMA_FTYPE_MOSTLY_Q4_0: return GGML_TYPE_Q4_0;
case LLAMA_FTYPE_MOSTLY_Q4_1: return GGML_TYPE_Q4_1;
@@ -827,16 +828,32 @@ static ggml_type llama_ftype_get_default_type(llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_IQ3_S:
case LLAMA_FTYPE_MOSTLY_IQ3_M: return GGML_TYPE_IQ3_S;
default: throw std::runtime_error(format("invalid output file type %d\n", ftype));
default: return GGML_TYPE_COUNT;
}
}
static void init_quantize_state_counters(quantize_state_impl & qs, std::vector<tensor_metadata> & metadata) {
for (auto & tm : metadata) {
tensor_category cat = tensor_get_category(tm.name);
tm.category = cat;
if (category_is_attn_v(cat)) {
++qs.n_attention_wv;
}
if (cat == tensor_category::OUTPUT) {
qs.has_tied_embeddings = false;
}
}
qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)qs.model.hparams.n_layer;
}
//
// main quantization driver
//
static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) {
ggml_type default_type;
llama_ftype ftype = params->ftype;
int nthread = params->nthread;
@@ -845,7 +862,10 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
nthread = std::thread::hardware_concurrency();
}
default_type = llama_ftype_get_default_type(ftype);
ggml_type default_type = llama_ftype_get_default_type(ftype);
if (default_type == GGML_TYPE_COUNT) {
throw std::runtime_error(format("invalid output file type %d\n", ftype));
}
// mmap consistently increases speed on Linux, and also increases speed on Windows with
// hot cache. It may cause a slowdown on macOS, possibly related to free memory.
@@ -964,6 +984,15 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
});
}
// compute tensor metadata once and cache it
std::vector<tensor_metadata> metadata(tensors.size());
for (size_t i = 0; i < tensors.size(); ++i) {
metadata[i].name = ggml_get_name(tensors[i]->tensor);
}
// initialize quantization state counters and metadata categories
init_quantize_state_counters(qs, metadata);
int idx = 0;
uint16_t n_split = 1;
@@ -976,25 +1005,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
std::vector<gguf_context_ptr> ctx_outs(n_split);
ctx_outs[0] = std::move(ctx_out);
// compute tensor metadata once and cache it
std::vector<tensor_metadata> metadata(tensors.size());
// initialize quantization state before preliminary loop (counters for use_more_bits)
{
for (size_t i = 0; i < tensors.size(); ++i) {
const auto cat = tensor_get_category(tensors[i]->tensor->name);
if (category_is_attn_v(cat)) {
++qs.n_attention_wv;
}
if (cat == tensor_category::OUTPUT) {
qs.has_tied_embeddings = false;
}
metadata[i].category = cat; // save and re-use the category while we're at it
}
// these also need to be set to n_layer by default
qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)qs.model.hparams.n_layer;
}
// flag for --dry-run
bool will_require_imatrix = false;
@@ -1005,7 +1015,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
for (size_t i = 0; i < tensors.size(); ++i) {
const auto * it = tensors[i];
const struct ggml_tensor * tensor = it->tensor;
const std::string name = ggml_get_name(tensor);
uint16_t i_split = params->keep_split ? it->idx : 0;
if (!ctx_outs[i_split]) {
@@ -1034,7 +1043,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
" - offending tensor: %s\n"
" - target type: %s\n"
"============================================================================\n\n",
name.c_str(), ggml_type_name(metadata[i].target_type));
metadata[i].name.c_str(), ggml_type_name(metadata[i].target_type));
throw std::runtime_error("this quantization requires an imatrix!");
}
}
@@ -1107,7 +1116,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
new_ofstream(weight.idx);
}
const std::string name = ggml_get_name(tensor);
const size_t tensor_size = ggml_nbytes(tensor);
if (!params->dry_run) {
@@ -1238,9 +1246,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
total_size_new += new_size;
// update the gguf meta data as we go
gguf_set_tensor_type(ctx_outs[cur_split].get(), name.c_str(), new_type);
GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), name.c_str())) == new_size);
gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data);
gguf_set_tensor_type(ctx_outs[cur_split].get(), metadata[i].name.c_str(), new_type);
GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), metadata[i].name.c_str())) == new_size);
gguf_set_tensor_data(ctx_outs[cur_split].get(), metadata[i].name.c_str(), new_data);
// write tensor data + padding
fout.write((const char *) new_data, new_size);
@@ -1305,3 +1313,89 @@ uint32_t llama_model_quantize(
return 0;
}
//
// Helper functions for external tools exposed in llama-ext.h
//
quantize_state_impl * llama_quant_init(
const llama_model * model,
const llama_model_quantize_params * params) {
return new quantize_state_impl(*model, params);
}
void llama_quant_free(quantize_state_impl * qs) {
delete qs;
}
llama_model * llama_quant_model_from_metadata(const llama_quant_model_desc * desc) {
struct llama_model_params mparams = llama_model_default_params();
auto * model = new llama_model(mparams);
model->arch = llm_arch_from_string(desc->architecture);
// infer llm_type: only LLM_TYPE_70B matters for quantization logic
if (model->arch == LLM_ARCH_LLAMA && desc->n_layer == 80 && desc->n_head != desc->n_head_kv) {
model->type = LLM_TYPE_70B;
}
model->hparams.n_embd = desc->n_embd;
model->hparams.n_embd_head_k_full = desc->n_embd_head_k;
model->hparams.n_embd_head_v_full = desc->n_embd_head_v;
model->hparams.n_layer = desc->n_layer;
model->hparams.n_expert = desc->n_expert;
for (uint32_t i = 0; i < desc->n_layer; i++) {
model->hparams.n_head_arr[i] = desc->n_head;
model->hparams.n_head_kv_arr[i] = desc->n_head_kv;
model->hparams.n_ff_arr[i] = desc->n_ff;
}
return model;
}
bool llama_quant_tensor_allows_quantization(
const quantize_state_impl * qs,
const ggml_tensor * tensor) {
return tensor_allows_quantization(qs->params, qs->model.arch, tensor);
}
void llama_quant_compute_types(
quantize_state_impl * qs,
llama_ftype ftype,
ggml_tensor ** tensors,
ggml_type * result_types,
size_t n_tensors) {
// reset per-computation state
qs->n_attention_wv = 0;
qs->n_ffn_down = 0;
qs->n_ffn_gate = 0;
qs->n_ffn_up = 0;
qs->i_attention_wv = 0;
qs->i_ffn_down = 0;
qs->i_ffn_gate = 0;
qs->i_ffn_up = 0;
qs->n_fallback = 0;
qs->has_imatrix = false;
qs->has_tied_embeddings = true;
// build metadata from tensor names
std::vector<tensor_metadata> metadata(n_tensors);
for (size_t i = 0; i < n_tensors; i++) {
metadata[i].name = ggml_get_name(tensors[i]);
}
// initialize counters and categories
init_quantize_state_counters(*qs, metadata);
// use a local copy of params with the requested ftype
llama_model_quantize_params local_params = *qs->params;
local_params.ftype = ftype;
ggml_type default_type = llama_ftype_get_default_type(ftype);
// compute types
for (size_t i = 0; i < n_tensors; i++) {
result_types[i] = llama_tensor_get_type(*qs, &local_params, tensors[i], default_type, metadata[i]);
}
}
+61 -5
View File
@@ -493,6 +493,16 @@ struct llm_tokenizer_bpe : llm_tokenizer {
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?(?:\\p{L}\\p{M}*(?: \\p{L}\\p{M}*)*)+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]?|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+",
};
break;
case LLAMA_VOCAB_PRE_TYPE_GEMMA4:
// Gemma4 uses SPM-style BPE: spaces are replaced with ▁ by the
// normalizer, then BPE merges run on the whole text without
// word-level pre-splitting. We only need to split on newlines
// since BPE merge lookup asserts no newlines in tokens.
regex_exprs = {
"[^\\n]+|[\\n]+",
};
byte_encode = false; // uses raw UTF-8, not GPT-2 byte encoding
break;
default:
// default regex for BPE tokenization pre-processing
regex_exprs = {
@@ -506,6 +516,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
}
std::vector<std::string> regex_exprs;
bool byte_encode = true; // GPT-2 byte encoding; false for SPM-style BPE (raw UTF-8)
};
struct llm_tokenizer_bpe_session {
@@ -550,9 +561,10 @@ struct llm_tokenizer_bpe_session {
void tokenize(const std::string & text, std::vector<llama_token> & output) {
int final_prev_index = -1;
const auto word_collection = unicode_regex_split(text, tokenizer.regex_exprs);
const auto word_collection = unicode_regex_split(text, tokenizer.regex_exprs, tokenizer.byte_encode);
symbols_final.clear();
auto tok_pre = vocab.get_pre_type();
for (const auto & word : word_collection) {
work_queue = llm_bigram_bpe::queue();
@@ -565,6 +577,13 @@ struct llm_tokenizer_bpe_session {
if (vocab.get_ignore_merges() && vocab.text_to_token(word) != LLAMA_TOKEN_NULL) {
symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()});
offset = word.size();
} else if (tok_pre == LLAMA_VOCAB_PRE_TYPE_GEMMA4 && word.find_first_not_of('\n') == std::string::npos) {
// fix for gemma 4, ref: https://github.com/ggml-org/llama.cpp/pull/21343
auto tok = vocab.text_to_token(word);
if (tok != LLAMA_TOKEN_NULL) {
symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()});
offset = word.size();
}
}
while (offset < word.size()) {
@@ -1864,7 +1883,31 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
special_pad_id = 3; // <|plamo:pad|>
special_mask_id = LLAMA_TOKEN_NULL;
} else if (tokenizer_model == "gemma4") {
type = LLAMA_VOCAB_TYPE_SPM;
type = LLAMA_VOCAB_TYPE_BPE;
// read bpe merges and populate bpe ranks
const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str());
if (merges_keyidx == -1) {
throw std::runtime_error("cannot find tokenizer merges in model file\n");
}
{
const int n_merges = gguf_get_arr_n(ctx, merges_keyidx);
for (int i = 0; i < n_merges; i++) {
const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i);
std::string first;
std::string second;
const size_t pos = word.find(' ', 1);
if (pos != std::string::npos) {
first = word.substr(0, pos);
second = word.substr(pos + 1);
}
bpe_ranks.emplace(std::make_pair(first, second), i);
}
}
// default special tokens (to be read from GGUF)
special_bos_id = LLAMA_TOKEN_NULL;
@@ -1874,7 +1917,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
special_pad_id = LLAMA_TOKEN_NULL;
special_mask_id = LLAMA_TOKEN_NULL;
tokenizer_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
tokenizer_pre = "gemma4";
} else {
throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str()));
}
@@ -1882,6 +1925,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
// for now, only BPE models have pre-tokenizers
if (type == LLAMA_VOCAB_TYPE_BPE) {
add_space_prefix = false;
escape_whitespaces = false;
clean_spaces = true;
if (tokenizer_pre.empty()) {
LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__);
@@ -1948,6 +1992,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
} else if (
tokenizer_pre == "jais-2") {
pre_type = LLAMA_VOCAB_PRE_TYPE_JAIS2;
} else if (
tokenizer_pre == "gemma4") {
pre_type = LLAMA_VOCAB_PRE_TYPE_GEMMA4;
escape_whitespaces = true;
} else if (
tokenizer_pre == "jina-v1-en" ||
tokenizer_pre == "jina-v2-code" ||
@@ -3045,6 +3093,10 @@ std::vector<llama_token> llama_vocab::impl::tokenize(
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
if (escape_whitespaces) {
llama_escape_whitespace(text);
}
#ifdef PRETOKENIZERDEBUG
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
#endif
@@ -3224,6 +3276,12 @@ int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t
return _try_copy(token_text.data(), token_text.size());
}
if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
if (escape_whitespaces) {
// SPM-style BPE: tokens contain ▁ for spaces
std::string result = token_text;
llama_unescape_whitespace(result);
return _try_copy(result.data(), result.size());
}
std::string result = llama_decode_text(token_text);
return _try_copy(result.data(), result.size());
}
@@ -3654,9 +3712,7 @@ int llama_vocab::max_token_len() const {
int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string & token_right) const {
GGML_ASSERT(token_left.find(' ') == std::string::npos);
GGML_ASSERT(token_left.find('\n') == std::string::npos);
GGML_ASSERT(token_right.find(' ') == std::string::npos);
GGML_ASSERT(token_right.find('\n') == std::string::npos);
auto it = pimpl->bpe_ranks.find(std::make_pair(token_left, token_right));
if (it == pimpl->bpe_ranks.end()) {
+1
View File
@@ -58,6 +58,7 @@ enum llama_vocab_pre_type {
LLAMA_VOCAB_PRE_TYPE_TINY_AYA = 47,
LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM = 48,
LLAMA_VOCAB_PRE_TYPE_JAIS2 = 49,
LLAMA_VOCAB_PRE_TYPE_GEMMA4 = 50,
};
struct LLM_KV;
+6 -2
View File
@@ -912,7 +912,7 @@ bool unicode_cpt_is_han(uint32_t cpt) {
return false;
}
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs, bool byte_encode) {
// unicode categories
static const std::map<std::string, int> k_ucat_enum = {
{ "\\p{N}", unicode_cpt_flags::NUMBER },
@@ -1099,5 +1099,9 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
start += offset;
}
return unicode_byte_encoding_process(bpe_words);
if (byte_encode) {
return unicode_byte_encoding_process(bpe_words);
}
return bpe_words;
}
+1 -1
View File
@@ -108,4 +108,4 @@ uint32_t unicode_tolower(uint32_t cpt);
bool unicode_cpt_is_han(uint32_t cpt);
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs);
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs, bool byte_encode = true);
+1
View File
@@ -1,5 +1,6 @@
*
!*.*
!snapshots/
*.o
ggml-common.h
**/*.swp
+10
View File
@@ -274,6 +274,12 @@ if (TARGET cpp-httplib)
add_executable(test-gguf-model-data test-gguf-model-data.cpp)
target_link_libraries(test-gguf-model-data PRIVATE gguf-model-data common)
llama_test(test-gguf-model-data LABEL "model")
# test-quant-type-selection requires gguf-model-data for remote model metadata
llama_build_and_test(test-quant-type-selection.cpp LABEL "model")
target_link_libraries(test-quant-type-selection PRIVATE gguf-model-data)
target_compile_definitions(test-quant-type-selection PRIVATE
SNAPSHOT_DIR="${CMAKE_CURRENT_SOURCE_DIR}/snapshots")
endif()
endif()
@@ -287,3 +293,7 @@ target_include_directories(test-alloc PRIVATE ${PROJECT_SOURCE_DIR}/ggml/src)
llama_build(export-graph-ops.cpp)
target_include_directories(export-graph-ops PRIVATE ${PROJECT_SOURCE_DIR}/ggml/src)
if (TARGET gguf-model-data)
target_link_libraries(export-graph-ops PRIVATE gguf-model-data)
target_compile_definitions(export-graph-ops PRIVATE LLAMA_HF_FETCH)
endif()
+63 -6
View File
@@ -1,15 +1,26 @@
#include "arg.h"
#include "common.h"
#include "log.h"
#include "llama.h"
#include "llama-cpp.h"
#include "../src/llama-ext.h"
#include "ggml.h"
#include "gguf-model-data.h"
#include "gguf.h"
#include "ggml-backend.h"
#include "download.h"
#include <array>
#include <vector>
#include <set>
#include <fstream>
#include <iostream>
#include <random>
// Noop because weights are not needed
static void set_tensor_data(struct ggml_tensor * tensor, void * userdata) {
GGML_UNUSED(tensor);
GGML_UNUSED(userdata);
}
struct input_tensor {
ggml_type type;
@@ -132,9 +143,52 @@ int main(int argc, char ** argv) {
params.warmup = false;
auto init_result = common_init_from_params(params);
llama_context * ctx;
common_init_result_ptr init_result;
llama_context_ptr ctx2;
llama_model_ptr model;
llama_context * ctx = init_result->context();
if (params.model.hf_repo.empty()) {
init_result = common_init_from_params(params);
ctx = init_result->context();
} else {
#ifdef LLAMA_HF_FETCH
auto [hf_repo, hf_quant] = common_download_split_repo_tag(params.model.hf_repo);
if (hf_quant.empty() || hf_quant == "latest") {
hf_quant = "Q4_K_M";
}
gguf_context_ptr gguf_ctx = gguf_fetch_gguf_ctx(hf_repo, hf_quant);
if (!gguf_ctx) {
LOG_ERR("failed to fetch GGUF metadata from %s\n", hf_repo.c_str());
return 1;
}
llama_model_params model_params = llama_model_default_params();
model_params.devices = params.devices.data();
model_params.no_alloc = true;
model.reset(llama_model_init_from_user(gguf_ctx.get(), set_tensor_data, nullptr, model_params));
if (!model) {
LOG_ERR("failed to create llama_model from %s\n", hf_repo.c_str());
return 1;
}
llama_context_params ctx_params = llama_context_default_params();
ctx2.reset(llama_init_from_model(model.get(), ctx_params));
ctx = ctx2.get();
if (!ctx) {
LOG_ERR("failed to create llama_context\n");
return 1;
}
#else
LOG_ERR("export-graph-ops compiled without HF fetch support\n");
return 1;
#endif
}
const uint32_t n_seqs = llama_n_seq_max(ctx);
const uint32_t n_tokens = std::min(llama_n_ctx(ctx), llama_n_ubatch(ctx));
@@ -143,13 +197,15 @@ int main(int argc, char ** argv) {
auto * gf_pp = llama_graph_reserve(ctx, n_tokens, n_seqs, n_tokens);
if (!gf_pp) {
throw std::runtime_error("failed to reserve prompt processing graph");
LOG_ERR("failed to reserve prompt processing graph\n");
return 1;
}
extract_graph_ops(gf_pp, "pp", tests);
auto * gf_tg = llama_graph_reserve(ctx, n_seqs, n_seqs, n_seqs);
if (!gf_tg) {
throw std::runtime_error("failed to reserve token generation graph");
LOG_ERR("failed to reserve token generation graph\n");
return 1;
}
extract_graph_ops(gf_tg, "tg", tests);
@@ -158,7 +214,8 @@ int main(int argc, char ** argv) {
std::ofstream f(params.out_file);
if (!f.is_open()) {
throw std::runtime_error("Unable to open output file");
LOG_ERR("unable to open output file: %s\n", params.out_file.c_str());
return 1;
}
for (const auto& test : tests) {
+139 -12
View File
@@ -4,6 +4,7 @@
#include "gguf-model-data.h"
#include "common.h"
#include "ggml-cpp.h"
#include "gguf.h"
#include <algorithm>
@@ -124,6 +125,35 @@ static bool gguf_skip_value(gguf_buf_reader & r, int32_t vtype) {
}
static bool gguf_read_uint32_val(gguf_buf_reader & r, int32_t vtype, uint32_t & out) {
// Handle array-valued fields (e.g. per-layer head counts in hybrid models)
// by reading the first element as a representative value.
if (vtype == GGUF_TYPE_ARRAY) {
int32_t elem_type;
uint64_t count;
if (!r.read_val(elem_type)) {
return false;
}
if (!r.read_val(count)) {
return false;
}
if (count == 0) {
return false;
}
// Read first element, skip the rest
if (!gguf_read_uint32_val(r, elem_type, out)) {
return false;
}
for (uint64_t i = 1; i < count; i++) {
size_t sz = gguf_val_type_size(elem_type);
if (sz == 0) {
return false;
}
if (!r.skip(sz)) {
return false;
}
}
return true;
}
if (vtype == GGUF_TYPE_UINT8) {
uint8_t v;
if (!r.read_val(v)) {
@@ -486,7 +516,8 @@ static std::string detect_gguf_filename(const std::string & repo, const std::str
static std::optional<gguf_remote_model> fetch_and_parse(
const std::string & repo,
const std::string & filename,
const std::string & cache_path) {
const std::string & cache_path,
bool verbose) {
std::string url = "https://huggingface.co/" + repo + "/resolve/main/" + filename;
// Progressive download inspired by RangeView.fetchChunk()
@@ -495,7 +526,9 @@ static std::optional<gguf_remote_model> fetch_and_parse(
const size_t max_chunk = 64 * 1024 * 1024;
while (chunk_size <= max_chunk) {
fprintf(stderr, "gguf_fetch: downloading %zu bytes from %s\n", chunk_size, filename.c_str());
if (verbose) {
fprintf(stderr, "gguf_fetch: downloading %zu bytes from %s\n", chunk_size, filename.c_str());
}
char range_buf[64];
snprintf(range_buf, sizeof(range_buf), "bytes=0-%zu", chunk_size - 1);
@@ -531,34 +564,42 @@ static std::optional<gguf_remote_model> fetch_and_parse(
return std::nullopt;
}
static std::string get_cache_file_path(const std::string& cdir, const std::string& repo_part, const std::string& filename) {
std::string fname_part = sanitize_for_path(filename);
return cdir + "/" + repo_part + "--" + fname_part + ".partial";
}
// Try cache first, then fetch and parse a single GGUF shard.
static std::optional<gguf_remote_model> fetch_or_cached(
const std::string & repo,
const std::string & filename,
const std::string & cdir,
const std::string & repo_part) {
std::string fname_part = sanitize_for_path(filename);
std::string cache_path = cdir + "/" + repo_part + "--" + fname_part + ".partial";
const std::string & repo_part,
bool verbose) {
std::string cache_path = get_cache_file_path(cdir, repo_part, filename);
{
std::vector<char> cached;
if (std::filesystem::exists(cache_path) && read_file(cache_path, cached)) {
auto result = gguf_parse_meta(cached);
if (result.has_value()) {
fprintf(stderr, "gguf_fetch: loaded from cache: %s\n", cache_path.c_str());
if (verbose) {
fprintf(stderr, "gguf_fetch: loaded from cache: %s\n", cache_path.c_str());
}
return result;
}
}
}
fs_create_directory_with_parents(cdir);
return fetch_and_parse(repo, filename, cache_path);
return fetch_and_parse(repo, filename, cache_path, verbose);
}
std::optional<gguf_remote_model> gguf_fetch_model_meta(
const std::string & repo,
const std::string & quant,
const std::string & cache_dir) {
const std::string & cache_dir,
bool verbose) {
std::string cdir = cache_dir.empty() ? get_default_cache_dir() : cache_dir;
std::string repo_part = sanitize_for_path(repo);
@@ -568,7 +609,7 @@ std::optional<gguf_remote_model> gguf_fetch_model_meta(
return std::nullopt;
}
auto model_opt = fetch_or_cached(repo, filename, cdir, repo_part);
auto model_opt = fetch_or_cached(repo, filename, cdir, repo_part, verbose);
if (!model_opt.has_value()) {
fprintf(stderr, "gguf_fetch: failed to fetch %s\n", filename.c_str());
return std::nullopt;
@@ -583,8 +624,10 @@ std::optional<gguf_remote_model> gguf_fetch_model_meta(
return std::nullopt;
}
fprintf(stderr, "gguf_fetch: split model with %u shards, fetching remaining %u...\n",
model.n_split, model.n_split - 1);
if (verbose) {
fprintf(stderr, "gguf_fetch: split model with %u shards, fetching remaining %u...\n",
model.n_split, model.n_split - 1);
}
for (int i = 2; i <= model.n_split; i++) {
char num_buf[6], total_buf[6];
@@ -592,7 +635,7 @@ std::optional<gguf_remote_model> gguf_fetch_model_meta(
snprintf(total_buf, sizeof(total_buf), "%05d", (int)model.n_split);
std::string shard_name = split_prefix + "-" + num_buf + "-of-" + total_buf + ".gguf";
auto shard = fetch_or_cached(repo, shard_name, cdir, repo_part);
auto shard = fetch_or_cached(repo, shard_name, cdir, repo_part, verbose);
if (!shard.has_value()) {
fprintf(stderr, "gguf_fetch: failed to fetch shard %d: %s\n", i, shard_name.c_str());
return std::nullopt;
@@ -611,3 +654,87 @@ std::optional<gguf_remote_model> gguf_fetch_model_meta(
return model_opt;
}
gguf_context_ptr gguf_fetch_gguf_ctx(
const std::string & repo,
const std::string & quant,
const std::string & cache_dir,
bool verbose) {
std::string cdir = cache_dir.empty() ? get_default_cache_dir() : cache_dir;
std::string repo_part = sanitize_for_path(repo);
std::string split_prefix;
std::string filename = detect_gguf_filename(repo, quant, split_prefix);
if (filename.empty()) {
return nullptr;
}
auto model_opt = fetch_or_cached(repo, filename, cdir, repo_part, verbose);
if (!model_opt.has_value()) {
fprintf(stderr, "gguf_fetch: failed to fetch %s\n", filename.c_str());
return nullptr;
}
auto & model = model_opt.value();
const std::string cache_path = get_cache_file_path(cdir, repo_part, filename);
ggml_context_ptr ggml_ctx_ptr;
ggml_context * ggml_ctx{};
gguf_init_params params{true, &ggml_ctx};
gguf_context_ptr ctx{gguf_init_from_file(cache_path.c_str(), params)};
ggml_ctx_ptr.reset(ggml_ctx);
if (ctx == nullptr) {
fprintf(stderr, "gguf_fetch: gguf_init_from_file failed\n");
return nullptr;
}
// If the model is split across multiple files we need to fetch the remaining shards metadata
if (model.n_split > 1) {
if (split_prefix.empty()) {
fprintf(stderr, "gguf_fetch: model reports %u splits but filename has no split pattern\n", model.n_split);
return nullptr;
}
if (verbose) {
fprintf(stderr, "gguf_fetch: split model with %u shards, fetching remaining %u...\n",
model.n_split, model.n_split - 1);
}
for (int i = 2; i <= model.n_split; i++) {
char num_buf[6], total_buf[6];
snprintf(num_buf, sizeof(num_buf), "%05d", i);
snprintf(total_buf, sizeof(total_buf), "%05d", (int)model.n_split);
std::string shard_name = split_prefix + "-" + num_buf + "-of-" + total_buf + ".gguf";
auto shard = fetch_or_cached(repo, shard_name, cdir, repo_part, verbose);
if (!shard.has_value()) {
fprintf(stderr, "gguf_fetch: failed to fetch shard %d: %s\n", i, shard_name.c_str());
return nullptr;
}
// Load tensors from shard and add to main gguf_context
const std::string shard_path = get_cache_file_path(cdir, repo_part, shard_name);
ggml_context_ptr shard_ggml_ctx_ptr;
ggml_context * shard_ggml_ctx{};
gguf_init_params shard_params{true, &shard_ggml_ctx};
gguf_context_ptr shard_ctx{gguf_init_from_file(shard_path.c_str(), shard_params)};
shard_ggml_ctx_ptr.reset(shard_ggml_ctx);
if (shard_ctx == nullptr) {
fprintf(stderr, "gguf_fetch: shard gguf_init_from_file failed\n");
return nullptr;
}
for (ggml_tensor * t = ggml_get_first_tensor(shard_ggml_ctx); t; t = ggml_get_next_tensor(shard_ggml_ctx, t)) {
gguf_add_tensor(ctx.get(), t);
}
}
gguf_set_val_u16(ctx.get(), "split.count", 1);
}
return ctx;
}
+10 -2
View File
@@ -1,6 +1,7 @@
#pragma once
#include "ggml.h"
#include "ggml-cpp.h"
#include "gguf.h"
#include <cstdint>
#include <optional>
@@ -39,4 +40,11 @@ struct gguf_remote_model {
std::optional<gguf_remote_model> gguf_fetch_model_meta(
const std::string & repo,
const std::string & quant = "Q8_0",
const std::string & cache_dir = ""); // empty = default
const std::string & cache_dir = "", // empty = default
bool verbose = true);
gguf_context_ptr gguf_fetch_gguf_ctx(
const std::string & repo,
const std::string & quant = "Q8_0",
const std::string & cache_dir = "",
bool verbose = true);
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+282
View File
@@ -589,6 +589,51 @@ static common_chat_tool amount_tool{
})",
};
static common_chat_tool toggle_tool{
/* .name = */ "toggle",
/* .description = */ "Toggle a feature",
/* .parameters = */ R"({
"type": "object",
"properties": {
"enabled": {
"type": "boolean",
"description": "Whether to enable the feature"
}
},
"required": ["enabled"]
})",
};
static common_chat_tool nullable_tool{
/* .name = */ "set_nullable",
/* .description = */ "Set a nullable value",
/* .parameters = */ R"({
"type": "object",
"properties": {
"value": {
"type": "null",
"description": "A null value"
}
},
"required": ["value"]
})",
};
static common_chat_tool config_tool{
/* .name = */ "set_config",
/* .description = */ "Set configuration",
/* .parameters = */ R"({
"type": "object",
"properties": {
"config": {
"type": "object",
"description": "Configuration dict"
}
},
"required": ["config"]
})",
};
static common_chat_tool imaginary_number_tool{
/* .name = */ "imaginary_number",
/* .description = */ "Imaginary number converter",
@@ -612,6 +657,66 @@ static common_chat_tool imaginary_number_tool{
})",
};
static common_chat_tool nullable_string_tool{
/* .name = */ "set_nullable_str",
/* .description = */ "Set a nullable string value",
/* .parameters = */ R"({
"type": "object",
"properties": {
"name": {
"type": ["string", "null"],
"description": "A nullable string"
}
},
"required": ["name"]
})",
};
static common_chat_tool nullable_string_null_first_tool{
/* .name = */ "set_nullable_str_nf",
/* .description = */ "Set a nullable string value with null first in type array",
/* .parameters = */ R"({
"type": "object",
"properties": {
"name": {
"type": ["null", "string"],
"description": "A nullable string with null first"
}
},
"required": ["name"]
})",
};
static common_chat_tool nullable_int_tool{
/* .name = */ "set_nullable_int",
/* .description = */ "Set a nullable integer value",
/* .parameters = */ R"({
"type": "object",
"properties": {
"count": {
"type": ["integer", "null"],
"description": "A nullable integer"
}
},
"required": ["count"]
})",
};
static common_chat_tool enum_no_type_tool{
/* .name = */ "set_unit",
/* .description = */ "Set a temperature unit",
/* .parameters = */ R"({
"type": "object",
"properties": {
"unit": {
"enum": ["celsius", "fahrenheit"],
"description": "Temperature unit"
}
},
"required": ["unit"]
})",
};
static common_chat_tool string_param_tool{
/* .name = */ "string_param",
/* .description = */ "Tool with string parameter for testing",
@@ -1869,6 +1974,130 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
tst.test("Line 1\nLine 2\nLine 3").expect(simple_assist_msg("Line 1\nLine 2\nLine 3")).expect_reconstruction().run();
}
{
// Google Gemma 4 (tool calling with Gemma4 dict format)
auto tst = peg_tester("models/templates/gemma4.jinja");
tst.test("Hello, world!").expect(simple_assist_msg("Hello, world!")).run();
// Simple tool call with string argument
tst.test(
"<|tool_call>call:get_time{city:<|\"|>London<|\"|>}<tool_call|>")
.tools({ get_time_tool })
.expect(message_with_tool_calls("get_time", R"({"city": "London"})"))
.run();
// Tool call with string argument containing special chars
tst.test(
"<|tool_call>call:get_time{city:<|\"|>San Francisco<|\"|>}<tool_call|>")
.tools({ get_time_tool })
.expect(message_with_tool_calls("get_time", R"({"city": "San Francisco"})"))
.run();
// Tool call with empty args
tst.test(
"<|tool_call>call:empty_args{}<tool_call|>")
.tools({ empty_args_tool })
.expect(message_with_tool_calls("empty_args", "{}"))
.run();
// Tool call with string and content
tst.test(
"Hello, world!\nWhat's up?<|tool_call>call:get_time{city:<|\"|>Paris<|\"|>}<tool_call|>")
.tools({ get_time_tool })
.expect(message_with_content_and_tool_call("Hello, world!\nWhat's up?", "get_time", R"({"city": "Paris"})"))
.run();
// Parallel tool calls
tst.test(
"<|tool_call>call:get_time{city:<|\"|>London<|\"|>}<tool_call|>"
"<|tool_call>call:get_weather{city:<|\"|>Paris<|\"|>}<tool_call|>")
.tools({ get_time_tool, get_weather_tool })
.parallel_tool_calls(true)
.expect_tool_calls({
{ "get_time", R"({"city": "London"})", "" },
{ "get_weather", R"({"city": "Paris"})", "" },
})
.run();
// Tool call with integer argument (number type)
tst.test(
"<|tool_call>call:special_function{arg1:42}<tool_call|>")
.tools({ special_function_tool })
.expect(message_with_tool_calls("special_function", R"({"arg1": 42})"))
.run();
// Tool call with negative number argument
tst.test(
"<|tool_call>call:special_function{arg1:-7}<tool_call|>")
.tools({ special_function_tool })
.expect(message_with_tool_calls("special_function", R"({"arg1": -7})"))
.run();
// Tool call with decimal number argument
tst.test(
"<|tool_call>call:amount{orig:3.14}<tool_call|>")
.tools({ amount_tool })
.expect(message_with_tool_calls("amount", R"({"orig": 3.14})"))
.run();
// Tool call with boolean argument (true)
tst.test(
"<|tool_call>call:toggle{enabled:true}<tool_call|>")
.tools({ toggle_tool })
.expect(message_with_tool_calls("toggle", R"({"enabled": true})"))
.run();
// Tool call with boolean argument (false)
tst.test(
"<|tool_call>call:toggle{enabled:false}<tool_call|>")
.tools({ toggle_tool })
.expect(message_with_tool_calls("toggle", R"({"enabled": false})"))
.run();
// Tool call with null argument
tst.test(
"<|tool_call>call:set_nullable{value:null}<tool_call|>")
.tools({ nullable_tool })
.expect(message_with_tool_calls("set_nullable", R"({"value": null})"))
.run();
// Tool call with array argument (todo list)
tst.test(
"<|tool_call>call:todo_list{todos:[<|\"|>buy milk<|\"|>,<|\"|>walk dog<|\"|>]}<tool_call|>")
.tools({ todo_list })
.expect(message_with_tool_calls("todo_list", R"({"todos":["buy milk","walk dog"]})"))
.run();
// Tool call with object/dict argument
tst.test(
"<|tool_call>call:set_config{config:{theme:<|\"|>dark<|\"|>,count:3}}<tool_call|>")
.tools({ config_tool })
.expect(message_with_tool_calls("set_config", R"({"config":{"theme":"dark","count":3}})"))
.run();
// Tool call with empty array
tst.test(
"<|tool_call>call:todo_list{todos:[]}<tool_call|>")
.tools({ todo_list })
.expect(message_with_tool_calls("todo_list", R"({"todos":[]})"))
.run();
// Tool call with empty dict
tst.test(
"<|tool_call>call:set_config{config:{}}<tool_call|>")
.tools({ config_tool })
.expect(message_with_tool_calls("set_config", R"({"config":{}})"))
.run();
// Tool call with scientific notation number
tst.test(
"<|tool_call>call:amount{orig:1.5e10}<tool_call|>")
.tools({ amount_tool })
.expect(message_with_tool_calls("amount", R"({"orig": 1.5e10})"))
.run();
}
{
// Qwen-QwQ-32B (reasoning model)
auto tst = peg_tester("models/templates/Qwen-QwQ-32B.jinja");
@@ -2031,6 +2260,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
}
})
.run();
}
{
@@ -2214,6 +2444,58 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
})
.expect_reconstruction()
.run();
// nullable string type ["string", "null"]
tst.test(
"<tool_call>\n"
"<function=set_nullable_str>\n"
"<parameter=name>\nhello world\n</parameter>\n"
"</function>\n"
"</tool_call>")
.tools({ nullable_string_tool })
.expect_tool_calls({
{ "set_nullable_str", R"({"name": "hello world"})", {} },
})
.run();
// nullable string with null first in type array ["null", "string"]
tst.test(
"<tool_call>\n"
"<function=set_nullable_str_nf>\n"
"<parameter=name>\nhello world\n</parameter>\n"
"</function>\n"
"</tool_call>")
.tools({ nullable_string_null_first_tool })
.expect_tool_calls({
{ "set_nullable_str_nf", R"({"name": "hello world"})", {} },
})
.run();
// nullable integer type ["integer", "null"] - should use JSON value path, not string
tst.test(
"<tool_call>\n"
"<function=set_nullable_int>\n"
"<parameter=count>\n42\n</parameter>\n"
"</function>\n"
"</tool_call>")
.tools({ nullable_int_tool })
.expect_tool_calls({
{ "set_nullable_int", R"({"count": 42})", {} },
})
.run();
// enum without explicit type key - should infer string from enum values
tst.test(
"<tool_call>\n"
"<function=set_unit>\n"
"<parameter=unit>\ncelsius\n</parameter>\n"
"</function>\n"
"</tool_call>")
.tools({ enum_no_type_tool })
.expect_tool_calls({
{ "set_unit", R"({"unit": "celsius"})", {} },
})
.run();
}
{
auto tst = peg_tester("models/templates/deepseek-ai-DeepSeek-V3.1.jinja", detailed_debug);
+33
View File
@@ -116,6 +116,39 @@ int main() {
// Verify tensor count
TEST_ASSERT(model3.tensors.size() == 780, "expected tensor count == 780");
// Test a hybrid-attention model with array-valued head counts
auto result4 = gguf_fetch_model_meta("ggml-org/Step-3.5-Flash-GGUF", "Q4_K");
if (!result4.has_value()) {
fprintf(stderr, "FAIL: could not fetch Step-3.5-Flash metadata\n");
return 1;
}
const auto & model4 = result4.value();
fprintf(stderr, "Architecture: %s\n", model4.architecture.c_str());
fprintf(stderr, "n_embd: %u\n", model4.n_embd);
fprintf(stderr, "n_ff: %u\n", model4.n_ff);
fprintf(stderr, "n_vocab: %u\n", model4.n_vocab);
fprintf(stderr, "n_layer: %u\n", model4.n_layer);
fprintf(stderr, "n_head: %u\n", model4.n_head);
fprintf(stderr, "n_head_kv: %u\n", model4.n_head_kv);
fprintf(stderr, "n_expert: %u\n", model4.n_expert);
fprintf(stderr, "n_embd_head_k: %u\n", model4.n_embd_head_k);
fprintf(stderr, "n_embd_head_v: %u\n", model4.n_embd_head_v);
fprintf(stderr, "tensors: %zu\n", model4.tensors.size());
TEST_ASSERT(model4.architecture == "step35", "expected architecture 'step35'");
TEST_ASSERT(model4.n_layer == 45, "expected n_layer == 45");
TEST_ASSERT(model4.n_embd == 4096, "expected n_embd == 4096");
TEST_ASSERT(model4.n_ff == 11264, "expected n_ff == 11264");
TEST_ASSERT(model4.n_head == 64, "expected n_head == 64 (first element of per-layer array)");
TEST_ASSERT(model4.n_head_kv == 8, "expected n_head_kv == 8 (first element of per-layer array)");
TEST_ASSERT(model4.n_expert == 288, "expected n_expert == 288");
TEST_ASSERT(model4.n_embd_head_k == 128, "expected n_embd_head_k == 128");
TEST_ASSERT(model4.n_embd_head_v == 128, "expected n_embd_head_v == 128");
TEST_ASSERT(model4.n_vocab == 128896, "expected n_vocab == 128896");
TEST_ASSERT(model4.tensors.size() == 754, "expected tensor count == 754");
fprintf(stderr, "=== ALL TESTS PASSED ===\n");
return 0;
}
+12
View File
@@ -523,6 +523,18 @@ static void test_filters(testing & t) {
"hello"
);
test_template(t, "upper array",
"{{ items|upper }}",
{{"items", json::array({"hello", "world"})}},
"['HELLO', 'WORLD']"
);
test_template(t, "upper dict",
"{{ items|upper }}",
{{"items", {{"hello", "world"}}}},
"{'HELLO': 'WORLD'}"
);
test_template(t, "capitalize",
"{{ 'heLlo World'|capitalize }}",
json::object(),
+520
View File
@@ -0,0 +1,520 @@
#include "../src/llama-ext.h"
#include "ggml-cpp.h"
#include "gguf-model-data.h"
#include "llama.h"
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <map>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
// ---------------------------------------------------------------------------
// ftype name <-> enum mapping
// ---------------------------------------------------------------------------
struct ftype_name_entry {
const char * name;
llama_ftype ftype;
};
static const ftype_name_entry ftype_name_table[] = {
{ "F32", LLAMA_FTYPE_ALL_F32 },
{ "F16", LLAMA_FTYPE_MOSTLY_F16 },
{ "BF16", LLAMA_FTYPE_MOSTLY_BF16 },
{ "Q4_0", LLAMA_FTYPE_MOSTLY_Q4_0 },
{ "Q4_1", LLAMA_FTYPE_MOSTLY_Q4_1 },
{ "Q5_0", LLAMA_FTYPE_MOSTLY_Q5_0 },
{ "Q5_1", LLAMA_FTYPE_MOSTLY_Q5_1 },
{ "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0 },
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K },
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S },
{ "Q3_K_S", LLAMA_FTYPE_MOSTLY_Q3_K_S },
{ "Q3_K_M", LLAMA_FTYPE_MOSTLY_Q3_K_M },
{ "Q3_K_L", LLAMA_FTYPE_MOSTLY_Q3_K_L },
{ "Q4_K_S", LLAMA_FTYPE_MOSTLY_Q4_K_S },
{ "Q4_K_M", LLAMA_FTYPE_MOSTLY_Q4_K_M },
{ "Q5_K_S", LLAMA_FTYPE_MOSTLY_Q5_K_S },
{ "Q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M },
{ "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K },
{ "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S },
{ "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M },
{ "IQ2_XXS", LLAMA_FTYPE_MOSTLY_IQ2_XXS },
{ "IQ2_XS", LLAMA_FTYPE_MOSTLY_IQ2_XS },
{ "IQ2_S", LLAMA_FTYPE_MOSTLY_IQ2_S },
{ "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M },
{ "IQ3_XXS", LLAMA_FTYPE_MOSTLY_IQ3_XXS },
{ "IQ3_XS", LLAMA_FTYPE_MOSTLY_IQ3_XS },
{ "IQ3_S", LLAMA_FTYPE_MOSTLY_IQ3_S },
{ "IQ3_M", LLAMA_FTYPE_MOSTLY_IQ3_M },
{ "IQ4_NL", LLAMA_FTYPE_MOSTLY_IQ4_NL },
{ "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS },
{ "TQ1_0", LLAMA_FTYPE_MOSTLY_TQ1_0 },
{ "TQ2_0", LLAMA_FTYPE_MOSTLY_TQ2_0 },
{ "MXFP4_MOE", LLAMA_FTYPE_MOSTLY_MXFP4_MOE },
{ "NVFP4", LLAMA_FTYPE_MOSTLY_NVFP4 },
};
static llama_ftype llama_ftype_from_name(const char * name) {
for (const auto & e : ftype_name_table) {
if (strcmp(name, e.name) == 0) {
return e.ftype;
}
}
return (llama_ftype) -1;
}
static const char * llama_ftype_to_name(llama_ftype ftype) {
for (const auto & e : ftype_name_table) {
if (e.ftype == ftype) {
return e.name;
}
}
return nullptr;
}
// ---------------------------------------------------------------------------
// ggml_type name lookup
// ---------------------------------------------------------------------------
static ggml_type ggml_type_from_name(const std::string & name) {
for (int i = 0; i < GGML_TYPE_COUNT; i++) {
const char * tname = ggml_type_name((ggml_type) i);
if (tname && name == tname) {
return (ggml_type) i;
}
}
return GGML_TYPE_COUNT;
}
// ---------------------------------------------------------------------------
// File parser for snapshot files (quant type schemas)
// ---------------------------------------------------------------------------
struct snapshot_section {
llama_ftype ftype;
ggml_type default_type;
std::vector<std::pair<std::string, ggml_type>> overrides;
};
// This function is pretty ugly, but it's a trade-off of readable snapshot files
// versus readable parsing code
static bool parse_snapshot_file(const std::string & path, std::vector<snapshot_section> & sections) {
std::ifstream f(path);
if (!f.good()) {
return false;
}
snapshot_section * cur = nullptr;
std::string line;
while (std::getline(f, line)) {
if (line.empty() || line[0] == '#') {
continue;
}
// section header: [FTYPE_NAME] default_type
if (line[0] == '[') {
auto close = line.find(']');
if (close == std::string::npos) {
fprintf(stderr, "parse error: missing ] in '%s'\n", line.c_str());
return false;
}
std::string ftype_str = line.substr(1, close - 1);
std::string default_str;
size_t pos = close + 1;
while (pos < line.size() && line[pos] == ' ') {
pos++;
}
default_str = line.substr(pos);
llama_ftype ftype = llama_ftype_from_name(ftype_str.c_str());
if ((int) ftype < 0) {
fprintf(stderr, "parse error: unknown ftype '%s'\n", ftype_str.c_str());
return false;
}
ggml_type dtype = ggml_type_from_name(default_str);
if (dtype == GGML_TYPE_COUNT) {
fprintf(stderr, "parse error: unknown default type '%s'\n", default_str.c_str());
return false;
}
sections.push_back({ ftype, dtype, {} });
cur = &sections.back();
continue;
}
if (!cur) {
fprintf(stderr, "parse error: tensor line before any section: '%s'\n", line.c_str());
return false;
}
auto sp = line.rfind(' ');
if (sp == std::string::npos) {
fprintf(stderr, "parse error: no space in tensor line: '%s'\n", line.c_str());
return false;
}
std::string tname = line.substr(0, sp);
std::string ttype = line.substr(sp + 1);
ggml_type gt = ggml_type_from_name(ttype);
if (gt == GGML_TYPE_COUNT) {
fprintf(stderr, "parse error: unknown type '%s' for tensor '%s'\n", ttype.c_str(), tname.c_str());
return false;
}
cur->overrides.push_back({ tname, gt });
}
return true;
}
// ---------------------------------------------------------------------------
// Remote model support using gguf-model-data.cpp
// ---------------------------------------------------------------------------
struct remote_model_spec {
const char * repo;
const char * quant;
};
// Get model name from repo: strip org prefix, strip -GGUF suffix,
// and strip anything up to and including first '_' (e.g. "deepseek-ai_DeepSeek-V3.1").
static std::string model_name_from_repo(const char * repo) {
std::string s(repo);
auto slash = s.find('/');
if (slash != std::string::npos) {
s = s.substr(slash + 1);
}
const std::string suffix = "-GGUF";
if (s.size() >= suffix.size() && s.compare(s.size() - suffix.size(), suffix.size(), suffix) == 0) {
s = s.substr(0, s.size() - suffix.size());
}
auto underscore = s.find('_');
if (underscore != std::string::npos) {
s = s.substr(underscore + 1);
}
return s;
}
static std::string snapshot_file_from_name(const std::string & name) {
std::string lower = name;
for (auto & c : lower) {
c = std::tolower(c);
}
return lower;
}
static const remote_model_spec model_specs[] = {
{ "ggml-org/Qwen3-0.6B-GGUF", "Q8_0" },
{ "ggml-org/GLM-4.6V-GGUF", "Q8_0" },
{ "ggml-org/Step-3.5-Flash-GGUF", "Q4_K" },
{ "ggml-org/Qwen3-Coder-Next-GGUF", "Q8_0" },
{ "ggml-org/Qwen3-14B-GGUF", "Q8_0" },
{ "ggml-org/Nemotron-Nano-3-30B-A3B-GGUF", "Q8_0" },
{ "ggml-org/gpt-oss-120b-GGUF", "mxfp4" },
{ "ggml-org/gemma-3-4b-it-GGUF", "Q8_0" },
{ "bartowski/Meta-Llama-3.1-70B-Instruct-GGUF", "Q4_K_M" },
{ "bartowski/deepseek-ai_DeepSeek-V3.1-GGUF", "IQ1_M" },
{ "bartowski/Qwen_Qwen3.5-397B-A17B-GGUF", "IQ1_S" }, // TODO: swap with ggml-org if/when it's released
{ "bartowski/Qwen_Qwen3.5-27B-GGUF", "Q8_0" }, // TODO: swap with ggml-org if/when it's released
};
static const int n_model_specs = (int) (sizeof(model_specs) / sizeof(model_specs[0]));
static llama_model * build_mock_model_from_remote(const gguf_remote_model & remote) {
llama_quant_model_desc desc = {};
desc.architecture = remote.architecture.c_str();
desc.n_embd = remote.n_embd;
desc.n_ff = remote.n_ff;
desc.n_layer = remote.n_layer;
desc.n_head = remote.n_head;
desc.n_head_kv = remote.n_head_kv;
desc.n_expert = remote.n_expert;
desc.n_embd_head_k = remote.n_embd_head_k;
desc.n_embd_head_v = remote.n_embd_head_v;
return llama_quant_model_from_metadata(&desc);
}
// Single ggml context holding all quantizable tensors for a model.
struct mock_tensors {
ggml_context_ptr ctx;
std::vector<ggml_tensor *> tensors;
};
static mock_tensors build_mock_tensors(const quantize_state_impl * qs, const gguf_remote_model & remote) {
const size_t ctx_size = remote.tensors.size() * ggml_tensor_overhead();
struct ggml_init_params params = { ctx_size, nullptr, true };
ggml_context_ptr ctx(ggml_init(params));
std::vector<ggml_tensor *> result;
for (const auto & t : remote.tensors) {
ggml_tensor * gt = ggml_new_tensor_4d(ctx.get(), GGML_TYPE_F32, t.ne[0], t.ne[1], t.ne[2], t.ne[3]);
ggml_set_name(gt, t.name.c_str());
if (llama_quant_tensor_allows_quantization(qs, gt)) {
result.push_back(gt);
}
}
// sort by layer index then name, matching llama_model_loader::weight_name_comparer
std::sort(result.begin(), result.end(), [](const ggml_tensor * a, const ggml_tensor * b) {
int a_layer = -1, b_layer = -1;
sscanf(a->name, "blk.%d.", &a_layer);
sscanf(b->name, "blk.%d.", &b_layer);
if (a_layer != b_layer) {
return a_layer < b_layer;
}
return strcmp(a->name, b->name) < 0;
});
return { std::move(ctx), std::move(result) };
}
// ---------------------------------------------------------------------------
// Generate mode: regenerate all snapshot files
// Use this when either adding new models or modifying quants
// ---------------------------------------------------------------------------
static std::string generate_snapshot(const std::string & name,
const gguf_remote_model & remote,
quantize_state_impl * qs,
mock_tensors & mt) {
std::ostringstream out;
out << "# Model: " << name << "\n";
out << "# n_embd=" << remote.n_embd << ", n_ff=" << remote.n_ff << ", n_vocab=" << remote.n_vocab
<< ", n_layer=" << remote.n_layer << ", n_head=" << remote.n_head << ", n_head_kv=" << remote.n_head_kv;
if (remote.n_expert > 0) {
out << ", n_expert=" << remote.n_expert;
}
out << "\n";
for (int i = 0; i < LLAMA_FTYPE_GUESSED; i++) {
llama_ftype ft = (llama_ftype) i;
ggml_type default_type = llama_ftype_get_default_type(ft);
if (default_type == GGML_TYPE_COUNT) {
continue;
}
const char * fname = llama_ftype_to_name(ft);
if (!fname) {
continue;
}
std::vector<ggml_type> result_types(mt.tensors.size());
llama_quant_compute_types(qs, ft, mt.tensors.data(), result_types.data(), mt.tensors.size());
out << "\n[" << fname << "] " << ggml_type_name(default_type) << "\n";
for (size_t j = 0; j < mt.tensors.size(); j++) {
if (result_types[j] != default_type) {
out << ggml_get_name(mt.tensors[j]) << " " << ggml_type_name(result_types[j]) << "\n";
}
}
}
return out.str();
}
static int run_generate(const std::string & snapshot_dir) {
fprintf(stderr, "This will overwrite all snapshot files in:\n %s\n", snapshot_dir.c_str());
fprintf(stderr, "Continue? [y/N] ");
int ch = fgetc(stdin);
if (ch != 'y' && ch != 'Y') {
fprintf(stderr, "Aborted.\n");
return 1;
}
fprintf(stderr, "\n");
int n_written = 0;
for (int m = 0; m < n_model_specs; m++) {
const auto & spec = model_specs[m];
std::string name = model_name_from_repo(spec.repo);
fprintf(stderr, "Fetching model metadata for %s from %s...\n", name.c_str(), spec.repo);
auto result = gguf_fetch_model_meta(spec.repo, spec.quant);
if (!result.has_value()) {
fprintf(stderr, "ERROR: could not fetch model metadata for %s\n", name.c_str());
return 1;
}
const auto & remote = result.value();
llama_model * model = build_mock_model_from_remote(remote);
llama_model_quantize_params qparams = llama_model_quantize_default_params();
quantize_state_impl * qs = llama_quant_init(model, &qparams);
auto mt = build_mock_tensors(qs, remote);
std::string content = generate_snapshot(name, remote, qs, mt);
std::string path = snapshot_dir + "/" + snapshot_file_from_name(name) + ".schema";
std::ofstream f(path);
if (!f.good()) {
fprintf(stderr, "ERROR: could not write %s\n", path.c_str());
llama_quant_free(qs);
llama_model_free(model);
return 1;
}
f << content;
n_written++;
fprintf(stderr, " wrote %s\n", path.c_str());
llama_quant_free(qs);
llama_model_free(model);
}
fprintf(stderr, "%d files written\n", n_written);
return 0;
}
// ---------------------------------------------------------------------------
// Test mode: compare against snapshot files
// ---------------------------------------------------------------------------
static bool run_test_section(quantize_state_impl * qs, mock_tensors & mt, const snapshot_section & section) {
// verify default_type matches what llama_ftype_get_default_type returns
ggml_type computed_default = llama_ftype_get_default_type(section.ftype);
if (computed_default != section.default_type) {
printf(" FAIL [%s] default type mismatch: file says %s, code says %s\n", llama_ftype_to_name(section.ftype),
ggml_type_name(section.default_type), ggml_type_name(computed_default));
return false;
}
std::vector<ggml_type> result_types(mt.tensors.size());
llama_quant_compute_types(qs, section.ftype, mt.tensors.data(), result_types.data(), mt.tensors.size());
std::map<std::string, ggml_type> override_map(section.overrides.begin(), section.overrides.end());
bool all_pass = true;
int n_override_found = 0;
for (size_t i = 0; i < mt.tensors.size(); i++) {
const char * name = ggml_get_name(mt.tensors[i]);
ggml_type got = result_types[i];
ggml_type expected = section.default_type;
auto it = override_map.find(name);
if (it != override_map.end()) {
expected = it->second;
n_override_found++;
}
if (got != expected) {
printf(" FAIL %-50s %-10s expected %s, got %s\n", name, llama_ftype_to_name(section.ftype),
ggml_type_name(expected), ggml_type_name(got));
all_pass = false;
}
}
if (n_override_found != (int) section.overrides.size()) {
printf(" FAIL [%s] override count mismatch: listed %d, matched %d\n", llama_ftype_to_name(section.ftype),
(int) section.overrides.size(), n_override_found);
all_pass = false;
}
return all_pass;
}
static int run_remote_tests(const std::string & snapshot_dir, const char * argv0) {
int total_pass = 0;
int total_fail = 0;
int total_skip = 0;
for (int m = 0; m < n_model_specs; m++) {
const auto & spec = model_specs[m];
std::string name = model_name_from_repo(spec.repo);
printf("=== %s ===\n", name.c_str());
auto result = gguf_fetch_model_meta(spec.repo, spec.quant, "", false);
if (!result.has_value()) {
printf(" SKIP (could not fetch model metadata)\n\n");
total_skip++;
continue;
}
const auto & remote = result.value();
llama_model * model = build_mock_model_from_remote(remote);
llama_model_quantize_params qparams = llama_model_quantize_default_params();
quantize_state_impl * qs = llama_quant_init(model, &qparams);
auto mt = build_mock_tensors(qs, remote);
std::string snapshot_path = snapshot_dir + "/" + snapshot_file_from_name(name) + ".schema";
std::vector<snapshot_section> sections;
if (!parse_snapshot_file(snapshot_path, sections)) {
printf(" SKIP (could not read snapshot file: %s)\n\n", snapshot_path.c_str());
llama_quant_free(qs);
llama_model_free(model);
total_skip++;
continue;
}
int model_pass = 0;
int model_fail = 0;
for (const auto & section : sections) {
bool pass = run_test_section(qs, mt, section);
if (pass) {
model_pass++;
} else {
model_fail++;
}
}
printf(" %s %s: %d/%d ftype sections passed (%d tensors)\n", model_fail == 0 ? "PASS" : "FAIL", name.c_str(),
model_pass, model_pass + model_fail, (int) mt.tensors.size());
printf("\n");
if (model_fail == 0) {
total_pass++;
} else {
total_fail++;
}
llama_quant_free(qs);
llama_model_free(model);
}
printf("%d/%d models passed", total_pass, total_pass + total_fail);
if (total_skip > 0) {
printf(", %d skipped", total_skip);
}
printf("\n");
if (total_fail > 0) {
printf("\nIf these changes are intentional, regenerate snapshot files with:\n");
printf(" %s --generate\n", argv0);
}
return total_fail > 0 ? 1 : 0;
}
int main(int argc, char ** argv) {
std::string snapshot_dir = SNAPSHOT_DIR;
bool generate = false;
for (int i = 1; i < argc; i++) {
if (strcmp(argv[i], "--generate") == 0) {
generate = true;
} else if (strcmp(argv[i], "--snapshot-dir") == 0 && i + 1 < argc) {
snapshot_dir = argv[++i];
}
}
if (generate) {
return run_generate(snapshot_dir);
}
// suppress llama log warnings during test (e.g. tensor type fallback messages)
llama_log_set([](enum ggml_log_level, const char *, void *) {}, nullptr);
return run_remote_tests(snapshot_dir, argv[0]);
}