Compare commits

...

38 Commits

Author SHA1 Message Date
Olivier Chafik d785f9c1fd server: fix/test add_generation_prompt (#13770)
Co-authored-by: ochafik <ochafik@google.com>
2025-05-25 10:45:49 +01:00
Piotr Jasiukajtis 4032ca4066 llama : add support for Qwen3 MoE tied word embeddings (#13768) 2025-05-25 10:29:43 +02:00
Akarshan Biswas 515fdbf7ed SYCL: revert "sycl: simplify bin_bcast_kernel (#13383)" (#13752)
Temporarily reverted due to failing fp16 DIV operation

This reverts commit 02cdd2d8b0.

ggml-ci
2025-05-25 10:08:37 +03:00
Olivier Chafik f5cd27b71d server: streaming of tool calls and thoughts when --jinja is on (#12379)
* add common_json w/ support for truncated json healing

* add common_chat_msg_diff

* partial common_chat_parse

* refactor parser w/ optionals

* server: wire chat diffs in stream mode

* fix trigger of thinking models (must happen after thoughts are closed)

* fix functionary v3.2 raw python!

* rename: common_chat_syntax (now contains format)

* rm common_regex.at_start

* don't return empty <think></think>

* accommodate yet another deepseek r1 distill fantasy syntax (`<|tool▁calls|>`)

* fix QwQ 32B tool call parsing after thoughts (hermes2)

* better logs for grammar triggers

* consume spaces after parse_json_tool_calls

* fix required tool calls w/ thinking models that have pre-opened thinking tags

* fix thinking model's initial trigger + test qwq's template

* run most test_tool_call tests in stream + non-stream modes

* make functionary v3.2 parsing more strict (differentiate first match from others)

* send final diff from server, to close off raw python arguments

* support partial content streaming in Generic mode

* tool-call: allow content prelude before hermes2 tool calls (for Qwen2.5)

* Update function-calling.md

* Update tool_bench.py

* chat-parser: remove input from exception (llm output may contain PII)

---------

Co-authored-by: ochafik <ochafik@google.com>
Co-authored-by: Olivier Chafik <ochafik@users.noreply.github.com>
2025-05-25 01:48:08 +01:00
Diego Devesa a2d02d5793 releases : bundle llvm omp library in windows release (#13763) 2025-05-25 00:55:16 +02:00
Diego Devesa 17fc817b58 releases : enable openmp in windows cpu backend build (#13756) 2025-05-24 22:27:03 +02:00
Diego Devesa 2bd1b30f69 ggml-cpu : set openmp wait time if not set (#13758) 2025-05-24 22:26:47 +02:00
0cc4m 259469c4b5 Move GLM4 f32 attention fix to the correct function (#13750) 2025-05-24 16:49:12 +02:00
Xuan-Son Nguyen 4c32832c59 ggml : add ggml_gelu_erf() CUDA kernel (#13719)
* ggml : add ggml_gelu_erf() CUDA kernel

* missing semicolon
2025-05-24 13:06:47 +02:00
Sigbjørn Skjæret c3a2624339 vocab : fix ugm tokenizer precision (#13743) 2025-05-24 12:29:09 +02:00
Johannes Gäßler ffd0eae60b CUDA: fix race condition in FA vector kernels (#13742) 2025-05-24 11:46:19 +02:00
Diego Devesa b775345d78 ci : enable winget package updates (#13734) 2025-05-23 23:14:00 +03:00
Diego Devesa a70a8a69c2 ci : add winget package updater (#13732) 2025-05-23 22:09:38 +02:00
Georgi Gerganov d13d0f6135 hparams : initialize arrays (#13728)
ggml-ci
2025-05-23 20:16:13 +03:00
Xuan-Son Nguyen 8a2afb7520 llama : allow custom list of swa_layers (#13726) 2025-05-23 17:07:04 +02:00
Xuan-Son Nguyen 9ecf3e66a3 server : support audio input (#13714)
* server : support audio input

* add audio support on webui
2025-05-23 11:03:47 +02:00
Chenguang Li faaaff5f94 CANN: Support MUL_MAT_ID for q8_0 and q4_0 (#13705)
* [CANN]Support MUL_MAT_ID Q8 && Q4

Signed-off-by: noemotiovon <757486878@qq.com>

* codestyle adjustment

Signed-off-by: noemotiovon <757486878@qq.com>

---------

Signed-off-by: noemotiovon <757486878@qq.com>
2025-05-23 16:47:53 +08:00
Xuan-Son Nguyen e16c4731c7 ggml : fix the order of ggml_unary_op (#13718) 2025-05-23 08:12:48 +02:00
Jeff Bolz 1dcd01960c vulkan: support CPY from any type to itself (#13695)
Reuse the f16/f32 copy shaders, and just scale the number of elements
according to the type size.
2025-05-23 06:45:02 +02:00
Jeff Bolz c10ed6cbcc vulkan: Disable coopmat/coopmat2/bfloat extensions if glslc doesn't support it (#13696) 2025-05-23 06:33:45 +02:00
Judd a127ff1780 use LOG_WARN to replace std::cerr (#13657) 2025-05-23 06:33:08 +02:00
Diego Devesa 3079e9ac8e release : fix windows hip release (#13707)
* release : fix windows hip release

* make single hip release with multiple targets
2025-05-23 00:21:37 +02:00
Georgi Gerganov 8a1d206f1d tts : fix n_ubatch + make WavTokenizer cache-less (#13713)
ggml-ci
2025-05-22 22:21:07 +03:00
Xuan-Son Nguyen 797990c4bc mtmd : add ultravox audio input (#13623)
* convert ok, load ok

* warmup ok

* test

* still does not work?

* fix padding

* temporary give up

* fix merge conflict

* build_ultravox()

* rm test

* fix merge conflict

* add necessary mtmd APIs

* first working version (only 4s of audio)

* will this monster compile?

* fix compile

* please compile

* fPIC

* fix windows

* various fixes

* clean up audio_helpers

* fix conversion

* add some debug stuff

* long audio input ok

* adapt the api

* add --audio arg

* final touch UX

* add miniaudio to readme

* fix typo

* refactor kv metadata

* mtmd_default_marker()
2025-05-22 20:42:48 +02:00
Aaron Teo ab86335760 common: Include torch package for s390x (#13699)
* common: update requirements.txt to include pytorch nightly for s390x

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* common: fix torch installation via pip for s390x

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

---------

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>
2025-05-22 21:31:29 +03:00
Georgi Gerganov cc74d5be99 server : pad small embedding batches (#13692)
ggml-ci
2025-05-22 16:33:39 +03:00
Sigbjørn Skjæret 5be24af73d gguf-py : correct charsmap parameter typing (#13701) 2025-05-22 14:25:05 +02:00
Nicolò Scipione d394a9aedc sycl : Remove waits from function calls (#13702)
* removes the waits in async memcpy functions
2025-05-22 12:54:43 +01:00
Ewan Crawford 6b56a64690 SYCL: Avoid using with SYCL-Graph for unsupported nodes (#13587)
Currently on a CUDA backend to SYCL when running
`GGML_SYCL_DISABLE_GRAPH=0 ./bin/test-backend-ops -b SYCL0` there
are two operations that throw an exception from the blocking
waits during queue recording.

* `-o CONCAT` : Use of blocking waits on a queue that's being recorded https://github.com/ggml-org/llama.cpp/blob/master/ggml/src/ggml-sycl/concat.cpp#L185-L187
* `-o MUL_MAT_ID`: Blocking wait on a recording queue for a copy to host memory https://github.com/ggml-org/llama.cpp/blob/master/ggml/src/ggml-sycl/ggml-sycl.cpp#L3072-L3074

We've noticed that `ggml-cuda.cu` has the
[check_node_graph_compatibility_and_refresh_copy_ops](https://github.com/ggml-org/llama.cpp/blob/39e73ae0d69f882d7e29cecc6dd8f5052fca6731/ggml/src/ggml-cuda/ggml-cuda.cu#L2458-L2458)
method for checking if a graph can be used, even if enabled. I've taken a
similar approach in this PR by adding a method to `ggml-sycl.cpp` for checking
if a graph can be used for the operations even if a user has asked for it to be
enabled.
2025-05-22 16:24:09 +08:00
Henry Linjamäki a4e8912dfd opencl: Add support for multiple devices (#12622)
* opencl: Add support for multiple devices

... but limited to one platform. A platform with a GPU will be preferred.

Additionally:

* Filter out devices that lack capabilities needed by the backend
  implementation (half support, OpenCL 2.0+, etc).

* Make ggml_backend_opencl_reg() thread-safe.

* fixup: fix an error in sync_with_other_backends

... when there is only one OpenCL device available.
2025-05-21 16:21:45 -07:00
Henry Linjamäki edbf42edfd opencl: fix couple crashes (#12795)
* opencl: fix couple crashes

* fix kernel launches failed on devices which do not support
  non-uniform work-groups. When non-uniform work-groups are not
  supported, set `local_work_size` to NULL (= let driver choose the
  work-group sizes). This patch does not cover everything - just the
  cases tested by test-backend-ops.

* fix sub-buffer creation failed due to `cl_buffer_region::origin` not
  being aligned to `CL_DEVICE_MEM_BASE_ADDR_ALIGN`.

* OpenCL: query non-uniform WG sizes only on OpenCL 3.0+
2025-05-21 13:21:17 -07:00
Diego Devesa d643bb2c79 releases : build CPU backend separately (windows) (#13642) 2025-05-21 22:09:57 +02:00
Georgi Gerganov 8e186ef0e7 hparams : support models for which all layers use SWA (#13682)
ggml-ci
2025-05-21 20:00:49 +03:00
Georgi Gerganov 5fbfe384d4 server : improve error reporting (#13680) 2025-05-21 19:46:56 +03:00
antichristHater c76532e7ba convert : add qwen2vl support for unsloth merges (#13686) 2025-05-21 18:40:35 +02:00
Sigbjørn Skjæret 2aa777d86d examples : switch retrieval to llama_encode (#13685)
* switch retrieval to llama_encode

* enable --no-warmup for retrieval
2025-05-21 16:57:38 +02:00
Emmanuel Ferdman eb0f5c28d3 gguf-py : display the invalid gguf type (#13687)
Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>
2025-05-21 16:33:54 +02:00
Xuan-Son Nguyen cf4cb59e64 ggml : add ggml_gelu_erf() (#13667)
* ggml : add ggml_gelu_na (not approximated)

* fix naming order

* rename na --> erf

* apply review suggesions

* revert naming order
2025-05-21 16:26:33 +02:00
81 changed files with 100295 additions and 1972 deletions
+4
View File
@@ -48,3 +48,7 @@ end_of_line = unset
charset = unset
trim_trailing_whitespace = unset
insert_final_newline = unset
[tools/mtmd/miniaudio.h]
trim_trailing_whitespace = unset
insert_final_newline = unset
+154 -128
View File
@@ -1,4 +1,4 @@
name: Create Release
name: Release
on:
workflow_dispatch: # allows manual triggering
@@ -227,6 +227,69 @@ jobs:
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.zip
name: llama-bin-ubuntu-vulkan-x64.zip
windows-cpu:
runs-on: windows-latest
strategy:
matrix:
include:
- arch: 'x64'
- arch: 'arm64'
steps:
- name: Clone
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: ccache
uses: hendrikmuhs/ccache-action@v1.2.16
with:
key: windows-latest-cmake-cpu-${{ matrix.arch }}
variant: ccache
evict-old-files: 1d
- name: Install Ninja
run: |
choco install ninja
- name: libCURL
id: get_libcurl
uses: ./.github/actions/windows-setup-curl
with:
architecture: ${{ matrix.arch == 'x64' && 'win64' || 'win64a' }}
- name: Build
shell: cmd
env:
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
run: |
call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" ${{ matrix.arch }}
cmake -S . -B build -G "Ninja Multi-Config" ^
-D CMAKE_TOOLCHAIN_FILE=cmake/${{ matrix.arch }}-windows-llvm.cmake ^
-DGGML_NATIVE=OFF ^
-DGGML_BACKEND_DL=ON ^
-DGGML_CPU_ALL_VARIANTS=${{ matrix.arch == 'x64' && 'ON' || 'OFF' }} ^
-DGGML_OPENMP=ON ^
-DCURL_LIBRARY="%CURL_PATH%/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="%CURL_PATH%/include" ^
${{ env.CMAKE_ARGS }}
cmake --build build --config Release
- name: Pack artifacts
id: pack_artifacts
env:
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
run: |
Copy-Item $env:CURL_PATH\bin\libcurl-${{ matrix.arch }}.dll .\build\bin\Release\
Copy-Item "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Redist\MSVC\14.42.34433\debug_nonredist\${{ matrix.arch }}\Microsoft.VC143.OpenMP.LLVM\libomp140.${{ matrix.arch == 'x64' && 'x86_64' || 'aarch64' }}.dll" .\build\bin\Release\
7z a llama-bin-win-cpu-${{ matrix.arch }}.zip .\build\bin\Release\*
- name: Upload artifacts
uses: actions/upload-artifact@v4
with:
path: llama-bin-win-cpu-${{ matrix.arch }}.zip
name: llama-bin-win-cpu-${{ matrix.arch }}.zip
windows:
runs-on: windows-latest
@@ -237,52 +300,30 @@ jobs:
strategy:
matrix:
include:
- build: 'cpu-x64'
- backend: 'vulkan'
arch: 'x64'
defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/x64-windows-llvm.cmake -DGGML_NATIVE=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_OPENMP=OFF'
#- build: 'openblas-x64'
# arch: 'x64'
# defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/x64-windows-llvm.cmake -DGGML_NATIVE=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_OPENMP=OFF -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS -DBLAS_INCLUDE_DIRS="$env:RUNNER_TEMP/openblas/include" -DBLAS_LIBRARIES="$env:RUNNER_TEMP/openblas/lib/openblas.lib"'
- build: 'vulkan-x64'
arch: 'x64'
defines: '-DGGML_NATIVE=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_VULKAN=ON'
- build: 'cpu-arm64'
arch: 'arm64'
defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DGGML_NATIVE=OFF'
- build: 'opencl-adreno-arm64'
defines: '-DGGML_VULKAN=ON'
target: 'ggml-vulkan'
- backend: 'opencl-adreno'
arch: 'arm64'
defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DCMAKE_PREFIX_PATH="$env:RUNNER_TEMP/opencl-arm64-release" -DGGML_OPENCL=ON -DGGML_OPENCL_USE_ADRENO_KERNELS=ON'
target: 'ggml-opencl'
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: ccache
uses: hendrikmuhs/ccache-action@v1.2.16
with:
key: windows-latest-cmake-${{ matrix.build }}
key: windows-latest-cmake-${{ matrix.backend }}-${{ matrix.arch }}
variant: ccache
evict-old-files: 1d
- name: Download OpenBLAS
id: get_openblas
if: ${{ matrix.build == 'openblas-x64' }}
run: |
curl.exe -o $env:RUNNER_TEMP/openblas.zip -L "https://github.com/xianyi/OpenBLAS/releases/download/v${env:OPENBLAS_VERSION}/OpenBLAS-${env:OPENBLAS_VERSION}-x64.zip"
curl.exe -o $env:RUNNER_TEMP/OpenBLAS.LICENSE.txt -L "https://github.com/xianyi/OpenBLAS/raw/v${env:OPENBLAS_VERSION}/LICENSE"
mkdir $env:RUNNER_TEMP/openblas
tar.exe -xvf $env:RUNNER_TEMP/openblas.zip -C $env:RUNNER_TEMP/openblas
$vcdir = $(vswhere -latest -products * -requires Microsoft.VisualStudio.Component.VC.Tools.x86.x64 -property installationPath)
$msvc = $(join-path $vcdir $('VC\Tools\MSVC\'+$(gc -raw $(join-path $vcdir 'VC\Auxiliary\Build\Microsoft.VCToolsVersion.default.txt')).Trim()))
$lib = $(join-path $msvc 'bin\Hostx64\x64\lib.exe')
& $lib /machine:x64 "/def:${env:RUNNER_TEMP}/openblas/lib/libopenblas.def" "/out:${env:RUNNER_TEMP}/openblas/lib/openblas.lib" /name:openblas.dll
- name: Install Vulkan SDK
id: get_vulkan
if: ${{ matrix.build == 'vulkan-x64' }}
if: ${{ matrix.backend == 'vulkan' }}
run: |
curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/VulkanSDK-${env:VULKAN_VERSION}-Installer.exe"
& "$env:RUNNER_TEMP\VulkanSDK-Installer.exe" --accept-licenses --default-answer --confirm-command install
@@ -296,7 +337,7 @@ jobs:
- name: Install OpenCL Headers and Libs
id: install_opencl
if: ${{ matrix.build == 'opencl-adreno-arm64' }}
if: ${{ matrix.backend == 'opencl-adreno' && matrix.arch == 'arm64' }}
run: |
git clone https://github.com/KhronosGroup/OpenCL-Headers
cd OpenCL-Headers
@@ -314,46 +355,22 @@ jobs:
-DCMAKE_INSTALL_PREFIX="$env:RUNNER_TEMP/opencl-arm64-release"
cmake --build build-arm64-release --target install --config release
- name: libCURL
id: get_libcurl
uses: ./.github/actions/windows-setup-curl
with:
architecture: ${{ matrix.arch == 'x64' && 'win64' || 'win64a' }}
- name: Build
id: cmake_build
env:
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
run: |
cmake -S . -B build ${{ matrix.defines }} `
-DCURL_LIBRARY="$env:CURL_PATH/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="$env:CURL_PATH/include" `
${{ env.CMAKE_ARGS }}
cmake --build build --config Release -j ${env:NUMBER_OF_PROCESSORS}
- name: Add libopenblas.dll
id: add_libopenblas_dll
if: ${{ matrix.build == 'openblas-x64' }}
run: |
cp $env:RUNNER_TEMP/openblas/bin/libopenblas.dll ./build/bin/Release/openblas.dll
cp $env:RUNNER_TEMP/OpenBLAS.LICENSE.txt ./build/bin/Release/OpenBLAS-${env:OPENBLAS_VERSION}.txt
- name: Determine tag name
id: tag
uses: ./.github/actions/get-tag-name
cmake -S . -B build ${{ matrix.defines }} -DGGML_NATIVE=OFF -DGGML_CPU=OFF -DGGML_BACKEND_DL=ON -DLLAMA_CURL=OFF
cmake --build build --config Release --target ${{ matrix.target }}
- name: Pack artifacts
id: pack_artifacts
env:
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
run: |
Copy-Item $env:CURL_PATH\bin\libcurl-${{ matrix.arch }}.dll .\build\bin\Release\
7z a llama-${{ steps.tag.outputs.name }}-bin-win-${{ matrix.build }}.zip .\build\bin\Release\*
7z a llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip .\build\bin\Release\${{ matrix.target }}.dll
- name: Upload artifacts
uses: actions/upload-artifact@v4
with:
path: llama-${{ steps.tag.outputs.name }}-bin-win-${{ matrix.build }}.zip
name: llama-bin-win-${{ matrix.build }}.zip
path: llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip
name: llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip
windows-cuda:
runs-on: windows-2019
@@ -366,8 +383,6 @@ jobs:
- name: Clone
id: checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Install ccache
uses: hendrikmuhs/ccache-action@v1.2.16
@@ -386,45 +401,30 @@ jobs:
run: |
choco install ninja
- name: libCURL
id: get_libcurl
uses: ./.github/actions/windows-setup-curl
- name: Build
id: cmake_build
shell: cmd
env:
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
run: |
call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Auxiliary\Build\vcvars64.bat"
cmake -S . -B build -G "Ninja Multi-Config" ^
-DGGML_NATIVE=OFF ^
-DGGML_BACKEND_DL=ON ^
-DGGML_CPU_ALL_VARIANTS=ON ^
-DGGML_NATIVE=OFF ^
-DGGML_CPU=OFF ^
-DGGML_CUDA=ON ^
-DCURL_LIBRARY="%CURL_PATH%/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="%CURL_PATH%/include" ^
${{ env.CMAKE_ARGS }}
-DLLAMA_CURL=OFF
set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1
cmake --build build --config Release -j %NINJA_JOBS% -t ggml
cmake --build build --config Release
- name: Determine tag name
id: tag
uses: ./.github/actions/get-tag-name
cmake --build build --config Release -j %NINJA_JOBS% --target ggml-cuda
- name: Pack artifacts
id: pack_artifacts
env:
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
run: |
cp $env:CURL_PATH\bin\libcurl-x64.dll .\build\bin\Release\libcurl-x64.dll
7z a llama-${{ steps.tag.outputs.name }}-bin-win-cuda${{ matrix.cuda }}-x64.zip .\build\bin\Release\*
7z a llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip .\build\bin\Release\ggml-cuda.dll
- name: Upload artifacts
uses: actions/upload-artifact@v4
with:
path: llama-${{ steps.tag.outputs.name }}-bin-win-cuda${{ matrix.cuda }}-x64.zip
name: llama-bin-win-cuda${{ matrix.cuda }}-x64.zip
path: llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip
name: llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip
- name: Copy and pack Cuda runtime
run: |
@@ -432,13 +432,13 @@ jobs:
$dst='.\build\bin\cudart\'
robocopy "${{env.CUDA_PATH}}\bin" $dst cudart64_*.dll cublas64_*.dll cublasLt64_*.dll
robocopy "${{env.CUDA_PATH}}\lib" $dst cudart64_*.dll cublas64_*.dll cublasLt64_*.dll
7z a cudart-llama-bin-win-cuda${{ matrix.cuda }}-x64.zip $dst\*
7z a cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip $dst\*
- name: Upload Cuda runtime
uses: actions/upload-artifact@v4
with:
path: cudart-llama-bin-win-cuda${{ matrix.cuda }}-x64.zip
name: cudart-llama-bin-win-cuda${{ matrix.cuda }}-x64.zip
path: cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip
name: cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip
windows-sycl:
runs-on: windows-latest
@@ -451,12 +451,11 @@ jobs:
WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/7cd9bba0-7aab-4e30-b3ae-2221006a4a05/intel-oneapi-base-toolkit-2025.1.1.34_offline.exe
WINDOWS_DPCPP_MKL: intel.oneapi.win.cpp-dpcpp-common:intel.oneapi.win.mkl.devel:intel.oneapi.win.dnnl:intel.oneapi.win.tbb.devel
ONEAPI_ROOT: "C:/Program Files (x86)/Intel/oneAPI"
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: ccache
uses: hendrikmuhs/ccache-action@v1.2.16
@@ -469,15 +468,18 @@ jobs:
run: |
scripts/install-oneapi.bat $WINDOWS_BASEKIT_URL $WINDOWS_DPCPP_MKL
# TODO: add libcurl support ; we will also need to modify win-build-sycl.bat to accept user-specified args
- name: Build
id: cmake_build
run: examples/sycl/win-build-sycl.bat
- name: Determine tag name
id: tag
uses: ./.github/actions/get-tag-name
shell: cmd
run: |
call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force
cmake -G "Ninja" -B build ^
-DCMAKE_C_COMPILER=cl -DCMAKE_CXX_COMPILER=icx ^
-DCMAKE_BUILD_TYPE=Release ^
-DGGML_BACKEND_DL=ON -DBUILD_SHARED_LIBS=ON ^
-DGGML_CPU=OFF -DGGML_SYCL=ON ^
-DLLAMA_CURL=OFF
cmake --build build --target ggml-sycl -j
- name: Build the release package
id: pack_artifacts
@@ -502,12 +504,12 @@ jobs:
cp "${{ env.ONEAPI_ROOT }}/tbb/latest/bin/tbb12.dll" ./build/bin
echo "cp oneAPI running time dll files to ./build/bin done"
7z a llama-${{ steps.tag.outputs.name }}-bin-win-sycl-x64.zip ./build/bin/*
7z a llama-bin-win-sycl-x64.zip ./build/bin/*
- name: Upload the release package
uses: actions/upload-artifact@v4
with:
path: llama-${{ steps.tag.outputs.name }}-bin-win-sycl-x64.zip
path: llama-bin-win-sycl-x64.zip
name: llama-bin-win-sycl-x64.zip
windows-hip:
@@ -515,14 +517,14 @@ jobs:
strategy:
matrix:
gpu_target: [gfx1100, gfx1101, gfx1030]
include:
- name: "radeon"
gpu_targets: "gfx1100;gfx1101;gfx1102;gfx1030;gfx1031;gfx1032"
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Clone rocWMMA repository
id: clone_rocwmma
@@ -532,7 +534,7 @@ jobs:
- name: ccache
uses: hendrikmuhs/ccache-action@v1.2.16
with:
key: windows-latest-cmake-hip-release
key: windows-latest-cmake-hip-${{ matrix.name }}-x64
evict-old-files: 1d
- name: Install
@@ -550,50 +552,39 @@ jobs:
run: |
& 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version
- name: libCURL
id: get_libcurl
uses: ./.github/actions/windows-setup-curl
- name: Build
id: cmake_build
env:
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
run: |
$env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path)
$env:CMAKE_PREFIX_PATH="${env:HIP_PATH}"
cmake -G "Unix Makefiles" -B build -S . `
-DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" `
-DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" `
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/rocwmma/library/include/" `
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/rocwmma/library/include/ -Wno-ignored-attributes -Wno-nested-anon-types" `
-DCMAKE_BUILD_TYPE=Release `
-DAMDGPU_TARGETS=${{ matrix.gpu_target }} `
-DGGML_BACKEND_DL=ON `
-DGGML_NATIVE=OFF `
-DGGML_CPU=OFF `
-DAMDGPU_TARGETS="${{ matrix.gpu_targets }}" `
-DGGML_HIP_ROCWMMA_FATTN=ON `
-DGGML_HIP=ON `
-DCURL_LIBRARY="$env:CURL_PATH/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="$env:CURL_PATH/include" `
${{ env.CMAKE_ARGS }}
cmake --build build -j ${env:NUMBER_OF_PROCESSORS}
-DLLAMA_CURL=OFF
cmake --build build --target ggml-hip -j ${env:NUMBER_OF_PROCESSORS}
md "build\bin\rocblas\library\"
cp "${env:HIP_PATH}\bin\hipblas.dll" "build\bin\"
cp "${env:HIP_PATH}\bin\rocblas.dll" "build\bin\"
cp "${env:HIP_PATH}\bin\rocblas\library\*" "build\bin\rocblas\library\"
- name: Determine tag name
id: tag
uses: ./.github/actions/get-tag-name
- name: Pack artifacts
id: pack_artifacts
env:
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
run: |
cp $env:CURL_PATH\bin\libcurl-x64.dll .\build\bin\libcurl-x64.dll
7z a llama-${{ steps.tag.outputs.name }}-bin-win-hip-x64-${{ matrix.gpu_target }}.zip .\build\bin\*
7z a llama-bin-win-hip-${{ matrix.name }}-x64.zip .\build\bin\*
- name: Upload artifacts
uses: actions/upload-artifact@v4
with:
path: llama-${{ steps.tag.outputs.name }}-bin-win-hip-x64-${{ matrix.gpu_target }}.zip
name: llama-bin-win-hip-x64-${{ matrix.gpu_target }}.zip
path: llama-bin-win-hip-${{ matrix.name }}-x64.zip
name: llama-bin-win-hip-${{ matrix.name }}-x64.zip
ios-xcode-build:
runs-on: macos-latest
@@ -655,14 +646,16 @@ jobs:
runs-on: ubuntu-latest
needs:
- ubuntu-22-cpu
- ubuntu-22-vulkan
- windows
- windows-cpu
- windows-cuda
- windows-sycl
- windows-hip
- ubuntu-22-cpu
- ubuntu-22-vulkan
- macOS-arm64
- macOS-x64
- ios-xcode-build
steps:
- name: Clone
@@ -680,10 +673,43 @@ jobs:
uses: actions/download-artifact@v4
with:
path: ./artifact
merge-multiple: true
- name: Move artifacts
id: move_artifacts
run: mkdir -p ./artifact/release && mv ./artifact/*/*.zip ./artifact/release
run: |
mkdir -p release
echo "Adding CPU backend files to existing zips..."
for arch in x64 arm64; do
cpu_zip="artifact/llama-bin-win-cpu-${arch}.zip"
temp_dir=$(mktemp -d)
echo "Extracting CPU backend for $arch..."
unzip "$cpu_zip" -d "$temp_dir"
echo "Adding CPU files to $arch zips..."
for target_zip in artifact/llama-bin-win-*-${arch}.zip; do
if [[ "$target_zip" == "$cpu_zip" ]]; then
continue
fi
echo "Adding CPU backend to $(basename "$target_zip")"
realpath_target_zip=$(realpath "$target_zip")
(cd "$temp_dir" && zip -r "$realpath_target_zip" .)
done
rm -rf "$temp_dir"
done
echo "Renaming and moving zips to release..."
for zip_file in artifact/llama-bin-win-*.zip; do
base_name=$(basename "$zip_file" .zip)
zip_name="llama-${{ steps.tag.outputs.name }}-${base_name#llama-}.zip"
echo "Moving $zip_file to release/$zip_name"
mv "$zip_file" "release/$zip_name"
done
echo "Moving other artifacts..."
mv -v artifact/*.zip release
- name: Create release
id: create_release
@@ -702,7 +728,7 @@ jobs:
const path = require('path');
const fs = require('fs');
const release_id = '${{ steps.create_release.outputs.id }}';
for (let file of await fs.readdirSync('./artifact/release')) {
for (let file of await fs.readdirSync('./release')) {
if (path.extname(file) === '.zip') {
console.log('uploadReleaseAsset', file);
await github.repos.uploadReleaseAsset({
@@ -710,7 +736,7 @@ jobs:
repo: context.repo.repo,
release_id: release_id,
name: file,
data: await fs.readFileSync(`./artifact/release/${file}`)
data: await fs.readFileSync(`./release/${file}`)
});
}
}
+42
View File
@@ -0,0 +1,42 @@
name: Update Winget Package
on:
workflow_dispatch: # allows manual triggering
schedule:
- cron: '28 5 * * *' # Update every day at 5:28 UTC
jobs:
update:
name: Update Winget Package
runs-on: ubuntu-latest
steps:
- name: Install cargo binstall
uses: cargo-bins/cargo-binstall@268643a6b5ea099f5718ee5cd3ff7dc89a5eb49b
- name: Install komac
run: |
cargo binstall komac@2.11.2 -y
- name: Find latest release
id: find_latest_release
uses: actions/github-script@v6
with:
script: |
const { data: releases } = await github.rest.repos.listReleases({
owner: context.repo.owner,
repo: context.repo.repo,
});
console.log("Latest release:", releases[0].tag_name);
return releases[0].tag_name;
- name: Update manifest
env:
VERSION: ${{ steps.find_latest_release.outputs.result }}
run: |
echo "Updating manifest..."
komac update --version ${{ env.VERSION }} \
--urls "https://github.com/ggml-org/llama.cpp/releases/download/${{ env.VERSION }}/llama-${{ env.VERSION }}-bin-win-vulkan-x64.zip" \
--token ${{ secrets.WINGET_GITHUB_TOKEN }} \
--submit \
ggml.llamacpp
+1
View File
@@ -580,3 +580,4 @@ $ echo "source ~/.llama-completion.bash" >> ~/.bashrc
- [minja](https://github.com/google/minja) - Minimal Jinja parser in C++, used by various tools/examples - MIT License
- [linenoise.cpp](./tools/run/linenoise.cpp/linenoise.cpp) - C++ library that provides readline-like line editing capabilities, used by `llama-run` - BSD 2-Clause License
- [curl](https://curl.se/) - Client-side URL transfer library, used by various tools/examples - [CURL License](https://curl.se/docs/copyright.html)
- [miniaudio.h](https://github.com/mackron/miniaudio) - Single-header audio format decoder, used by multimodal subsystem - Public domain
+4
View File
@@ -60,12 +60,16 @@ add_library(${TARGET} STATIC
base64.hpp
chat.cpp
chat.h
chat-parser.cpp
chat-parser.h
common.cpp
common.h
console.cpp
console.h
json-schema-to-grammar.cpp
json.hpp
json-partial.h
json-partial.cpp
llguidance.cpp
log.cpp
log.h
+6 -6
View File
@@ -39,7 +39,7 @@
using json = nlohmann::ordered_json;
std::initializer_list<enum llama_example> mmproj_examples = {
LLAMA_EXAMPLE_LLAVA,
LLAMA_EXAMPLE_MTMD,
LLAMA_EXAMPLE_SERVER,
};
@@ -1678,7 +1678,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params) {
params.warmup = false;
}
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_EMBEDDING}));
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL}));
add_opt(common_arg(
{"--spm-infill"},
string_format(
@@ -2233,12 +2233,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_examples(mmproj_examples).set_env("LLAMA_ARG_NO_MMPROJ_OFFLOAD"));
add_opt(common_arg(
{"--image"}, "FILE",
"path to an image file. use with multimodal models. Specify multiple times for batching",
{"--image", "--audio"}, "FILE",
"path to an image or audio file. use with multimodal models, can be repeated if you have multiple files\n",
[](common_params & params, const std::string & value) {
params.image.emplace_back(value);
}
).set_examples({LLAMA_EXAMPLE_LLAVA}));
).set_examples({LLAMA_EXAMPLE_MTMD}));
if (llama_supports_rpc()) {
add_opt(common_arg(
{"--rpc"}, "SERVERS",
@@ -2868,7 +2868,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, const std::string & value) {
params.chat_template = value;
}
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_LLAVA}).set_env("LLAMA_ARG_CHAT_TEMPLATE"));
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_CHAT_TEMPLATE"));
add_opt(common_arg(
{"--chat-template-file"}, "JINJA_TEMPLATE_FILE",
string_format(
+376
View File
@@ -0,0 +1,376 @@
#include "chat-parser.h"
#include "common.h"
#include "log.h"
#include "regex-partial.h"
#include <optional>
#include <stdexcept>
#include <string>
#include <vector>
using json = nlohmann::ordered_json;
common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax)
: input_(input), is_partial_(is_partial), syntax_(syntax)
{
result_.role = "assistant";
while (true) {
std::string id = std::to_string(std::rand());
if (input.find(id) == std::string::npos) {
healing_marker_ = id;
break;
}
}
}
std::string common_chat_msg_parser::str(const common_string_range & rng) const {
GGML_ASSERT(rng.begin <= rng.end);
return input_.substr(rng.begin, rng.end - rng.begin);
}
void common_chat_msg_parser::add_content(const std::string &content) {
result_.content += content;
}
void common_chat_msg_parser::add_reasoning_content(const std::string &reasoning_content) {
result_.reasoning_content += reasoning_content;
}
bool common_chat_msg_parser::add_tool_call(const std::string & name, const std::string & id, const std::string & arguments) {
if (name.empty()) {
return false;
}
common_chat_tool_call tool_call;
tool_call.name = name;
tool_call.arguments = arguments;
tool_call.id = id;
// LOG_DBG("Tool call arguments:\n\traw: %s\n\tresult: %s\n", arguments.c_str(), tool_call.arguments.c_str());
result_.tool_calls.emplace_back(tool_call);
return true;
}
bool common_chat_msg_parser::add_tool_call(const json & tool_call) {
std::string name = tool_call.contains("name") ? tool_call.at("name") : "";
std::string id = tool_call.contains("id") ? tool_call.at("id") : "";
std::string arguments = tool_call.contains("arguments") ? tool_call.at("arguments") : "";
return add_tool_call(name, id, arguments);
}
bool common_chat_msg_parser::add_tool_calls(const json & arr) {
for (const auto & item : arr) {
if (!add_tool_call(item)) {
return false;
}
}
return true;
}
void common_chat_msg_parser::finish() {
if (!is_partial_ && pos_ != input_.size()) {
throw std::runtime_error("Unexpected content at end of input");// + input_.substr(pos_));
}
}
bool common_chat_msg_parser::consume_spaces() {
const auto length = input_.size();
auto consumed = false;
while (pos_ < length && std::isspace(input_[pos_])) {
++pos_;
consumed = true;
}
return consumed;
}
bool common_chat_msg_parser::try_consume_literal(const std::string & literal) {
auto pos = pos_;
for (auto i = 0u; i < literal.size(); ++i) {
if (pos >= input_.size()) {
return false;
}
if (input_[pos] != literal[i]) {
return false;
}
++pos;
}
pos_ = pos;
return true;
}
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_literal(const std::string & literal) {
auto idx = input_.find(literal, pos_);
if (idx != std::string::npos) {
find_regex_result res;
res.prelude = input_.substr(pos_, idx - pos_);
auto end = idx + literal.size();
res.groups.emplace_back(common_string_range{idx, end});
move_to(end);
return res;
}
if (is_partial_) {
idx = string_find_partial_stop(input_, literal);
if (idx != std::string::npos && idx >= pos_) {
find_regex_result res;
res.prelude = input_.substr(pos_, idx - pos_);
auto end = input_.size();
res.groups.emplace_back(common_string_range{idx, end});
move_to(end);
return res;
}
}
return std::nullopt;
}
void common_chat_msg_parser::consume_literal(const std::string & literal) {
if (!try_consume_literal(literal)) {
throw common_chat_msg_partial_exception(literal);
}
}
bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think, const std::string & end_think) {
auto handle_reasoning = [&](const std::string & reasoning, bool closed) {
auto stripped_reasoning = string_strip(reasoning);
if (stripped_reasoning.empty()) {
return;
}
if (syntax_.reasoning_in_content) {
add_content(syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "<think>" : start_think);
add_content(stripped_reasoning);
if (closed) {
add_content(syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "</think>" : end_think);
}
} else {
add_reasoning_content(stripped_reasoning);
}
};
if (syntax_.reasoning_format != COMMON_REASONING_FORMAT_NONE) {
if (syntax_.thinking_forced_open || try_consume_literal(start_think)) {
if (auto res = try_find_literal(end_think)) {
handle_reasoning(res->prelude, /* closed */ true);
consume_spaces();
return true;
}
auto rest = consume_rest();
if (!rest.empty()) {
handle_reasoning(rest, /* closed */ !is_partial());
}
if (!syntax_.thinking_forced_open) {
throw common_chat_msg_partial_exception(end_think);
}
return true;
}
}
return false;
}
std::string common_chat_msg_parser::consume_rest() {
auto rest = input_.substr(pos_);
pos_ = input_.size();
return rest;
}
// Tries to find the regex, consumes it (pos right after it) and gives the prelude (right before it) and the groups to the callback.
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from) {
auto m = regex.search(input_, from == std::string::npos ? pos_ : from);
if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
return std::nullopt;
}
if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) {
if (is_partial()) {
throw common_chat_msg_partial_exception(regex.str());
}
return std::nullopt;
}
auto prelude = input_.substr(pos_, m.groups[0].begin - pos_);
pos_ = m.groups[0].end;
return find_regex_result{prelude, m.groups};
}
common_chat_msg_parser::find_regex_result common_chat_msg_parser::consume_regex(const common_regex & regex) {
if (auto result = try_consume_regex(regex)) {
return *result;
}
throw common_chat_msg_partial_exception(regex.str());
}
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_consume_regex(const common_regex & regex) {
auto m = regex.search(input_, pos_);
if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
return std::nullopt;
}
if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) {
if (is_partial()) {
throw common_chat_msg_partial_exception(regex.str());
}
return std::nullopt;
}
if (m.groups[0].begin != pos_) {
// Didn't match at the current position.
return std::nullopt;
}
pos_ = m.groups[0].end;
return find_regex_result {
/* .prelude = */ "",
m.groups,
};
}
std::optional<common_json> common_chat_msg_parser::try_consume_json() {
auto it = input_.cbegin() + pos_;
const auto end = input_.cend();
common_json result;
if (!common_json_parse(it, end, healing_marker_, result)) {
return std::nullopt;
}
pos_ = std::distance(input_.cbegin(), it);
if (result.healing_marker.marker.empty()) {
// No healing marker, just return the parsed json
return result;
}
if (!is_partial()) {
throw common_chat_msg_partial_exception("JSON");
}
return result;
}
common_json common_chat_msg_parser::consume_json() {
if (auto result = try_consume_json()) {
return *result;
}
throw common_chat_msg_partial_exception("JSON");
}
common_chat_msg_parser::consume_json_result common_chat_msg_parser::consume_json_with_dumped_args(
const std::vector<std::vector<std::string>> & args_paths,
const std::vector<std::vector<std::string>> & content_paths
) {
if (auto result = try_consume_json_with_dumped_args(args_paths, content_paths)) {
return *result;
}
throw common_chat_msg_partial_exception("JSON");
}
std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parser::try_consume_json_with_dumped_args(
const std::vector<std::vector<std::string>> & args_paths,
const std::vector<std::vector<std::string>> & content_paths
) {
auto partial = try_consume_json();
if (!partial) {
return std::nullopt;
}
auto is_arguments_path = [&](const std::vector<std::string> & path) {
return std::find(args_paths.begin(), args_paths.end(), path) != args_paths.end();
};
auto is_content_path = [&](const std::vector<std::string> & path) {
return std::find(content_paths.begin(), content_paths.end(), path) != content_paths.end();
};
if (partial->healing_marker.marker.empty()) {
if (args_paths.empty()) {
// No arguments to dump, and JSON was parsed fully.
return consume_json_result {
partial->json,
/* .is_partial = */ false,
};
}
if (is_arguments_path({})) {
// Entire JSON is the arguments and was parsed fully.
return consume_json_result {
partial->json.dump(),
/* .is_partial = */ false,
};
}
}
LOG_DBG("Parsed partial JSON: %s (json_healing_marker: %s)\n", partial->json.dump().c_str(), partial->healing_marker.json_dump_marker.c_str());
auto found_healing_marker = false;
std::vector<std::string> path;
std::function<json(const json &)> remove_unsupported_healings_and_dump_args = [&](const json & j) -> json {
if (is_arguments_path(path)) {
auto arguments = j.dump();
if (is_partial() && !partial->healing_marker.marker.empty()) {
auto idx = arguments.find(partial->healing_marker.json_dump_marker);
if (idx != std::string::npos) {
arguments.resize(idx);
found_healing_marker = true;
}
if (arguments == "\"") {
// This happens because of completing `:"$magic` after `"arguments"`
arguments = "";
}
}
return arguments;
}
if (is_content_path(path)) {
if (!j.is_string()) {
throw std::runtime_error("Content path must be a string");
}
std::string str = j;
auto idx = str.find(partial->healing_marker.marker); // not using json_dump_marker as we're inside a string
if (idx != std::string::npos) {
str.resize(idx);
found_healing_marker = true;
}
return str;
}
if (j.is_object()) {
auto obj = json::object();
for (const auto & p : j.items()) {
const auto & key = p.key();
const auto & value = p.value();
const std::string key_str = key; // NOLINT
auto idx = key_str.find(healing_marker_);
if (idx != std::string::npos) {
found_healing_marker = true;
break;
}
path.push_back(key_str);
if (value.is_string()) {
const std::string value_str = value;
if (value_str.find(healing_marker_) != std::string::npos) {
found_healing_marker = true;
if (is_content_path(path)) {
if (partial->healing_marker.marker == partial->healing_marker.json_dump_marker) {
// The healing occurred inside the string: good. Otherwise we just ditch the entire key/value pair.
obj[key] = remove_unsupported_healings_and_dump_args(value);
}
}
break;
}
obj[key] = value;
} else {
obj[key] = remove_unsupported_healings_and_dump_args(value);
}
path.pop_back();
}
return obj;
}
if (j.is_array()) {
auto arr = json::array();
for (const auto & value : j) {
if (value.is_string()) {
std::string str = value;
auto idx = str.find(healing_marker_);
if (idx != std::string::npos) {
// Don't heal array values that aren't in the arguments.
found_healing_marker = true;
break;
}
}
arr.push_back(remove_unsupported_healings_and_dump_args(value));
}
return arr;
}
return j;
};
auto cleaned = remove_unsupported_healings_and_dump_args(partial->json);
LOG_DBG("Cleaned up JSON %s to %s (json_healing_marker : '%s')\n", partial->json.dump().c_str(), cleaned.dump().c_str(), partial->healing_marker.json_dump_marker.c_str());
return consume_json_result {
cleaned,
/* .is_partial = */ found_healing_marker,
};
}
+116
View File
@@ -0,0 +1,116 @@
#pragma once
#include "chat.h"
#include "json-partial.h"
#include "json.hpp"
#include "regex-partial.h"
#include <optional>
#include <string>
#include <vector>
class common_chat_msg_partial_exception : public std::runtime_error {
public:
common_chat_msg_partial_exception(const std::string & message) : std::runtime_error(message) {}
};
class common_chat_msg_parser {
std::string input_;
bool is_partial_;
common_chat_syntax syntax_;
std::string healing_marker_;
size_t pos_ = 0;
common_chat_msg result_;
public:
common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
const std::string & input() const { return input_; }
size_t pos() const { return pos_; }
const std::string & healing_marker() const { return healing_marker_; }
const bool & is_partial() const { return is_partial_; }
const common_chat_msg & result() const { return result_; }
void move_to(size_t pos) {
if (pos > input_.size()) {
throw std::runtime_error("Invalid position!");
}
pos_ = pos;
}
void move_back(size_t n) {
if (pos_ < n) {
throw std::runtime_error("Can't move back that far!");
}
pos_ -= n;
}
// Get the substring of the input at the given range
std::string str(const common_string_range & rng) const;
// Appends to the result.content field
void add_content(const std::string & content);
// Appends to the result.reasoning_content field
void add_reasoning_content(const std::string & reasoning_content);
// Adds a tool call to the result. If the tool call is too incomplete (e.g. name empty), it won't add anything.
bool add_tool_call(const std::string & name, const std::string & id, const std::string & arguments);
// Adds a tool call using the "name", "id" and "arguments" fields of the json object
bool add_tool_call(const nlohmann::ordered_json & tool_call);
// Adds an array of tool calls using their "name", "id" and "arguments" fields.
bool add_tool_calls(const nlohmann::ordered_json & arr);
void finish();
bool consume_spaces();
void consume_literal(const std::string & literal);
bool try_parse_reasoning(const std::string & start_think, const std::string & end_think);
std::string consume_rest();
struct find_regex_result {
std::string prelude;
std::vector<common_string_range> groups;
};
std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos);
bool try_consume_literal(const std::string & literal);
std::optional<find_regex_result> try_find_literal(const std::string & literal);
find_regex_result consume_regex(const common_regex & regex);
std::optional<find_regex_result> try_consume_regex(const common_regex & regex);
std::optional<common_json> try_consume_json();
common_json consume_json();
struct consume_json_result {
nlohmann::ordered_json value;
bool is_partial;
};
/*
Consume (possibly partial) json and converts specific subtrees to (possibly truncated) JSON strings.
By default, object keys can't be truncated, nor can string values (their corresponding key is removed,
e.g. `{"foo": "bar", "baz": "b` -> `{"foo": "bar"}`
But one can allow subpaths to be kept truncated, and possibly json-dumped to truncated json strings
- with `content_paths={{"foo"}}` -> `{"foo": "b` -> {"foo": "b"}`
- with `args_paths={{"foo"}}` -> `{"foo": {"b` -> `{"foo": "{b"}`
*/
consume_json_result consume_json_with_dumped_args(
const std::vector<std::vector<std::string>> & args_paths = {},
const std::vector<std::vector<std::string>> & content_paths = {}
);
std::optional<consume_json_result> try_consume_json_with_dumped_args(
const std::vector<std::vector<std::string>> & args_paths = {},
const std::vector<std::vector<std::string>> & content_paths = {}
);
};
+602 -524
View File
File diff suppressed because it is too large Load Diff
+67 -5
View File
@@ -3,6 +3,7 @@
#pragma once
#include "common.h"
#include <functional>
#include <chrono>
#include <string>
#include <vector>
@@ -13,11 +14,19 @@ struct common_chat_tool_call {
std::string name;
std::string arguments;
std::string id;
bool operator==(const common_chat_tool_call & other) const {
return name == other.name && arguments == other.arguments && id == other.id;
}
};
struct common_chat_msg_content_part {
std::string type;
std::string text;
bool operator==(const common_chat_msg_content_part & other) const {
return type == other.type && text == other.text;
}
};
struct common_chat_msg {
@@ -28,6 +37,51 @@ struct common_chat_msg {
std::string reasoning_content;
std::string tool_name;
std::string tool_call_id;
template <class T> T to_json_oaicompat() const;
bool empty() const {
return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
}
void ensure_tool_call_ids_set(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) {
for (auto i = 0u; i < tool_calls.size(); i++) {
if (ids_cache.size() <= i) {
auto id = tool_calls[i].id;
if (id.empty()) {
id = gen_tool_call_id();
}
ids_cache.push_back(id);
}
tool_calls[i].id = ids_cache[i];
}
}
bool operator==(const common_chat_msg & other) const {
return role == other.role
&& content == other.content
&& content_parts == other.content_parts
&& tool_calls == other.tool_calls
&& reasoning_content == other.reasoning_content
&& tool_name == other.tool_name
&& tool_call_id == other.tool_call_id;
}
bool operator!=(const common_chat_msg & other) const {
return !(*this == other);
}
};
struct common_chat_msg_diff {
// std::string reasoning_content_delta;
std::string content_delta;
size_t tool_call_index = std::string::npos;
common_chat_tool_call tool_call_delta;
static std::vector<common_chat_msg_diff> compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg);
bool operator==(const common_chat_msg_diff & other) const {
return content_delta == other.content_delta
&& tool_call_index == other.tool_call_index
&& tool_call_delta == other.tool_call_delta;
}
};
struct common_chat_tool {
@@ -49,14 +103,11 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_LLAMA_3_X,
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING,
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
COMMON_CHAT_FORMAT_HERMES_2_PRO,
COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING,
COMMON_CHAT_FORMAT_COMMAND_R7B,
COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
};
@@ -71,7 +122,7 @@ struct common_chat_templates_inputs {
std::vector<common_chat_tool> tools;
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
bool parallel_tool_calls = false;
bool extract_reasoning = true;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
};
@@ -80,11 +131,20 @@ struct common_chat_params {
std::string prompt;
std::string grammar;
bool grammar_lazy = false;
bool thinking_forced_open = false;
std::vector<common_grammar_trigger> grammar_triggers;
std::vector<std::string> preserved_tokens;
std::vector<std::string> additional_stops;
};
struct common_chat_syntax {
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
bool reasoning_in_content = false;
bool thinking_forced_open = false;
};
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
@@ -122,7 +182,7 @@ std::string common_chat_format_example(
bool use_jinja);
std::string common_chat_format_name(common_chat_format format);
common_chat_msg common_chat_parse( const std::string & input, common_chat_format format);
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
@@ -135,3 +195,5 @@ template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common
// T can be std::string containing JSON or nlohmann::ordered_json
template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
template <class T> T common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);
+2 -2
View File
@@ -76,7 +76,7 @@ enum llama_example {
LLAMA_EXAMPLE_SERVER,
LLAMA_EXAMPLE_CVECTOR_GENERATOR,
LLAMA_EXAMPLE_EXPORT_LORA,
LLAMA_EXAMPLE_LLAVA,
LLAMA_EXAMPLE_MTMD,
LLAMA_EXAMPLE_LOOKUP,
LLAMA_EXAMPLE_PARALLEL,
LLAMA_EXAMPLE_TTS,
@@ -115,7 +115,7 @@ enum common_grammar_trigger_type {
COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN,
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
};
struct common_grammar_trigger {
+255
View File
@@ -0,0 +1,255 @@
#include <json-partial.h>
#include "ggml.h"
#include "log.h"
#include <string>
#include <json.hpp>
using json = nlohmann::ordered_json;
enum common_json_stack_element_type {
COMMON_JSON_STACK_ELEMENT_OBJECT,
COMMON_JSON_STACK_ELEMENT_KEY,
COMMON_JSON_STACK_ELEMENT_ARRAY,
};
struct common_json_stack_element {
common_json_stack_element_type type;
std::string key;
};
bool common_json_parse(
const std::string & input,
const std::string & healing_marker,
common_json & out)
{
std::string::const_iterator it = input.begin();
const auto end = input.end();
return common_json_parse(it, end, healing_marker, out);
}
bool common_json_parse(
std::string::const_iterator & it,
const std::string::const_iterator & end,
const std::string & healing_marker,
common_json & out)
{
// // https://json.nlohmann.me/features/parsing/sax_interface/
struct json_error_locator : public nlohmann::json_sax<json> {
std::size_t position;
bool found_error;
std::string last_token;
std::string exception_message;
std::vector<common_json_stack_element> stack;
json_error_locator() : position(0), found_error(false) {}
bool parse_error(std::size_t position, const std::string & last_token, const json::exception & ex) override { // NOLINT
this->position = position - 1;
this->found_error = true;
this->last_token = last_token;
this->exception_message = ex.what();
return false;
}
void close_value() {
if (!stack.empty() && (stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY)) {
stack.pop_back();
}
}
bool null() override { // NOLINT
close_value();
return true;
}
bool boolean(bool) override { // NOLINT
close_value();
return true;
}
bool number_integer(number_integer_t) override { // NOLINT
close_value();
return true;
}
bool number_unsigned(number_unsigned_t) override { // NOLINT
close_value();
return true;
}
bool number_float(number_float_t, const string_t &) override { // NOLINT
close_value();
return true;
}
bool string(string_t &) override { // NOLINT
close_value();
return true;
}
bool binary(binary_t &) override { // NOLINT
close_value();
return true;
}
bool start_object(std::size_t) override { // NOLINT
stack.push_back({COMMON_JSON_STACK_ELEMENT_OBJECT, ""});
return true;
}
bool end_object() override {
GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT);
stack.pop_back();
close_value();
return true;
}
bool key(string_t & key) override { // NOLINT
stack.push_back({COMMON_JSON_STACK_ELEMENT_KEY, key});
return true;
}
bool start_array(std::size_t) override { // NOLINT
stack.push_back({COMMON_JSON_STACK_ELEMENT_ARRAY, ""});
return true;
}
bool end_array() override {
GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY);
stack.pop_back();
close_value();
return true;
}
};
json_error_locator err_loc;
auto start = it;
json::sax_parse(it, end, &err_loc);
if (err_loc.found_error) {
it = start;
auto temptative_end = it + err_loc.position;
// LOG_DBG("Error at position %zu (is_end = %s): %s\n", err_loc.position, temptative_end == end ? "true" : "false", err_loc.exception_message.c_str());
auto input = std::string(it, temptative_end);
try {
out.json = json::parse(input);
// out.json = json::parse(it, temptative_end);
it = temptative_end;
return true;
} catch (const std::exception & ex) {
// No, needs healing.
LOG_DBG("Failed to parse up to error: %s: <<<%s>>>\n", ex.what(), std::string(it, temptative_end).c_str());
}
auto can_parse = [](const std::string & str) {
try {
auto _ = json::parse(str); // NOLINT
return true;
} catch (const std::exception &) {
return false;
}
};
if (!healing_marker.empty() && !err_loc.stack.empty()) {
std::string str(it, temptative_end);
auto last_non_sp_pos = str.find_last_not_of(" \n\r\t");
if (last_non_sp_pos == std::string::npos) {
throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
}
auto last_non_sp_char = str[last_non_sp_pos];
// Used to detect stops on a number, which may not be complete.
auto was_maybe_number = [&]() {
if (!str.empty() && std::isspace(str.back())) {
return false;
}
return std::isdigit(last_non_sp_char) ||
last_non_sp_char == '.' ||
last_non_sp_char == 'e' ||
last_non_sp_char == 'E' ||
last_non_sp_char == '-';
};
std::string closing;
for (size_t i = err_loc.stack.size(); i > 0; i--) {
auto & el = err_loc.stack[i - 1];
if (el.type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
closing += "}";
} else if (el.type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
closing += "]";
} else if (el.type != COMMON_JSON_STACK_ELEMENT_KEY) {
throw std::runtime_error("Unexpected stack element type");
}
}
const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$";
if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) {
// We're inside an object value
if (last_non_sp_char == ':' && can_parse(str + "1" + closing)) {
// Was about to create an object value
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
} else if (can_parse(str + ": 1" + closing)) {
str += (out.healing_marker.json_dump_marker = ":\"" + magic_seed) + "\"" + closing;
} else if (last_non_sp_char == '{' && can_parse(str + closing)) {
// Was about to create an object
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
} else if (can_parse(str + "\"" + closing)) {
// Was inside an object value string
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
// Was inside an object value string after an escape
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
} else {
// find last :
auto last_pos = str.find_last_of(':');
if (last_pos == std::string::npos) {
throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
}
// Cutting back to opening : for object value
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
}
} else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
if ((last_non_sp_char == ',' || last_non_sp_char == '[') && can_parse(str + "1" + closing)) {
// Was about to create an array value
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
} else if (can_parse(str + "\"" + closing)) {
// Was inside an array value string
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
// Was inside an array value string after an escape
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
} else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) {
// Had just finished a value
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing;
} else {
auto last_pos = str.find_last_of("[,");
if (last_pos == std::string::npos) {
throw std::runtime_error("Cannot heal a truncated JSON array stopped in an unknown location");
}
// Cutting back to last [ or , for array value
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
}
} else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
if ((last_non_sp_char == '{' && can_parse(str + closing)) ||
(last_non_sp_char == ',' && can_parse(str + "\"\": 1" + closing))) {
// Was about to create an object key+value
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
} else if (!was_maybe_number() && can_parse(str + ",\"\": 1" + closing)) {
// Was about to create an object key+value
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\": 1" + closing;
} else if (can_parse(str + "\": 1" + closing)) {
// Was inside an object key string
str += (out.healing_marker.json_dump_marker = magic_seed) + "\": 1" + closing;
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) {
// Was inside an object key string after an escape
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing;
} else {
auto last_pos = str.find_last_of(':');
if (last_pos == std::string::npos) {
throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
}
// fprintf(stderr, "Cutting back to last : for object key+value\n");
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
}
} else {
throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
}
// fprintf(stderr, "HEALED:\nSTRING <<<\n%s\n>>>\n\nmagic_cut: <<<\n%s\n>>>\n\n", str.c_str(), out.healing_marker.json_dump_marker.c_str());
out.json = json::parse(str);
it = temptative_end;
return true;
}
// TODO: handle unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...)
// fprintf(stderr, "Closing: TODO\n");
return false;
}
out.json = json::parse(it, end);
it = end;
return true;
}
+37
View File
@@ -0,0 +1,37 @@
#pragma once
#include <json.hpp>
// Healing marker (empty if the JSON was fully parsed / wasn't healed).
struct common_healing_marker {
// Raw marker.
std::string marker;
// Cutting the `common_json.json.dump()` string at the (only) occurrence of this marker should yield the original partial JSON string (modulo spaces / if it had the same dump format).
std::string json_dump_marker;
};
// Represents a parsed JSON object, with its optional healing marker (a JSON dump fragment that can be used to find the position of healing in the JSON dump string)
struct common_json {
nlohmann::ordered_json json;
common_healing_marker healing_marker;
};
// Parse the JSON string, healing (closing) any partial JSON if `healing_marker` is not empty.
//
// Healing completes partial JSON strings by adding a (possibly modified) healing marker, then whatever is needed to close the JSON.
// This allows to parse the resulting healed JSON string, yet be able to cut it again if needed at the healing marker.
// (this is used when parsing JSON outputs from the models, then crafting partial JSONs for the partial tool calls in OAI format).
//
// For instance, parsing `{` with a healing marker `foo` will produce a healed JSON `{"foo":1}`, w/ json_dump_marker = `"foo"` (which can be used to break the JSON again).
bool common_json_parse(
const std::string & input,
const std::string & healing_marker,
common_json & out);
// Parse the JSON string (see overload above), but advancing an iterator to the end of the input when the (potentially partial) parsing succeeds.
bool common_json_parse(
std::string::const_iterator & it,
const std::string::const_iterator & end,
const std::string & healing_marker,
common_json & out);
+7 -8
View File
@@ -161,7 +161,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
#endif // LLAMA_USE_LLGUIDANCE
} else {
std::vector<std::string> patterns_at_start;
std::vector<std::string> trigger_patterns;
std::vector<std::string> patterns_anywhere;
std::vector<llama_token> trigger_tokens;
for (const auto & trigger : params.grammar_triggers) {
@@ -173,10 +173,13 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
{
const auto & pattern = trigger.value;
(trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START ? patterns_at_start : patterns_anywhere).push_back(pattern);
patterns_anywhere.push_back(trigger.value);
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
{
trigger_patterns.push_back(trigger.value);
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
@@ -190,10 +193,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
}
}
std::vector<std::string> trigger_patterns;
if (!patterns_at_start.empty()) {
trigger_patterns.push_back("^(" + string_join(patterns_at_start, "|") + ")[\\s\\S]*");
}
if (!patterns_anywhere.empty()) {
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
}
+120 -44
View File
@@ -45,7 +45,7 @@ class SentencePieceTokenTypes(IntEnum):
class ModelType(IntEnum):
TEXT = 1
VISION = 2
MMPROJ = 2
AnyModel = TypeVar("AnyModel", bound="type[ModelBase]")
@@ -54,7 +54,7 @@ AnyModel = TypeVar("AnyModel", bound="type[ModelBase]")
class ModelBase:
_model_classes: dict[ModelType, dict[str, type[ModelBase]]] = {
ModelType.TEXT: {},
ModelType.VISION: {},
ModelType.MMPROJ: {},
}
dir_model: Path
@@ -88,7 +88,7 @@ class ModelBase:
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None):
if type(self) is ModelBase or \
type(self) is TextModel or \
type(self) is VisionModel:
type(self) is MmprojModel:
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
self.dir_model = dir_model
@@ -309,6 +309,7 @@ class ModelBase:
gguf.MODEL_TENSOR.POSNET_NORM1,
gguf.MODEL_TENSOR.POSNET_NORM2,
gguf.MODEL_TENSOR.V_ENC_EMBD_POS,
gguf.MODEL_TENSOR.A_ENC_EMBD_POS,
)
)
or not new_name.endswith(".weight")
@@ -438,7 +439,7 @@ class ModelBase:
assert names
def func(modelcls: AnyModel) -> AnyModel:
model_type = ModelType.VISION if modelcls.model_arch == gguf.MODEL_ARCH.CLIP_VISION else ModelType.TEXT
model_type = ModelType.MMPROJ if modelcls.model_arch == gguf.MODEL_ARCH.MMPROJ else ModelType.TEXT
for name in names:
cls._model_classes[model_type][name] = modelcls
return modelcls
@@ -1114,60 +1115,87 @@ class TextModel(ModelBase):
self.gguf_writer.add_pooling_type(pooling_type)
class VisionModel(ModelBase):
model_type = ModelType.VISION
model_arch = gguf.MODEL_ARCH.CLIP_VISION
class MmprojModel(ModelBase):
model_type = ModelType.MMPROJ
model_arch = gguf.MODEL_ARCH.MMPROJ
preprocessor_config: dict[str, Any]
global_config: dict[str, Any]
has_vision_encoder: bool = True # by default
has_audio_encoder: bool = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.model_arch != gguf.MODEL_ARCH.CLIP_VISION:
raise TypeError("VisionModel must be subclassed with model_arch = gguf.MODEL_ARCH.CLIP_VISION")
if self.model_arch != gguf.MODEL_ARCH.MMPROJ:
raise TypeError("MmprojModel must be subclassed with model_arch = gguf.MODEL_ARCH.MMPROJ")
if self.has_vision_encoder and self.has_audio_encoder:
raise NotImplementedError("both vision + audio not supported yet")
# get n_embd of the text model
if "text_config" not in self.hparams:
self.hparams["text_config"] = {}
if "audio_config" not in self.hparams:
self.hparams["audio_config"] = {}
text_config = {**self.hparams, **self.hparams["text_config"]}
self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0))
assert self.n_embd_text > 0, "n_embd not found in hparams"
if "vision_config" not in self.hparams:
raise ValueError("vision_config not found in hparams")
# move vision config to the top level, while preserving the original hparams in global_config
self.global_config = self.hparams
self.hparams = self.hparams["vision_config"]
if "vision_config" in self.hparams:
self.hparams = self.hparams["vision_config"]
elif "audio_config" in self.hparams:
self.hparams = self.hparams["audio_config"]
else:
raise ValueError("vision_config / audio_config not found in hparams")
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"])
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.CLIP_VISION, self.block_count)
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count)
# load preprocessor config
with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
self.preprocessor_config = json.load(f)
def set_type(self):
self.gguf_writer.add_type(gguf.GGUFType.CLIP_VISION)
self.gguf_writer.add_type(gguf.GGUFType.MMPROJ)
def set_gguf_parameters(self):
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_vision_projection_dim(self.n_embd_text)
self.gguf_writer.add_vision_has_vision_encoder(True)
# vision config
self.gguf_writer.add_vision_image_size(self.find_hparam(["image_size"]))
self.gguf_writer.add_vision_patch_size(self.find_hparam(["patch_size"]))
self.gguf_writer.add_vision_embedding_length(self.find_hparam(["hidden_size"]))
self.gguf_writer.add_vision_feed_forward_length(self.find_hparam(["intermediate_size"]))
self.gguf_writer.add_vision_block_count(self.block_count)
self.gguf_writer.add_vision_head_count(self.find_hparam(["num_attention_heads"]))
if self.has_vision_encoder:
self.gguf_writer.add_clip_has_vision_encoder(True)
self.gguf_writer.add_vision_projection_dim(self.n_embd_text)
# preprocessor config
self.gguf_writer.add_vision_image_mean(self.preprocessor_config["image_mean"])
self.gguf_writer.add_vision_image_std(self.preprocessor_config["image_std"])
# vision config
self.gguf_writer.add_vision_image_size(self.find_hparam(["image_size"]))
self.gguf_writer.add_vision_patch_size(self.find_hparam(["patch_size"]))
self.gguf_writer.add_vision_embedding_length(self.find_hparam(["hidden_size"]))
self.gguf_writer.add_vision_feed_forward_length(self.find_hparam(["intermediate_size"]))
self.gguf_writer.add_vision_block_count(self.block_count)
self.gguf_writer.add_vision_head_count(self.find_hparam(["num_attention_heads"]))
# preprocessor config
self.gguf_writer.add_vision_image_mean(self.preprocessor_config["image_mean"])
self.gguf_writer.add_vision_image_std(self.preprocessor_config["image_std"])
elif self.has_audio_encoder:
self.gguf_writer.add_clip_has_audio_encoder(True)
self.gguf_writer.add_audio_projection_dim(self.n_embd_text)
# audio config
self.gguf_writer.add_audio_embedding_length(self.find_hparam(["hidden_size"]))
self.gguf_writer.add_audio_feed_forward_length(self.find_hparam(["intermediate_size"]))
self.gguf_writer.add_audio_block_count(self.block_count)
self.gguf_writer.add_audio_head_count(self.find_hparam(["num_attention_heads"]))
else:
raise ValueError("MmprojModel must have either vision or audio encoder")
def write_vocab(self):
raise ValueError("VisionModel does not support vocab writing")
raise ValueError("MmprojModel does not support vocab writing")
@ModelBase.register("GPTNeoXForCausalLM")
@@ -1951,7 +1979,7 @@ class LlamaModel(TextModel):
"LlavaForConditionalGeneration", # pixtral
"Mistral3ForConditionalGeneration", # mistral small 3.1
)
class LlavaVisionModel(VisionModel):
class LlavaVisionModel(MmprojModel):
img_break_tok_id = -1
def __init__(self, *args, **kwargs):
@@ -1977,7 +2005,7 @@ class LlavaVisionModel(VisionModel):
super().set_gguf_parameters()
hparams = self.hparams
if hparams["model_type"] == "pixtral":
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.PIXTRAL)
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.PIXTRAL)
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
# hidden_act
@@ -2016,7 +2044,7 @@ class LlavaVisionModel(VisionModel):
@ModelBase.register("Idefics3ForConditionalGeneration", "SmolVLMForConditionalGeneration")
class SmolVLMModel(VisionModel):
class SmolVLMModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.hparams["model_type"] == "smolvlm_vision":
@@ -2028,7 +2056,7 @@ class SmolVLMModel(VisionModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.IDEFICS3)
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.IDEFICS3)
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-5))
self.gguf_writer.add_vision_projector_scale_factor(self.global_config.get("scale_factor", 2))
self.gguf_writer.add_vision_use_gelu(True)
@@ -2094,10 +2122,10 @@ class Llama4Model(LlamaModel):
@ModelBase.register("Llama4ForConditionalGeneration")
class Llama4VisionModel(VisionModel):
class Llama4VisionModel(MmprojModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.LLAMA4)
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.LLAMA4)
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams["norm_eps"])
self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / self.hparams["pixel_shuffle_ratio"]))
assert self.hparams["hidden_act"] == "gelu"
@@ -2645,7 +2673,7 @@ class Qwen2Model(TextModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration")
@ModelBase.register("Qwen2VLModel", "Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration")
class Qwen2VLModel(TextModel):
model_arch = gguf.MODEL_ARCH.QWEN2VL
@@ -2669,8 +2697,8 @@ class Qwen2VLModel(TextModel):
return [(self.map_tensor_name(name), data_torch)]
@ModelBase.register("Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration")
class Qwen2VLVisionModel(VisionModel):
@ModelBase.register("Qwen2VLModel", "Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration")
class Qwen2VLVisionModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hparams["image_size"] = self.hparams.get("image_size", 560)
@@ -2685,9 +2713,9 @@ class Qwen2VLVisionModel(VisionModel):
super().set_gguf_parameters()
hparams = self.hparams
if self.global_config['model_type'] == 'qwen2_vl':
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.QWEN2VL)
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN2VL)
elif self.global_config['model_type'] == 'qwen2_5_vl':
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.QWEN25VL)
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN25VL)
self.gguf_writer.add_vision_use_silu(True)
# find n_wa_pattern (window attention pattern)
fullatt_block_indexes = hparams.get("fullatt_block_indexes")
@@ -2746,11 +2774,11 @@ class Qwen2VLVisionModel(VisionModel):
@ModelBase.register("InternVisionModel")
class InternVisionModel(VisionModel):
class InternVisionModel(MmprojModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.INTERNVL)
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.INTERNVL)
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
# hidden_act
if hparams["hidden_act"] == "silu":
@@ -4008,11 +4036,11 @@ class Gemma3Model(TextModel):
@ModelBase.register("Gemma3ForConditionalGeneration")
class Gemma3VisionModel(VisionModel):
class Gemma3VisionModel(MmprojModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.GEMMA3)
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.GEMMA3)
# default values below are taken from HF tranformers code
self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("layer_norm_eps", 1e-6))
self.gguf_writer.add_vision_use_gelu(True)
@@ -5959,6 +5987,52 @@ class ChameleonModel(TextModel):
return data_torch
@ModelBase.register("UltravoxModel")
class UltravoxModel(TextModel):
model_arch = gguf.MODEL_ARCH.LLAMA # dummy
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
raise NotImplementedError("Ultravox does not have text decoder. Please use --mmproj argument")
@ModelBase.register("UltravoxModel")
class UltravoxAudioModel(MmprojModel):
has_vision_encoder = False # no vision encoder
has_audio_encoder = True
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hparams["hidden_size"] = self.hparams["d_model"]
self.hparams["intermediate_size"] = self.hparams["encoder_ffn_dim"]
self.hparams["num_attention_heads"] = self.hparams["encoder_attention_heads"]
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.ULTRAVOX)
self.gguf_writer.add_audio_num_mel_bins(self.hparams["num_mel_bins"])
self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-5))
self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"])
def tensor_force_quant(self, name, new_name, bid, n_dims):
del bid, new_name, n_dims # unused
if ".conv" in name and ".weight" in name:
return gguf.GGMLQuantizationType.F16
return False
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
# prevent clash naming with vision tensors
if name.startswith("multi_modal_projector"):
name = "audio." + name
if "conv1.bias" in name or "conv2.bias" in name:
# transpose conv1 and conv2 bias
data_torch = data_torch.unsqueeze(-1)
return [(self.map_tensor_name(name), data_torch)]
###### CONVERSION LOGIC ######
@@ -6134,13 +6208,15 @@ def split_str_to_n_bytes(split_str: str) -> int:
def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> str:
# TODO @ngxson : this won't work correctly if the model has both audio & vision encoders
# maybe we should fallback to text model's arch in that case, since not many models have both
text_config = hparams.get("text_config", {})
vision_config = hparams.get("vision_config", {})
arch = hparams["architectures"][0]
# if "architectures" is found in the sub-config, use that instead
if model_type == ModelType.TEXT and text_config.get("architectures") is not None:
arch = text_config["architectures"][0]
elif model_type == ModelType.VISION and vision_config.get("architectures") is not None:
elif model_type == ModelType.MMPROJ and vision_config.get("architectures") is not None:
arch = vision_config["architectures"][0]
return arch
@@ -6203,7 +6279,7 @@ def main() -> None:
with torch.inference_mode():
output_type = ftype_map[args.outtype]
model_type = ModelType.VISION if args.mmproj else ModelType.TEXT
model_type = ModelType.MMPROJ if args.mmproj else ModelType.TEXT
hparams = ModelBase.load_hparams(dir_model)
model_architecture = get_model_architecture(hparams, model_type)
logger.info(f"Model architecture: {model_architecture}")
+53 -24
View File
@@ -325,36 +325,65 @@ To get the official template from original HuggingFace repos, you can use [scrip
> [!TIP]
> If there is no official `tool_use` Jinja template, you may want to set `--chat-template chatml` to use a default that works with many models (YMMV!), or write your own (e.g. we provide a custom [llama-cpp-deepseek-r1.jinja](../models/templates/llama-cpp-deepseek-r1.jinja) for DeepSeek R1 distills)
> [!CAUTION]
> Beware of extreme KV quantizations (e.g. `-ctk q4_0`), they can substantially degrade the model's tool calling performance.
Test in CLI (or with any library / software that can use OpenAI-compatible API backends):
```bash
curl http://localhost:8080/v1/chat/completions -d '{
"model": "gpt-3.5-turbo",
"tools": [
{
"type":"function",
"function":{
"name":"python",
"description":"Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
"parameters":{
"type":"object",
"properties":{
"code":{
"type":"string",
"description":"The code to run in the ipython interpreter."
"model": "gpt-3.5-turbo",
"tools": [
{
"type":"function",
"function":{
"name":"python",
"description":"Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
"parameters":{
"type":"object",
"properties":{
"code":{
"type":"string",
"description":"The code to run in the ipython interpreter."
}
},
"required":["code"]
}
},
"required":["code"]
}
}
}
],
"messages": [
{
"role": "user",
"content": "Print a hello world message with python."
}
]
}
],
"messages": [
{
"role": "user",
"content": "Print a hello world message with python."
}
]
}'
curl http://localhost:8080/v1/chat/completions -d '{
"model": "gpt-3.5-turbo",
"messages": [
{"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."},
{"role": "user", "content": "What is the weather in Istanbul?"}
],
"tools": [{
"type":"function",
"function":{
"name":"get_current_weather",
"description":"Get the current weather in a given location",
"parameters":{
"type":"object",
"properties":{
"location":{
"type":"string",
"description":"The city and country/state, e.g. `San Francisco, CA`, or `Paris, France`"
}
},
"required":["location"]
}
}
}]
}'
```
+13 -1
View File
@@ -4,7 +4,9 @@ llama.cpp supports multimodal input via `libmtmd`. Currently, there are 2 tools
- [llama-mtmd-cli](../tools/mtmd/README.md)
- [llama-server](../tools/server/README.md) via OpenAI-compatible `/chat/completions` API
To enable it, can use use one of the 2 methods below:
Currently, we support **image** and **audio** input. Audio is highly experimental and may have reduced quality.
To enable it, you can use one of the 2 methods below:
- Use `-hf` option with a supported model (see a list of pre-quantized model below)
- To load a model using `-hf` while disabling multimodal, use `--no-mmproj`
@@ -37,6 +39,8 @@ Replaces the `(tool_name)` with the name of binary you want to use. For example,
NOTE: some models may require large context window, for example: `-c 8192`
**Vision models**:
```sh
# Gemma 3
(tool_name) -hf ggml-org/gemma-3-4b-it-GGUF
@@ -78,3 +82,11 @@ NOTE: some models may require large context window, for example: `-c 8192`
# Llama 4 Scout
(tool_name) -hf ggml-org/Llama-4-Scout-17B-16E-Instruct-GGUF
```
**Audio models**:
```sh
# Ultravox 0.5
(tool_name) -hf ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF
(tool_name) -hf ggml-org/ultravox-v0_5-llama-3_1-8b-GGUF
```
+6 -6
View File
@@ -81,14 +81,14 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
}
}
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
static void batch_encode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_self_clear(ctx);
// run model
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
if (llama_decode(ctx, batch) < 0) {
LOG_ERR("%s : failed to decode\n", __func__);
if (llama_encode(ctx, batch) < 0) {
LOG_ERR("%s : failed to encode\n", __func__);
}
for (int i = 0; i < batch.n_tokens; i++) {
@@ -233,7 +233,7 @@ int main(int argc, char ** argv) {
// encode if at capacity
if (batch.n_tokens + n_toks > n_batch) {
float * out = emb + p * n_embd;
batch_decode(ctx, batch, out, s, n_embd);
batch_encode(ctx, batch, out, s, n_embd);
common_batch_clear(batch);
p += s;
s = 0;
@@ -246,7 +246,7 @@ int main(int argc, char ** argv) {
// final batch
float * out = emb + p * n_embd;
batch_decode(ctx, batch, out, s, n_embd);
batch_encode(ctx, batch, out, s, n_embd);
// save embeddings to chunks
for (int i = 0; i < n_chunks; i++) {
@@ -267,7 +267,7 @@ int main(int argc, char ** argv) {
batch_add_seq(query_batch, query_tokens, 0);
std::vector<float> query_emb(n_embd, 0);
batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd);
batch_encode(ctx, query_batch, query_emb.data(), 1, n_embd);
common_batch_clear(query_batch);
+11
View File
@@ -536,6 +536,7 @@ extern "C" {
GGML_UNARY_OP_HARDSWISH,
GGML_UNARY_OP_HARDSIGMOID,
GGML_UNARY_OP_EXP,
GGML_UNARY_OP_GELU_ERF,
GGML_UNARY_OP_COUNT,
};
@@ -1024,6 +1025,16 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);
// GELU using erf (error function) when possible
// some backends may fallback to approximation based on Abramowitz and Stegun formula
GGML_API struct ggml_tensor * ggml_gelu_erf(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_gelu_erf_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_gelu_quick(
struct ggml_context * ctx,
struct ggml_tensor * a);
+133 -6
View File
@@ -2697,14 +2697,10 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor*
}
}
// GroupedMatmulV2 required tensor_list.size < 128
size_t GROUP_SIZE = 128;
std::vector<std::vector<aclTensor*>> src0_tensor_vec_vec;
std::vector<std::vector<aclTensor*>> src1_tensor_vec_vec;
std::vector<std::vector<aclTensor*>> dst_tensor_vec_vec;
// split and call GroupedMatmulV2
// GroupedMatmulV2 required tensor_list.size < 128
for (size_t i = 0; i < src0_tensor_vec.size(); i += GROUP_SIZE) {
// split and call GroupedMatmulV2
size_t end = std::min(i + GROUP_SIZE, src0_tensor_vec.size());
std::vector<aclTensor*> src0_tensor_vec_split(src0_tensor_vec.begin() + i, src0_tensor_vec.begin() + end);
std::vector<aclTensor*> src1_tensor_vec_split(src1_tensor_vec.begin() + i, src1_tensor_vec.begin() + end);
@@ -2722,6 +2718,133 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor*
return;
}
/**
* @brief Performs expert-specific matrix multiplication (MoE) with
* quantized precision using the CANN backend.
*
* This function executes a matrix multiplication operation tailored for
* Mixture of Experts (MoE) models, where the input tensor is multiplied
* with expert-specific quantized weight matrices. It leverages the CANN
* backend to perform efficient low-precision computations and stores the
* quantized result in the destination tensor `dst`.
*
* Quantization techniques reduce memory footprint and improve performance
* by using lower-bit representations (e.g., int8) instead of floating-point.
* This function is designed to work with such formats and may incorporate
* optimizations like identity-based fast paths or routing masks for sparse
* expert selection.
*
* @param ctx The context for executing CANN backend operations.
* @param dst The destination tensor where the quantized MoE multiplication result
* will be stored.
*
* @note This function assumes quantized data types and is designed for
* MoE architectures with potential sparse expert routing.
*/
static void ggml_cann_mul_mat_id_quant(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
// TODO: Use aclnnGroupedMatMul
//dst [M, K, N, 1]
ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1]
ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1
ggml_tensor * ids = dst->src[2]; //ids [K, N]
GGML_TENSOR_BINARY_OP_LOCALS
// copy index from npu to cpu
int64_t n_as = ne02; // A
int64_t n_ids = ids->ne[0]; // K
std::vector<char> ids_host(ggml_nbytes(ids));
ggml_cann_async_memcpy(ctx, ids_host.data(), ids->data, ggml_nbytes(ids),
ACL_MEMCPY_DEVICE_TO_HOST);
ACL_CHECK(aclrtSynchronizeStream(ctx.stream()));
char * src0_original = (char *) src0->data;
char * src1_original = (char *) src1->data;
char * dst_original = (char *) dst->data;
ggml_tensor src0_row = *src0;
ggml_tensor src1_row = *src1;
ggml_tensor dst_row = *dst;
const enum ggml_type type = dst->src[0]->type;
float weight_elem_size;
if (type == GGML_TYPE_Q4_0) {
weight_elem_size = float(sizeof(uint8_t)) / 2;
} else if (type == GGML_TYPE_Q8_0) {
weight_elem_size = float(sizeof(uint8_t));
} else {
GGML_ABORT("MUL_MAT_ID only support quant type Q4_0 and Q8_0 ");
}
// src0_row [D, M, 1, 1] weight without permute
src0_row.ne[2] = 1;
src0_row.ne[3] = 1;
src0_row.nb[0] = weight_elem_size;
src0_row.nb[1] = weight_elem_size * ne00;
src0_row.nb[2] = weight_elem_size * ne00;
src0_row.nb[3] = weight_elem_size * ne00;
size_t weight_stride = ne00 * ne01 * weight_elem_size;
size_t weight_size = weight_stride * ne02 * ne03;
// scale [D, M, 1, 1] -> scale && permute
size_t scale_elem_size = sizeof(uint16_t);
size_t scale_stride = src0->ne[1] * src0->ne[0] / QK8_0 * scale_elem_size;
// src1_row [D, 1, 1, 1] -> input
src1_row.ne[1] = 1;
src1_row.ne[2] = 1;
src1_row.ne[3] = 1;
src1_row.nb[2] = nb11;
src1_row.nb[3] = nb11;
// dst_row [M, 1, 1, 1] -> out
dst_row.ne[1] = 1;
dst_row.ne[2] = 1;
dst_row.ne[3] = 1;
dst_row.nb[2] = nb1;
dst_row.nb[3] = nb1;
//create weight for one row
ggml_cann_pool_alloc weight_allocator(ctx.pool());
void* weight_buffer = weight_allocator.alloc(nb02);
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
for (int64_t id = 0; id < n_ids; id++) {
// expert index
int32_t i02 = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
GGML_ASSERT(i02 >= 0 && i02 < n_as);
// If B = 1 (broadcast), always use 0; otherwise, use id.
int64_t i11 = (ne11 == 1 ? 0 : id);
int64_t i12 = iid1;
int64_t i1 = id;
int64_t i2 = i12;
void* src0_tmp_ptr = src0_original + i02*weight_stride;
void* scale_tmp_ptr = src0_original + weight_size + i02*scale_stride;
void* src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12;
void* dst_tmp_ptr = dst_original + i1*nb1 + i2*nb2;
// mem cpy
ggml_cann_async_memcpy(ctx, weight_buffer, src0_tmp_ptr, weight_stride,
ACL_MEMCPY_DEVICE_TO_DEVICE);
void* scale_buffer = (char*)weight_buffer + weight_stride;
ggml_cann_async_memcpy(ctx, scale_buffer, scale_tmp_ptr, scale_stride,
ACL_MEMCPY_DEVICE_TO_DEVICE);
src0_row.data = weight_buffer;
src1_row.data = src1_tmp_ptr;
dst_row.data = dst_tmp_ptr;
dst_row.src[0] = &src0_row;
dst_row.src[1] = &src1_row;
ggml_cann_mul_mat(ctx, &dst_row);
}
}
return;
}
void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
const enum ggml_type type = dst->src[0]->type;
switch (type) {
@@ -2729,6 +2852,10 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
case GGML_TYPE_F16:
ggml_cann_mul_mat_id_fp(ctx, dst);
break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
ggml_cann_mul_mat_id_quant(ctx, dst);
break;
default:
GGML_ABORT("Unsupported type for mul_mat_id");
break;
+9
View File
@@ -2035,6 +2035,15 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
case GGML_TYPE_F16:
case GGML_TYPE_F32:
return true;
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_0:
#ifdef ASCEND_310P
// Q4 && Q8 per group is not suppor on 310p device
return false;
#endif
// only support contiguous for quantized types.
return ggml_is_contiguous(op->src[0]) &&
ggml_is_contiguous(op->src[1]);
default:
return false;
}
+14
View File
@@ -2202,6 +2202,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
} break;
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_GELU_ERF:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_SILU:
{
@@ -3483,6 +3484,19 @@ void ggml_cpu_init(void) {
const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0);
#ifdef GGML_USE_OPENMP
//if (!getenv("OMP_WAIT_POLICY")) {
// // set the wait policy to active, so that OpenMP threads don't sleep
// putenv("OMP_WAIT_POLICY=active");
//}
if (!getenv("KMP_BLOCKTIME")) {
// set the time to wait before sleeping a thread
// this is less aggressive than setting the wait policy to active, but should achieve similar results in most cases
putenv("KMP_BLOCKTIME=200"); // 200ms
}
#endif
}
#if defined(__ARM_ARCH)
+107
View File
@@ -2691,6 +2691,109 @@ static void ggml_compute_forward_gelu(
}
}
// ggml_compute_forward_gelu_erf
static void ggml_compute_forward_gelu_erf_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_are_same_shape(src0, dst));
const int ith = params->ith;
const int nth = params->nth;
const int nc = src0->ne[0];
const int nr = ggml_nrows(src0);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
ggml_vec_gelu_erf_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
GGML_UNUSED(x);
assert(!isnan(x));
assert(!isinf(x));
}
#endif
}
}
static void ggml_compute_forward_gelu_erf_f16(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_are_same_shape(src0, dst));
const int ith = params->ith;
const int nth = params->nth;
const int nc = src0->ne[0];
const int nr = ggml_nrows(src0);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
ggml_vec_gelu_erf_f16(nc,
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
const float v = GGML_FP16_TO_FP32(x);
GGML_UNUSED(v);
assert(!isnan(v));
assert(!isinf(v));
}
#endif
}
}
static void ggml_compute_forward_gelu_erf(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_gelu_erf_f32(params, dst);
} break;
case GGML_TYPE_F16:
{
ggml_compute_forward_gelu_erf_f16(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_gelu_quick
static void ggml_compute_forward_gelu_quick_f32(
@@ -7749,6 +7852,10 @@ void ggml_compute_forward_unary(
{
ggml_compute_forward_gelu(params, dst);
} break;
case GGML_UNARY_OP_GELU_ERF:
{
ggml_compute_forward_gelu_erf(params, dst);
} break;
case GGML_UNARY_OP_GELU_QUICK:
{
ggml_compute_forward_gelu_quick(params, dst);
+16
View File
@@ -428,6 +428,7 @@ inline static void ggml_vec_exp_f16 (const int n, ggml_fp16_t * y, const ggml_fp
static const float GELU_COEF_A = 0.044715f;
static const float GELU_QUICK_COEF = -1.702f;
static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
static const float SQRT_2_INV = 0.70710678118654752440084436210484f;
inline static float ggml_gelu_f32(float x) {
return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
@@ -440,6 +441,14 @@ inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp
}
}
inline static void ggml_vec_gelu_erf_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
for (int i = 0; i < n; ++i) {
float xi = GGML_FP16_TO_FP32(x[i]);
float res = 0.5f*xi*(1.0f + erff(xi*SQRT_2_INV));
y[i] = GGML_FP32_TO_FP16(res);
}
}
#ifdef GGML_GELU_FP16
inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
uint16_t t;
@@ -463,6 +472,13 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
}
#endif
inline static void ggml_vec_gelu_erf_f32(const int n, float * y, const float * x) {
for (int i = 0; i < n; ++i) {
float xi = x[i];
y[i] = 0.5f*xi*(1.0f + erff(xi*SQRT_2_INV));
}
}
inline static float ggml_gelu_quick_f32(float x) {
return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
}
+1
View File
@@ -212,6 +212,7 @@ static __global__ void flash_attn_vec_ext_f16(
}
}
if (__all_sync(0xFFFFFFFF, skip)) {
__syncthreads();
continue;
}
#endif // GGML_USE_HIP
+1
View File
@@ -217,6 +217,7 @@ static __global__ void flash_attn_vec_ext_f32(
}
}
if (__all_sync(0xFFFFFFFF, skip)) {
__syncthreads();
continue;
}
#endif // GGML_USE_HIP
+4
View File
@@ -2192,6 +2192,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_UNARY_OP_SILU:
ggml_cuda_op_silu(ctx, dst);
break;
case GGML_UNARY_OP_GELU_ERF:
ggml_cuda_op_gelu_erf(ctx, dst);
break;
case GGML_UNARY_OP_GELU_QUICK:
ggml_cuda_op_gelu_quick(ctx, dst);
break;
@@ -2977,6 +2980,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_GELU_ERF:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_EXP:
+10
View File
@@ -23,6 +23,12 @@ static __device__ __forceinline__ float op_gelu(float x) {
return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
}
static __device__ __forceinline__ float op_gelu_erf(float x) {
const float SQRT_2_INV = 0.70710678118654752440084436210484f;
return 0.5f*x*(1.0f + erff(x*SQRT_2_INV));
}
static __device__ __forceinline__ float op_gelu_quick(float x) {
const float GELU_QUICK_COEF = -1.702f;
@@ -134,6 +140,10 @@ void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_gelu>(ctx, dst);
}
void ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_gelu_erf>(ctx, dst);
}
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_gelu_quick>(ctx, dst);
}
+2
View File
@@ -30,6 +30,8 @@ void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+24
View File
@@ -149,6 +149,8 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_SIGMOID,
GGML_METAL_KERNEL_TYPE_GELU,
GGML_METAL_KERNEL_TYPE_GELU_4,
GGML_METAL_KERNEL_TYPE_GELU_ERF,
GGML_METAL_KERNEL_TYPE_GELU_ERF_4,
GGML_METAL_KERNEL_TYPE_GELU_QUICK,
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
GGML_METAL_KERNEL_TYPE_SILU,
@@ -1103,6 +1105,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_ERF, gelu_erf, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_ERF_4, gelu_erf_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
@@ -1613,6 +1617,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_GELU_ERF:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_ELU:
@@ -2251,6 +2256,25 @@ static bool ggml_metal_encode_node(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_UNARY_OP_GELU_ERF:
{
int64_t n = ggml_nelements(dst);
id<MTLComputePipelineState> pipeline = nil;
if (n % 4 == 0) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_ERF_4].pipeline;
n /= 4;
} else {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_ERF].pipeline;
}
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_UNARY_OP_GELU_QUICK:
{
int64_t n = ggml_nelements(dst);
+37
View File
@@ -856,6 +856,7 @@ kernel void kernel_tanh(
constant float GELU_COEF_A = 0.044715f;
constant float GELU_QUICK_COEF = -1.702f;
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
kernel void kernel_gelu(
device const float * src0,
@@ -897,6 +898,42 @@ kernel void kernel_gelu_quick_4(
dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
}
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
// ref: https://www.johndcook.com/blog/python_erf/
constant float p_erf = 0.3275911f;
constant float a1_erf = 0.254829592f;
constant float a2_erf = -0.284496736f;
constant float a3_erf = 1.421413741f;
constant float a4_erf = -1.453152027f;
constant float a5_erf = 1.061405429f;
template<typename T>
T erf_approx(T x) {
T sign_x = sign(x);
x = fabs(x);
T t = 1.0f / (1.0f + p_erf * x);
T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
return sign_x * y;
}
kernel void kernel_gelu_erf(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
device const float & x = src0[tpig];
dst[tpig] = 0.5f*x*(1.0f+erf_approx<float>(x*SQRT_2_INV));
}
kernel void kernel_gelu_erf_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
device const float4 & x = src0[tpig];
dst[tpig] = 0.5f*x*(1.0f+erf_approx<float4>(x*SQRT_2_INV));
}
kernel void kernel_silu(
device const float * src0,
device float * dst,
+316 -156
View File
@@ -27,6 +27,7 @@
#include <cmath>
#include <memory>
#include <charconv>
#include <mutex>
#undef MIN
#undef MAX
@@ -74,6 +75,7 @@ struct ggml_cl_version {
cl_uint minor = 0;
};
struct ggml_cl_compiler_version {
ADRENO_CL_COMPILER_TYPE type;
int major = -1;
@@ -91,6 +93,14 @@ struct ggml_cl_compiler_version {
}
};
static size_t align_to(size_t value, size_t to_alignment) {
GGML_ASSERT(to_alignment && "Invalid alignment (must be non-zero)");
GGML_ASSERT((to_alignment & (to_alignment - 1)) == 0 && "to_alignment must be power-of-two");
return ((value + to_alignment - 1) / to_alignment) * to_alignment;
}
// Parses a version string of form "XX.YY ". On an error returns ggml_cl_version with all zeroes.
static ggml_cl_version parse_cl_version(std::string_view str) {
size_t major_str_begin = 0;
@@ -221,13 +231,25 @@ static ggml_cl_compiler_version get_adreno_cl_compiler_version(const char *drive
return { type, major, minor, patch };
}
struct ggml_backend_opencl_context;
// backend device context
struct ggml_backend_opencl_device_context {
cl_platform_id platform;
std::string platform_name;
cl_device_id device;
std::string device_name;
cl_device_id device;
std::string device_name;
cl_device_type device_type;
std::string device_version;
// Initialized by ggml_cl2_init().
ggml_backend_opencl_context * backend_ctx = nullptr;
// Initialized by ggml_backend_opencl_device_get_buffer_type()
ggml_backend_buffer_type buffer_type;
cl_context context = nullptr;
};
// backend context
@@ -248,6 +270,8 @@ struct ggml_backend_opencl_context {
int adreno_wave_size;
cl_bool non_uniform_workgroups;
cl_context context;
cl_command_queue queue;
@@ -344,15 +368,8 @@ struct ggml_backend_opencl_context {
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
};
static ggml_backend_device g_ggml_backend_opencl_device;
static ggml_backend_opencl_device_context g_ggml_ctx_dev_main {
/*.platform =*/ nullptr,
/*.platform_nane =*/ "",
/*.device =*/ nullptr,
/*.device_name =*/ "",
};
static int ggml_backend_opencl_n_devices = 0;
// All registered devices with a default device in the front.
static std::vector<ggml_backend_device> g_ggml_backend_opencl_devices;
// Profiling
#ifdef GGML_OPENCL_PROFILING
@@ -1107,25 +1124,19 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
GGML_LOG_CONT("\n");
}
static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
static bool initialized = false;
static ggml_backend_opencl_context *backend_ctx = nullptr;
// XXX static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
// XXX static bool initialized = false;
// XXX static ggml_backend_opencl_context *backend_ctx = nullptr;
if (initialized) {
return backend_ctx;
}
static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev);
ggml_backend_opencl_device_context *dev_ctx = (ggml_backend_opencl_device_context *)dev->context;
GGML_ASSERT(dev_ctx);
GGML_ASSERT(dev_ctx->platform == nullptr);
GGML_ASSERT(dev_ctx->device == nullptr);
GGML_ASSERT(backend_ctx == nullptr);
namespace /* anonymous */ {
extern struct ggml_backend_device_i ggml_backend_opencl_device_i;
}
initialized = true;
backend_ctx = new ggml_backend_opencl_context();
backend_ctx->gpu_family = GPU_FAMILY::UNKNOWN;
cl_int err;
// Look for available and suitable devices.
static std::vector<ggml_backend_device> ggml_opencl_probe_devices(ggml_backend_reg * reg) {
std::vector<ggml_backend_device> found_devices;
#ifdef GGML_OPENCL_PROFILING
GGML_LOG_INFO("ggml_opencl: OpenCL profiling enabled\n");
@@ -1158,11 +1169,12 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
struct cl_device devices[NDEV];
unsigned n_devices = 0;
struct cl_device * default_device = NULL;
unsigned default_platform_number = 0;
cl_platform_id platform_ids[NPLAT];
if (clGetPlatformIDs(NPLAT, platform_ids, &n_platforms) != CL_SUCCESS) {
GGML_LOG_ERROR("ggml_opencl: plaform IDs not available.\n");
return backend_ctx;
return found_devices;
}
for (unsigned i = 0; i < n_platforms; i++) {
@@ -1197,19 +1209,22 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
}
if (default_device == NULL && p->default_device != NULL) {
default_device = p->default_device;
default_device = p->default_device;
default_platform_number = i;
}
}
if (n_devices == 0) {
GGML_LOG_ERROR("ggml_opencl: could find any OpenCL devices.\n");
return backend_ctx;
return found_devices;
}
char * user_platform_string = getenv("GGML_OPENCL_PLATFORM");
char * user_device_string = getenv("GGML_OPENCL_DEVICE");
int user_platform_number = -1;
int user_device_number = -1;
char * user_platform_string = getenv("GGML_OPENCL_PLATFORM");
char * user_device_string = getenv("GGML_OPENCL_DEVICE");
int user_platform_number = -1;
int user_device_number = -1;
cl_device * candidate_devices = nullptr;
unsigned n_candidate_devices = 0;
unsigned n;
if (user_platform_string != NULL && sscanf(user_platform_string, " %u", &n) == 1 && n < n_platforms) {
@@ -1224,12 +1239,11 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
GGML_LOG_ERROR("ggml_opencl: invalid device number %d\n", user_device_number);
exit(1);
}
default_device = &platform->devices[user_device_number];
default_device = &platform->devices[user_device_number];
candidate_devices = platform->devices;
n_candidate_devices = platform->n_devices;
} else {
struct cl_device * selected_devices = devices;
unsigned n_selected_devices = n_devices;
// Choose a platform by matching a substring.
if (user_platform_number == -1 && user_platform_string != NULL && user_platform_string[0] != 0) {
for (unsigned i = 0; i < n_platforms; i++) {
struct cl_platform * p = &platforms[i];
@@ -1244,20 +1258,20 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
exit(1);
}
}
if (user_platform_number != -1) {
struct cl_platform * p = &platforms[user_platform_number];
selected_devices = p->devices;
n_selected_devices = p->n_devices;
default_device = p->default_device;
if (n_selected_devices == 0) {
GGML_LOG_ERROR("ggml_opencl: selected platform '%s' does not have any devices.\n", p->name);
exit(1);
}
int platform_idx = user_platform_number != -1 ? user_platform_number : default_platform_number;
struct cl_platform * p = &platforms[platform_idx];
candidate_devices = p->devices;
n_candidate_devices = p->n_devices;
default_device = p->default_device;
if (n_candidate_devices == 0) {
GGML_LOG_ERROR("ggml_opencl: selected platform '%s' does not have any devices.\n", p->name);
exit(1);
}
if (user_device_number == -1 && user_device_string != NULL && user_device_string[0] != 0) {
for (unsigned i = 0; i < n_selected_devices; i++) {
struct cl_device * d = &selected_devices[i];
for (unsigned i = 0; i < n_candidate_devices; i++) {
struct cl_device * d = &candidate_devices[i];
if (strstr(d->name, user_device_string) != NULL) {
user_device_number = d->number;
break;
@@ -1269,71 +1283,145 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
}
}
if (user_device_number != -1) {
selected_devices = &devices[user_device_number];
n_selected_devices = 1;
default_device = &selected_devices[0];
candidate_devices = &devices[user_device_number];
n_candidate_devices = 1;
default_device = &candidate_devices[0];
}
GGML_ASSERT(n_selected_devices > 0);
GGML_ASSERT(n_candidate_devices > 0);
if (default_device == NULL) {
default_device = &selected_devices[0];
default_device = &candidate_devices[0];
}
}
GGML_LOG_INFO("ggml_opencl: selecting platform: '%s'\n", default_device->platform->name);
GGML_LOG_INFO("ggml_opencl: selecting device: '%s (%s)'\n", default_device->name, default_device->version);
if (default_device->type != CL_DEVICE_TYPE_GPU) {
GGML_LOG_WARN("ggml_opencl: warning, not a GPU: '%s'.\n", default_device->name);
GGML_ASSERT(n_candidate_devices != 0 && candidate_devices);
// Put the default device in front.
for (unsigned i = 1; i < n_candidate_devices; i++) {
if (&candidate_devices[i] == default_device) {
std::swap(candidate_devices[0], candidate_devices[i]);
default_device = &candidate_devices[0];
break;
}
}
dev_ctx->platform = default_device->platform->id;
dev_ctx->device = default_device->id;
backend_ctx->device = default_device->id;
GGML_LOG_INFO("ggml_opencl: selected platform: '%s'\n", default_device->platform->name);
if (strstr(default_device->name, "Adreno") ||
strstr(default_device->name, "Qualcomm") ||
strstr(default_device->version, "Adreno")) {
std::vector<cl_device_id> device_ids;
for (auto dev = candidate_devices, dev_end = candidate_devices + n_candidate_devices; dev != dev_end; dev++) {
device_ids.push_back(dev->id);
}
cl_int err;
cl_context shared_context;
cl_context_properties properties[] = { (intptr_t) CL_CONTEXT_PLATFORM, (intptr_t) default_device->platform->id, 0 };
CL_CHECK(
(shared_context = clCreateContext(properties, device_ids.size(), device_ids.data(), NULL, NULL, &err), err));
for (auto dev = candidate_devices, dev_end = candidate_devices + n_candidate_devices; dev != dev_end; dev++) {
GGML_LOG_INFO("\nggml_opencl: device: '%s (%s)'\n", dev->name, dev->version);
auto dev_ctx = std::unique_ptr<ggml_backend_opencl_device_context>(new ggml_backend_opencl_device_context{
/*.platform =*/dev->platform->id,
/*.platform_nane =*/dev->platform->name,
/*.device =*/dev->id,
/*.device_name =*/dev->name,
/*.device_type =*/dev->type,
/*.device_version =*/dev->version,
/*.backend_ctx =*/nullptr,
/*.buffer_type =*/{},
/*.context =*/shared_context,
});
found_devices.push_back(ggml_backend_device{
/* .iface = */ ggml_backend_opencl_device_i,
/* .reg = */ reg,
/* .context = */ dev_ctx.get(),
});
if (!ggml_cl2_init(&found_devices.back())) {
found_devices.pop_back();
GGML_LOG_INFO("ggml_opencl: drop unsupported device.\n");
continue;
}
dev_ctx.release();
}
if (found_devices.size()) {
auto * dev_ctx = static_cast<ggml_backend_opencl_device_context *>(found_devices.front().context);
GGML_LOG_INFO("ggml_opencl: default device: '%s (%s)'\n", dev_ctx->device_name.c_str(),
dev_ctx->device_version.c_str());
if (dev_ctx->device_type != CL_DEVICE_TYPE_GPU) {
GGML_LOG_WARN("ggml_opencl: warning, the default device is not a GPU: '%s'.\n",
dev_ctx->device_name.c_str());
}
}
return found_devices;
}
// Initialize device if it is supported (returns nullptr if it is not).
static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
GGML_ASSERT(dev);
GGML_ASSERT(dev->context);
ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) dev->context;
GGML_ASSERT(dev_ctx->platform);
GGML_ASSERT(dev_ctx->device);
if (dev_ctx->backend_ctx) {
return dev_ctx->backend_ctx;
}
auto backend_ctx = std::make_unique<ggml_backend_opencl_context>();
backend_ctx->device = dev_ctx->device;
backend_ctx->gpu_family = GPU_FAMILY::UNKNOWN;
if (strstr(dev_ctx->device_name.c_str(), "Adreno") ||
strstr(dev_ctx->device_name.c_str(), "Qualcomm") ||
strstr(dev_ctx->device_version.c_str(), "Adreno")) {
backend_ctx->gpu_family = GPU_FAMILY::ADRENO;
// Usually device version contains the detailed device name
backend_ctx->adreno_gen = get_adreno_gpu_gen(default_device->version);
backend_ctx->adreno_gen = get_adreno_gpu_gen(dev_ctx->device_version.c_str());
if (backend_ctx->adreno_gen == ADRENO_GPU_GEN::ADRENO_UNKNOWN) {
backend_ctx->adreno_gen = get_adreno_gpu_gen(default_device->name);
backend_ctx->adreno_gen = get_adreno_gpu_gen(dev_ctx->device_name.c_str());
}
// Use wave size of 64 for all Adreno GPUs.
backend_ctx->adreno_wave_size = 64;
} else if (strstr(default_device->name, "Intel")) {
} else if (strstr(dev_ctx->device_name.c_str(), "Intel")) {
backend_ctx->gpu_family = GPU_FAMILY::INTEL;
} else {
GGML_LOG_ERROR("Unsupported GPU: %s\n", default_device->name);
GGML_LOG_ERROR("Unsupported GPU: %s\n", dev_ctx->device_name.c_str());
backend_ctx->gpu_family = GPU_FAMILY::UNKNOWN;
return backend_ctx;
return nullptr;
}
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
if (backend_ctx->gpu_family != GPU_FAMILY::ADRENO) {
GGML_LOG_ERROR("ggml_opencl: Adreno-specific kernels should not be enabled for non-Adreno GPUs; "
"run on an Adreno GPU or recompile with CMake option `-DGGML_OPENCL_USE_ADRENO_KERNELS=OFF`\n");
return backend_ctx;
return nullptr;
}
#endif
// Populate backend device name
dev_ctx->platform_name = default_device->platform->name;
dev_ctx->device_name = default_device->name;
backend_ctx->device_name = default_device->name;
backend_ctx->device_name = dev_ctx->device_name;
// A local ref of cl_device_id for convenience
cl_device_id device = backend_ctx->device;
ggml_cl_version platform_version = get_opencl_platform_version(default_device->platform->id);
ggml_cl_version platform_version = get_opencl_platform_version(dev_ctx->platform);
// Check device OpenCL version, OpenCL 2.0 or above is required
ggml_cl_version opencl_c_version = get_opencl_c_version(platform_version, device);
if (opencl_c_version.major < 2) {
GGML_LOG_ERROR("ggml_opencl: OpenCL 2.0 or above is required\n");
return backend_ctx;
return nullptr;
}
// Check driver version
@@ -1364,7 +1452,7 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
// fp16 is required
if (!backend_ctx->fp16_support) {
GGML_LOG_ERROR("ggml_opencl: device does not support FP16\n");
return backend_ctx;
return nullptr;
}
// If OpenCL 3.0 is supported, then check for cl_khr_subgroups, which becomes
@@ -1373,7 +1461,7 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
strstr(ext_buffer, "cl_intel_subgroups") == NULL) {
GGML_LOG_ERROR("ggml_opencl: device does not support subgroups (cl_khr_subgroups or cl_intel_subgroups) "
"(note that subgroups is an optional feature in OpenCL 3.0)\n");
return backend_ctx;
return nullptr;
}
cl_uint base_align_in_bits;
@@ -1397,6 +1485,15 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
GGML_LOG_INFO("ggml_opencl: SVM atomics support: %s\n",
svm_caps & CL_DEVICE_SVM_ATOMICS ? "true" : "false");
if (opencl_c_version.major >= 3) {
CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_NON_UNIFORM_WORK_GROUP_SUPPORT, sizeof(cl_bool),
&backend_ctx->non_uniform_workgroups, 0));
} else {
GGML_ASSERT(opencl_c_version.major == 2);
// Non-uniform workgroup sizes is mandatory feature in v2.x.
backend_ctx->non_uniform_workgroups = true;
}
// Print out configurations
#ifdef GGML_OPENCL_SOA_Q
GGML_LOG_INFO("ggml_opencl: flattening quantized weights representation as struct of arrays (GGML_OPENCL_SOA_Q)\n");
@@ -1406,14 +1503,10 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
GGML_LOG_INFO("ggml_opencl: using kernels optimized for Adreno (GGML_OPENCL_USE_ADRENO_KERNELS)\n");
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
cl_context_properties properties[] = {
(intptr_t)CL_CONTEXT_PLATFORM, (intptr_t)dev_ctx->platform, 0
};
CL_CHECK((backend_ctx->context = clCreateContext(properties, 1, &device, NULL, NULL, &err), err));
cl_int err;
// A local ref of cl_context for convenience
cl_context context = backend_ctx->context;
cl_context context = backend_ctx->context = dev_ctx->context;
//CL_CHECK((queue = clCreateCommandQueue(context, device, CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE, &err),
// (err != CL_INVALID_QUEUE_PROPERTIES && err != CL_INVALID_VALUE ? err :
@@ -1426,7 +1519,7 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
CL_CHECK((backend_ctx->queue = clCreateCommandQueue(context, device, command_queue_props, &err), err));
// Load kernels
load_cl_kernels(backend_ctx, opencl_c_version);
load_cl_kernels(backend_ctx.get(), opencl_c_version);
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
// Allocate intermediate buffers and images
@@ -1456,10 +1549,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
CL_CHECK((backend_ctx->B_d_max = clCreateBuffer(context, 0, max_B_d_bytes, NULL, &err), err));
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
// For now we support a single devices
ggml_backend_opencl_n_devices = 1;
return backend_ctx;
dev_ctx->backend_ctx = backend_ctx.release();
return dev_ctx->backend_ctx;
}
static void ggml_cl2_free(void) {
@@ -1664,10 +1755,46 @@ static void ggml_backend_opencl_synchronize(ggml_backend_t backend) {
GGML_UNUSED(backend);
}
// Syncronizes the 'backend_ctx's device with others so that commands
// enqueued to it won't start until commands in the other devices have
// completed.
static void sync_with_other_backends(ggml_backend_opencl_context * backend_ctx) {
if (g_ggml_backend_opencl_devices.size() < 2)
return; // No other devices to synchronize with.
std::vector<cl_event> events;
events.reserve(g_ggml_backend_opencl_devices.size());
for (ggml_backend_device & backend_dev : g_ggml_backend_opencl_devices) {
auto * other_backend_ctx = ggml_cl2_init(&backend_dev);
if (backend_ctx != other_backend_ctx) {
cl_event ev;
CL_CHECK(clEnqueueMarkerWithWaitList(other_backend_ctx->queue, 0, nullptr, &ev));
CL_CHECK(clFlush(other_backend_ctx->queue));
events.push_back(ev);
}
}
CL_CHECK(clEnqueueBarrierWithWaitList(backend_ctx->queue, events.size(), events.data(), nullptr));
for (auto ev : events) {
CL_CHECK(clReleaseEvent(ev));
}
}
static void sync_with_other_backends(ggml_backend_t backend) {
auto * backend_ctx = static_cast<ggml_backend_opencl_context *>(backend->context);
sync_with_other_backends(backend_ctx);
}
static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
// NOTE: this may oversynchronize by synchronizing with
// backends/devices which don't compute 'cgraph's
// dependencies.
sync_with_other_backends(backend);
if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
continue;
}
@@ -2058,15 +2185,16 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
// The original tensor memory is divided into scales and quants, i.e.,
// we first store scales, then quants.
// Create subbuffer for scales.
region.origin = extra_orig->offset + tensor->view_offs + offset;
region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment);
region.size = size_d;
extra->d = clCreateSubBuffer(
extra_orig->data_device, CL_MEM_READ_WRITE,
CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
auto previous_origin = region.origin;
// Create subbuffer for quants.
region.origin = extra_orig->offset + tensor->view_offs + offset + size_d;
region.origin = align_to(previous_origin + size_d, backend_ctx->alignment);
region.size = size_q;
extra->q = clCreateSubBuffer(
extra_orig->data_device, CL_MEM_READ_WRITE,
@@ -2271,8 +2399,8 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
cl_context context = backend_ctx->context;
cl_command_queue queue = backend_ctx->queue;
// Make sure all previously submitted commands are finished.
CL_CHECK(clFinish(queue));
// Make sure all previously submitted commands in other devices are finished.
sync_with_other_backends(backend_ctx);
#ifdef GGML_OPENCL_SOA_Q
// In end-to-end runs, get_tensor is usually used to get back the logits,
@@ -2376,13 +2504,8 @@ static ggml_backend_buffer_t ggml_backend_opencl_buffer_type_alloc_buffer(ggml_b
}
static size_t ggml_backend_opencl_buffer_type_get_alignment(ggml_backend_buffer_type_t buffer_type) {
// FIXME: not thread safe, device may not be initialized yet
static cl_uint alignment = -1;
if (alignment == (cl_uint)-1) {
ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer_type->device);
alignment = backend_ctx->alignment;
}
return alignment;
ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer_type->device);
return backend_ctx->alignment;
}
static size_t ggml_backend_opencl_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) {
@@ -2409,16 +2532,6 @@ static ggml_backend_buffer_type_i ggml_backend_opencl_buffer_type_interface = {
/* .is_host = */ NULL,
};
ggml_backend_buffer_type_t ggml_backend_opencl_buffer_type() {
static ggml_backend_buffer_type buffer_type = {
/* .iface = */ ggml_backend_opencl_buffer_type_interface,
/* .device = */ &g_ggml_backend_opencl_device,
/* .context = */ nullptr,
};
return &buffer_type;
}
//
// backend device
//
@@ -2476,9 +2589,15 @@ static ggml_backend_t ggml_backend_opencl_device_init(ggml_backend_dev_t dev, co
}
static ggml_backend_buffer_type_t ggml_backend_opencl_device_get_buffer_type(ggml_backend_dev_t dev) {
return ggml_backend_opencl_buffer_type();
auto * dev_ctx = static_cast<ggml_backend_opencl_device_context *>(dev->context);
GGML_UNUSED(dev);
dev_ctx->buffer_type = ggml_backend_buffer_type{
/* .iface = */ ggml_backend_opencl_buffer_type_interface,
/* .device = */ dev,
/* .context = */ nullptr,
};
return &dev_ctx->buffer_type;
}
static ggml_backend_buffer_t ggml_backend_opencl_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
@@ -2494,12 +2613,21 @@ static bool ggml_backend_opencl_device_supports_op(ggml_backend_dev_t dev, const
}
static bool ggml_backend_opencl_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
return buft->iface.get_name == ggml_backend_opencl_buffer_type_get_name;
// Check 'dev' and 'buffer_type' are not objects belonging to this backend.
if (dev->iface.get_name != ggml_backend_opencl_device_get_name ||
buft->iface.get_name != ggml_backend_opencl_buffer_type_get_name) {
return false;
}
GGML_UNUSED(dev);
// Check cl_context is the same. clEnqueue* commands may not use
// buffers from another cl_context.
ggml_backend_opencl_context * backend_ctx0 = ggml_cl2_init(dev);
ggml_backend_opencl_context * backend_ctx1 = ggml_cl2_init(buft->device);
return backend_ctx0->context == backend_ctx1->context;
}
static struct ggml_backend_device_i ggml_backend_opencl_device_i = {
namespace /* anonymous */ {
struct ggml_backend_device_i ggml_backend_opencl_device_i = {
/* .get_name = */ ggml_backend_opencl_device_get_name,
/* .get_description = */ ggml_backend_opencl_device_get_description,
/* .get_memory = */ ggml_backend_opencl_device_get_memory,
@@ -2516,6 +2644,7 @@ static struct ggml_backend_device_i ggml_backend_opencl_device_i = {
/* .event_free = */ NULL,
/* .event_synchronize = */ NULL,
};
}
// Backend registry
@@ -2526,15 +2655,15 @@ static const char * ggml_backend_opencl_reg_get_name(ggml_backend_reg_t reg) {
}
static size_t ggml_backend_opencl_reg_device_count(ggml_backend_reg_t reg) {
return ggml_backend_opencl_n_devices;
return g_ggml_backend_opencl_devices.size();
GGML_UNUSED(reg);
}
static ggml_backend_dev_t ggml_backend_opencl_reg_device_get(ggml_backend_reg_t reg, size_t index) {
GGML_ASSERT(index == 0);
GGML_ASSERT(index < ggml_backend_opencl_reg_device_count(reg));
return &g_ggml_backend_opencl_device;
return &g_ggml_backend_opencl_devices[index];
GGML_UNUSED(reg);
GGML_UNUSED(index);
@@ -2548,27 +2677,23 @@ static struct ggml_backend_reg_i ggml_backend_opencl_reg_i = {
};
ggml_backend_reg_t ggml_backend_opencl_reg(void) {
// TODO: make this thread-safe somehow?
static std::mutex mutex;
static ggml_backend_reg reg;
static bool initialized = false;
std::lock_guard<std::mutex> lock(mutex);
if (!initialized) {
reg = ggml_backend_reg {
/* .api_version = */ GGML_BACKEND_API_VERSION,
/* .iface = */ ggml_backend_opencl_reg_i,
/* .context = */ NULL,
};
g_ggml_backend_opencl_device = ggml_backend_device {
/* .iface = */ ggml_backend_opencl_device_i,
/* .reg = */ &reg,
/* .context = */ &g_ggml_ctx_dev_main,
};
ggml_cl2_init(&g_ggml_backend_opencl_device);
initialized = true;
if (initialized) {
return &reg;
}
initialized = true;
g_ggml_backend_opencl_devices = ggml_opencl_probe_devices(&reg);
reg = ggml_backend_reg{
/* .api_version = */ GGML_BACKEND_API_VERSION,
/* .iface = */ ggml_backend_opencl_reg_i,
/* .context = */ NULL,
};
return &reg;
}
@@ -2942,14 +3067,19 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
size_t global_work_size[] = {(size_t)n, 1, 1};
size_t local_work_size[] = {64, 1, 1};
size_t * local_work_size_ptr = local_work_size;
if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
local_work_size_ptr = nullptr; // Let driver choose the work-group sizes.
}
#ifdef GGML_OPENCL_PROFILING
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt));
g_profiling_info.emplace_back();
populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst);
#else
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL));
#endif
} else {
unsigned int nth = MIN(64, ne0);
@@ -3077,14 +3207,19 @@ static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const
size_t global_work_size[] = {(size_t)n, 1, 1};
size_t local_work_size[] = {64, 1, 1};
size_t * local_work_size_ptr = local_work_size;
if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
local_work_size_ptr = nullptr; // Let driver choose the work-group sizes.
}
#ifdef GGML_OPENCL_PROFILING
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt));
g_profiling_info.emplace_back();
populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst);
#else
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL));
#endif
} else {
unsigned int nth = MIN(64, ne0);
@@ -3233,14 +3368,19 @@ static void ggml_cl_silu(ggml_backend_t backend, const ggml_tensor * src0, const
size_t global_work_size[] = {(size_t)n, 1, 1};
size_t local_work_size[] = {64, 1, 1};
size_t * local_work_size_ptr = local_work_size;
if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
local_work_size_ptr = nullptr; // Let driver choose the work-group sizes.
}
#ifdef GGML_OPENCL_PROFILING
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt));
g_profiling_info.emplace_back();
populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst);
#else
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL));
#endif
}
@@ -3273,14 +3413,19 @@ static void ggml_cl_relu(ggml_backend_t backend, const ggml_tensor * src0, const
size_t global_work_size[] = {(size_t)n, 1, 1};
size_t local_work_size[] = {64, 1, 1};
size_t * local_work_size_ptr = local_work_size;
if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
local_work_size_ptr = nullptr; // Let driver choose the work-group sizes.
}
#ifdef GGML_OPENCL_PROFILING
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt));
g_profiling_info.emplace_back();
populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst);
#else
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL));
#endif
}
@@ -3320,14 +3465,19 @@ static void ggml_cl_clamp(ggml_backend_t backend, const ggml_tensor * src0, cons
size_t global_work_size[] = {(size_t)n, 1, 1};
size_t local_work_size[] = {64, 1, 1};
size_t * local_work_size_ptr = local_work_size;
if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
local_work_size_ptr = nullptr; // Let driver choose the work-group sizes.
}
#ifdef GGML_OPENCL_PROFILING
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt));
g_profiling_info.emplace_back();
populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst);
#else
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL));
#endif
}
@@ -4230,14 +4380,19 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons
size_t global_work_size[] = {(size_t)n, 1, 1};
size_t local_work_size[] = {64, 1, 1};
size_t * local_work_size_ptr = local_work_size;
if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
local_work_size_ptr = nullptr; // Let driver choose the work-group sizes.
}
#ifdef GGML_OPENCL_PROFILING
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt));
g_profiling_info.emplace_back();
populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst);
#else
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL));
#endif
}
@@ -4418,14 +4573,19 @@ static void ggml_cl_diag_mask_inf(ggml_backend_t backend, const ggml_tensor * sr
size_t global_work_size[] = {(size_t)ne00, (size_t)ne01, (size_t)ne02};
size_t local_work_size[] = {64, 1, 1};
size_t * local_work_size_ptr = local_work_size;
if (ne00 % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
local_work_size_ptr = nullptr; // Let driver choose the work-group sizes.
}
#ifdef GGML_OPENCL_PROFILING
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt));
g_profiling_info.emplace_back();
populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst);
#else
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL));
#endif
}
}
+222 -111
View File
@@ -1,74 +1,93 @@
#include "binbcast.hpp"
#include <array>
#include <cstddef>
#include <cstdint>
#include <sycl/sycl.hpp>
#include "dpct/helper.hpp"
#include "ggml.h"
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
static __dpct_inline__ void k_bin_bcast_contiguous(const src0_t * __restrict__ src0, const src1_t * __restrict__ src1,
dst_t * dst, std::size_t num_elements, const sycl::nd_item<1> & it) {
auto element_id = it.get_global_id(0);
auto global_range = it.get_global_range(0);
for (; element_id < num_elements; element_id += global_range) {
auto src0_float_val = sycl::vec(src0[element_id]).template convert<float, sycl::rounding_mode::rte>();
auto src1_float_val = sycl::vec(src1[element_id]).template convert<float, sycl::rounding_mode::rte>();
float dst_val = bin_op(src0_float_val[0], src1_float_val[0]);
auto val_to_store = sycl::vec(dst_val).template convert<dst_t, sycl::rounding_mode::rte>();
dst[element_id] = val_to_store;
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
int ne0, int ne1, int ne2, int ne3,
int ne10, int ne11, int ne12, int ne13,
/*int s0, */ int s1, int s2, int s3,
/*int s00,*/ int s01, int s02, int s03,
/*int s10,*/ int s11, int s12, int s13,
const sycl::nd_item<3> &item_ct1) {
const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
item_ct1.get_local_id(1));
const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
item_ct1.get_local_id(0)) /
ne3;
const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
item_ct1.get_local_id(0)) %
ne3;
if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
return;
}
const int i11 = i1 % ne11;
const int i12 = i2 % ne12;
const int i13 = i3 % ne13;
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
const src0_t * src0_row = src0 + i_src0;
const src1_t * src1_row = src1 + i_src1;
dst_t * dst_row = dst + i_dst;
for (int i0 = i0s; i0 < ne0;
i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {
const int i10 = i0 % ne10;
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
}
}
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
static __dpct_inline__ void k_bin_bcast(const src0_t * __restrict__ src0, const src1_t * __restrict__ src1, dst_t * dst,
int ne0, int ne1, int ne2, int ne3, int ne10, int ne11, int ne12, int ne13,
int s0, int s1, int s2, int s3, int s00, int s01, int s02, int s03, int s10,
int s11, int s12, int s13, std::size_t num_dst_elements,
const sycl::nd_item<1> & item_ct1) {
auto calculate_logical_index =
[](const std::array<int, 4> & dims, std::size_t element_id) __attribute__((always_inline))->std::array<int, 4> {
std::array<int, 4> logical_index;
#pragma unroll(4)
for (int i = 3; i >= 0; i--) {
logical_index[i] = element_id % dims[i];
element_id /= dims[i];
}
return logical_index;
};
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
int ne0, int ne1, int ne2, int ne3,
int ne10, int ne11, int ne12, int ne13,
/*int s0, */ int s1, int s2, int s3,
/*int s00,*/ int s01, int s02, int s03,
/*int s10,*/ int s11, int s12, int s13,
const sycl::nd_item<3> &item_ct1) {
auto calculate_index = [](const std::array<int, 4> & dims, const std::array<int, 4> & strides,
const std::array<int, 4> & indices) __attribute__((always_inline))
->std::size_t {
std::size_t index = 0;
#pragma unroll(4)
for (int i = 0; i < 4; i++) {
auto index_i = indices[i];
if (indices[i] >= dims[i]) {
index_i = indices[i] % dims[i];
}
index += strides[i] * index_i;
}
return index;
};
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
auto element_id = item_ct1.get_global_id(0);
for (; element_id < num_dst_elements; element_id += item_ct1.get_global_range(0)) {
auto logical_index = calculate_logical_index({ ne3, ne2, ne1, ne0 }, element_id);
auto src_0_index = calculate_index({ ne3, ne2, ne1, ne0 }, { s03, s02, s01, s00 }, logical_index);
auto src_1_index = calculate_index({ ne13, ne12, ne11, ne10 }, { s13, s12, s11, s10 }, logical_index);
auto dst_index = calculate_index({ ne3, ne2, ne1, ne0 }, { s3, s2, s1, s0 }, logical_index);
auto src0_float_val = sycl::vec(src0[src_0_index]).template convert<float, sycl::rounding_mode::rte>();
auto src1_float_val = sycl::vec(src1[src_1_index]).template convert<float, sycl::rounding_mode::rte>();
float dst_val = bin_op(src0_float_val[0], src1_float_val[0]);
auto val_to_store = sycl::vec(dst_val).template convert<dst_t, sycl::rounding_mode::rte>();
dst[dst_index] = val_to_store;
const int i3 = i/(ne2*ne1*ne0);
const int i2 = (i/(ne1*ne0)) % ne2;
const int i1 = (i/ne0) % ne1;
const int i0 = i % ne0;
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
return;
}
const int i11 = i1 % ne11;
const int i12 = i2 % ne12;
const int i13 = i3 % ne13;
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
const src0_t * src0_row = src0 + i_src0;
const src1_t * src1_row = src1 + i_src1;
dst_t * dst_row = dst + i_dst;
const int i10 = i0 % ne10;
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
}
template <float (*bin_op)(const float, const float)> struct bin_bcast_sycl {
template<float (*bin_op)(const float, const float)>
struct bin_bcast_sycl {
template <typename src0_t, typename src1_t, typename dst_t>
void operator()(const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, const int64_t ne00,
const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t ne10, const int64_t ne11,
@@ -77,73 +96,165 @@ template <float (*bin_op)(const float, const float)> struct bin_bcast_sycl {
const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0,
const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguous,
const bool src1_is_contiguous, const bool dst_is_contiguous, queue_ptr stream) {
auto check_bcast_required = [](const std::array<int64_t, 4> & src_dims,
const std::array<int64_t, 4> & dst_dims) -> bool {
for (int i = 0; i < 4; i++) {
if (dst_dims[i] > src_dims[i]) {
return true;
}
}
return false;
int nr0 = ne10 / ne0;
int nr1 = ne11/ne1;
int nr2 = ne12/ne2;
int nr3 = ne13/ne3;
int nr[4] = { nr0, nr1, nr2, nr3 };
// collapse dimensions until first broadcast dimension
int64_t cne[] = {ne0, ne1, ne2, ne3};
int64_t cne0[] = {ne00, ne01, ne02, ne03};
int64_t cne1[] = {ne10, ne11, ne12, ne13};
size_t cnb[] = {nb0, nb1, nb2, nb3};
size_t cnb0[] = {nb00, nb01, nb02, nb03};
size_t cnb1[] = {nb10, nb11, nb12, nb13};
auto collapse = [](int64_t cne[]) {
cne[0] *= cne[1];
cne[1] = cne[2];
cne[2] = cne[3];
cne[3] = 1;
};
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
cnb[1] *= cne[1];
cnb[2] *= cne[2];
cnb[3] *= cne[3];
};
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
if (src0_is_contiguous && src1_is_contiguous && dst_is_contiguous) {
for (int i = 0; i < 4; i++) {
if (nr[i] != 1) {
break;
}
if (i > 0) {
collapse_nb(cnb, cne);
collapse_nb(cnb0, cne0);
collapse_nb(cnb1, cne1);
collapse(cne);
collapse(cne0);
collapse(cne1);
}
}
}
{
int64_t ne0 = cne[0];
int64_t ne1 = cne[1];
int64_t ne2 = cne[2];
int64_t ne3 = cne[3];
GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
int64_t ne10 = cne1[0];
int64_t ne11 = cne1[1];
int64_t ne12 = cne1[2];
int64_t ne13 = cne1[3];
GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
size_t nb0 = cnb[0];
size_t nb1 = cnb[1];
size_t nb2 = cnb[2];
size_t nb3 = cnb[3];
// dst strides in number of elements
size_t s0 = nb0 / sizeof(dst_t);
size_t s1 = nb1 / sizeof(dst_t);
size_t s2 = nb2 / sizeof(dst_t);
size_t s3 = nb3 / sizeof(dst_t);
size_t nb00 = cnb0[0];
size_t nb01 = cnb0[1];
size_t nb02 = cnb0[2];
size_t nb03 = cnb0[3];
// src1 strides in number of elements
size_t s10 = nb10 / sizeof(src0_t);
size_t s11 = nb11 / sizeof(src1_t);
size_t s12 = nb12 / sizeof(src1_t);
size_t s13 = nb13 / sizeof(src1_t);
size_t nb10 = cnb1[0];
size_t nb11 = cnb1[1];
size_t nb12 = cnb1[2];
size_t nb13 = cnb1[3];
// src0 strides in number of elements
size_t s00 = nb00 / sizeof(src0_t);
size_t s01 = nb01 / sizeof(src0_t);
size_t s02 = nb02 / sizeof(src0_t);
size_t s03 = nb03 / sizeof(src0_t);
size_t s0 = nb0 / sizeof(dst_t);
size_t s1 = nb1 / sizeof(dst_t);
size_t s2 = nb2 / sizeof(dst_t);
size_t s3 = nb3 / sizeof(dst_t);
std::size_t num_dst_elements = static_cast<std::size_t>(ne0) * static_cast<std::size_t>(ne1) *
static_cast<std::size_t>(ne2) * static_cast<std::size_t>(ne3);
std::size_t local_range = 256;
std::size_t global_range = ceil_div(num_dst_elements, local_range) * local_range;
size_t s10 = nb10 / sizeof(src1_t);
size_t s11 = nb11 / sizeof(src1_t);
size_t s12 = nb12 / sizeof(src1_t);
size_t s13 = nb13 / sizeof(src1_t);
bool needs_broadcasting = check_bcast_required({ ne00, ne01, ne02, ne03 }, { ne0, ne1, ne2, ne3 }) ||
check_bcast_required({ ne10, ne11, ne12, ne13 }, { ne0, ne1, ne2, ne3 });
bool all_contiguous = src0_is_contiguous && src1_is_contiguous && dst_is_contiguous;
size_t s00 = nb00 / sizeof(src0_t);
size_t s01 = nb01 / sizeof(src0_t);
size_t s02 = nb02 / sizeof(src0_t);
size_t s03 = nb03 / sizeof(src0_t);
if (! needs_broadcasting && all_contiguous) {
stream->submit([&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), [=](sycl::nd_item<1> it) {
k_bin_bcast_contiguous<bin_op>(src0_dd, src1_dd, dst_dd, num_dst_elements, it);
});
});
} else {
stream->submit([&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), [=](sycl::nd_item<1> it) {
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, ne10, ne11, ne12, ne13, s0, s1,
s2, s3, s00, s01, s02, s03, s10, s11, s12, s13, num_dst_elements, it);
});
});
GGML_UNUSED(s00);
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
GGML_ASSERT(s0 == 1);
GGML_ASSERT(s10 == 1);
const int block_size = 128;
int64_t hne0 = std::max(ne0/2LL, 1LL);
sycl::range<3> block_dims(1, 1, 1);
block_dims[2] = std::min<unsigned int>(hne0, block_size);
block_dims[1] = std::min<unsigned int>(
ne1, block_size / (unsigned int)block_dims[2]);
block_dims[0] = std::min(
std::min<unsigned int>(
ne2 * ne3, block_size / (unsigned int)block_dims[2] /
(unsigned int)block_dims[1]),
64U);
sycl::range<3> block_nums(
(ne2 * ne3 + block_dims[0] - 1) / block_dims[0],
(ne1 + block_dims[1] - 1) / block_dims[1],
(hne0 + block_dims[2] - 1) / block_dims[2]);
if (block_nums[0] > 65535) {
// this is the maximum number of blocks in z direction, fallback to 1D grid kernel
int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
sycl::range<3>(1, 1, block_size),
sycl::range<3>(1, 1, block_size)),
[=](sycl::nd_item<3> item_ct1) {
k_bin_bcast_unravel<bin_op>(
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02,
s03, s11, s12, s13, item_ct1);
});
}
} else {
/*
DPCT1049:16: The work-group size passed to the SYCL kernel may
exceed the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if
needed.
*/
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
ne2, ne3, ne10, ne11, ne12, ne13,
s1, s2, s3, s01, s02, s03, s11, s12, s13,
item_ct1);
});
}
}
}
};
+35 -3
View File
@@ -3740,7 +3740,7 @@ static void ggml_backend_sycl_get_tensor_async(ggml_backend_t backend,
GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type");
const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(
data, (const char *)tensor->data + offset, size).wait()));
data, (const char *)tensor->data + offset, size)));
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -3760,7 +3760,7 @@ static bool ggml_backend_sycl_cpy_tensor_async(ggml_backend_t backend,
*/
const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(
dst->data, src->data, ggml_nbytes(dst)).wait()));
dst->data, src->data, ggml_nbytes(dst))));
return true;
}
@@ -3809,11 +3809,43 @@ static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * syc
}
}
#ifdef GGML_SYCL_GRAPH
static bool check_graph_compatibility(ggml_cgraph * cgraph) {
if (ggml_sycl_info().device_count > 1) {
// A sycl_ex::command_graph object can only be created for a single device
GGML_LOG_INFO("%s: disabling SYCL graphs due to multiple devices\n", __func__);
return false;
}
for (int i = 0; i < cgraph->n_nodes; i++) {
const ggml_op node_op = cgraph->nodes[i]->op;
switch (node_op) {
default:
break;
case GGML_OP_CONCAT:
// ggml_sycl_op_concat() does a blocking host wait after memcpy operations,
// but wait() can't be called on the events returned by a queue recording
// to a graph.
[[fallthrough]];
case GGML_OP_MUL_MAT_ID:
// ggml_sycl_mul_mat_id() does a blocking host wait on the sycl queue after
// submitting a memcpy operation, but wait() can't be called on a queue that
// is recording to a graph.
GGML_LOG_INFO("%s: disabling SYCL graphs due to unsupported node type %s\n", __func__,
ggml_op_name(node_op));
return false;
}
}
return true;
}
#endif
static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
auto * sycl_ctx = static_cast<ggml_backend_sycl_context *>(backend->context);
#ifdef GGML_SYCL_GRAPH
if (!g_ggml_sycl_disable_graph) {
bool use_sycl_graph = !g_ggml_sycl_disable_graph && check_graph_compatibility(cgraph);
if (use_sycl_graph) {
const bool graph_support = dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_limited_graph);
if (!graph_support) {
GGML_SYCL_DEBUG("[SYCL-GRAPH] can not use graphs on device:%d\n", sycl_ctx->device);
+51 -4
View File
@@ -2804,23 +2804,29 @@ static vk_device ggml_vk_get_device(size_t idx) {
pipeline_robustness = true;
} else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
device->subgroup_size_control = true;
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
} else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 &&
!getenv("GGML_VK_DISABLE_COOPMAT")) {
device->coopmat_support = true;
device->coopmat_m = 0;
device->coopmat_n = 0;
device->coopmat_k = 0;
#endif
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
} else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
!getenv("GGML_VK_DISABLE_COOPMAT2")) {
coopmat2_support = true;
#endif
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
} else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
!getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
device->integer_dot_product = true;
#endif
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
} else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 &&
!getenv("GGML_VK_DISABLE_BFLOAT16")) {
bfloat16_support = true;
#endif
}
}
@@ -4670,6 +4676,19 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
}
}
if (src->type == to) {
// Copy two or four bytes at a time, depending on block size.
// For quantized types, we scale by block size/type size. But
// this path is also used for bf16->bf16 for example, where the
// type size must be exactly 2 or 4.
GGML_ASSERT(ggml_is_quantized(to) || ggml_type_size(src->type) == 2 || ggml_type_size(src->type) == 4);
if ((ggml_type_size(src->type) % 4) == 0) {
return ctx->device->pipeline_contig_cpy_f32_f32;
} else {
return ctx->device->pipeline_contig_cpy_f16_f16;
}
}
std::cerr << "Missing CPY op for types: " << ggml_type_name(src->type) << " " << ggml_type_name(to) << std::endl;
GGML_ABORT("fatal error");
}
@@ -6731,7 +6750,16 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
case GGML_OP_UNARY:
case GGML_OP_CONV_2D_DW:
{
const uint32_t ne = ggml_nelements(dst);
uint32_t ne = ggml_nelements(dst);
if (op == GGML_OP_CPY && ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
// Convert from number of logical elements to 2- or 4-byte units.
ne /= ggml_blck_size(src0->type);
if ((ggml_type_size(src0->type) % 4) == 0) {
ne *= ggml_type_size(src0->type) / 4;
} else {
ne *= ggml_type_size(src0->type) / 2;
}
}
if (ne > 262144) {
elements = { 512, 512, CEIL_DIV(ne, 262144) };
} else if (ne > 512) {
@@ -7281,8 +7309,19 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
uint32_t ne = (uint32_t)ggml_nelements(src0);
if (ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
// Convert from number of logical elements to 2- or 4-byte units.
ne /= ggml_blck_size(src0->type);
if ((ggml_type_size(src0->type) % 4) == 0) {
ne *= ggml_type_size(src0->type) / 4;
} else {
ne *= ggml_type_size(src0->type) / 2;
}
}
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, {
(uint32_t)ggml_nelements(src0),
ne,
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
@@ -9264,8 +9303,7 @@ static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_buffer(ggml_
try {
ptr = ggml_vk_host_malloc(vk_instance.devices[0], size);
} catch (vk::SystemError& e) {
std::cerr << "ggml_vulkan: Failed to allocate pinned memory." << std::endl;
std::cerr << "ggml_vulkan: " << e.what() << std::endl;
GGML_LOG_WARN("ggml_vulkan: Failed to allocate pinned memory (%s)\n", e.what());
// fallback to cpu buffer
return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
}
@@ -9867,6 +9905,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
return true;
}
// We can handle copying from a type to the same type if it's
// contiguous (memcpy). We use f16 or f32 shaders to do the copy,
// so the type/block size must be a multiple of 4.
if (src0_type == src1_type &&
ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op) &&
(ggml_type_size(src0_type) % 2) == 0) {
return true;
}
return false;
} break;
case GGML_OP_REPEAT:
+16 -1
View File
@@ -1099,9 +1099,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
"HARDSWISH",
"HARDSIGMOID",
"EXP",
"GELU_ERF",
};
static_assert(GGML_UNARY_OP_COUNT == 14, "GGML_UNARY_OP_COUNT != 14");
static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15");
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@@ -2501,6 +2502,20 @@ struct ggml_tensor * ggml_gelu_inplace(
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU);
}
// ggml_gelu_erf
struct ggml_tensor * ggml_gelu_erf(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_unary(ctx, a, GGML_UNARY_OP_GELU_ERF);
}
struct ggml_tensor * ggml_gelu_erf_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU_ERF);
}
// ggml_gelu_quick
struct ggml_tensor * ggml_gelu_quick(
+76 -7
View File
@@ -219,10 +219,13 @@ class Keys:
TYPE = "adapter.type"
LORA_ALPHA = "adapter.lora.alpha"
class ClipVision:
class Clip:
PROJECTOR_TYPE = "clip.projector_type"
HAS_VISION_ENCODER = "clip.has_vision_encoder"
HAS_AUDIO_ENCODER = "clip.has_audio_encoder"
HAS_LLAVA_PROJECTOR = "clip.has_llava_projector"
class ClipVision:
IMAGE_SIZE = "clip.vision.image_size"
PATCH_SIZE = "clip.vision.patch_size"
EMBEDDING_LENGTH = "clip.vision.embedding_length"
@@ -243,19 +246,33 @@ class Keys:
class Projector:
SCALE_FACTOR = "clip.vision.projector.scale_factor"
class ClipAudio:
NUM_MEL_BINS = "clip.audio.num_mel_bins"
EMBEDDING_LENGTH = "clip.audio.embedding_length"
FEED_FORWARD_LENGTH = "clip.audio.feed_forward_length"
PROJECTION_DIM = "clip.audio.projection_dim"
BLOCK_COUNT = "clip.audio.block_count"
class Attention:
HEAD_COUNT = "clip.audio.attention.head_count"
LAYERNORM_EPS = "clip.audio.attention.layer_norm_epsilon"
class Projector:
STACK_FACTOR = "clip.audio.projector.stack_factor"
#
# recommended mapping of model tensor names for storage in gguf
#
class GGUFType:
MODEL = "model"
ADAPTER = "adapter"
CLIP_VISION = "clip-vision"
MODEL = "model"
ADAPTER = "adapter"
MMPROJ = "mmproj" # dummy, unused for now
class MODEL_ARCH(IntEnum):
CLIP_VISION = auto() # dummy arch for clip.cpp
MMPROJ = auto() # dummy arch for clip.cpp
LLAMA = auto()
LLAMA4 = auto()
DECI = auto()
@@ -514,10 +531,27 @@ class MODEL_TENSOR(IntEnum):
V_RESMPL_QUERY = auto() # minicpmv
V_TOK_EMBD_IMG_BREAK = auto() # pixtral
V_MM_PATCH_MERGER = auto() # mistral small 3.1
# audio (mtmd)
A_ENC_EMBD_POS = auto()
A_ENC_CONV1D = auto()
A_PRE_NORM = auto()
A_POST_NORM = auto()
A_ENC_ATTN_Q = auto()
A_ENC_ATTN_K = auto()
A_ENC_ATTN_V = auto()
A_ENC_INPUT_NORM = auto()
A_ENC_OUTPUT = auto()
A_ENC_OUTPUT_NORM = auto()
A_ENC_FFN_UP = auto()
A_ENC_FFN_GATE = auto()
A_ENC_FFN_DOWN = auto()
A_MMPROJ = auto()
A_MM_NORM_PRE = auto()
A_MM_NORM_MID = auto()
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.CLIP_VISION: "clip", # dummy arch for clip.cpp
MODEL_ARCH.MMPROJ: "clip", # dummy arch for clip.cpp
MODEL_ARCH.LLAMA: "llama",
MODEL_ARCH.LLAMA4: "llama4",
MODEL_ARCH.DECI: "deci",
@@ -776,10 +810,27 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.V_RESMPL_QUERY: "resampler.query",
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: "v.token_embd.img_break", # pixtral
MODEL_TENSOR.V_MM_PATCH_MERGER: "mm.patch_merger", # mistral small 3.1
# audio (mtmd)
MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd",
MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}",
MODEL_TENSOR.A_PRE_NORM: "a.pre_ln",
MODEL_TENSOR.A_POST_NORM: "a.post_ln",
MODEL_TENSOR.A_ENC_ATTN_Q: "a.blk.{bid}.attn_q",
MODEL_TENSOR.A_ENC_ATTN_K: "a.blk.{bid}.attn_k",
MODEL_TENSOR.A_ENC_ATTN_V: "a.blk.{bid}.attn_v",
MODEL_TENSOR.A_ENC_INPUT_NORM: "a.blk.{bid}.ln1",
MODEL_TENSOR.A_ENC_OUTPUT: "a.blk.{bid}.attn_out",
MODEL_TENSOR.A_ENC_OUTPUT_NORM: "a.blk.{bid}.ln2",
MODEL_TENSOR.A_ENC_FFN_UP: "a.blk.{bid}.ffn_up",
MODEL_TENSOR.A_ENC_FFN_GATE: "a.blk.{bid}.ffn_gate",
MODEL_TENSOR.A_ENC_FFN_DOWN: "a.blk.{bid}.ffn_down",
MODEL_TENSOR.A_MMPROJ: "mm.a.mlp.{bid}",
MODEL_TENSOR.A_MM_NORM_PRE: "mm.a.norm_pre",
MODEL_TENSOR.A_MM_NORM_MID: "mm.a.norm_mid",
}
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_ARCH.CLIP_VISION: [
MODEL_ARCH.MMPROJ: [
MODEL_TENSOR.V_MMPROJ,
MODEL_TENSOR.V_MMPROJ_FC,
MODEL_TENSOR.V_MMPROJ_MLP,
@@ -819,6 +870,23 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.V_RESMPL_QUERY,
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK,
MODEL_TENSOR.V_MM_PATCH_MERGER,
# audio
MODEL_TENSOR.A_ENC_EMBD_POS,
MODEL_TENSOR.A_ENC_CONV1D,
MODEL_TENSOR.A_PRE_NORM,
MODEL_TENSOR.A_POST_NORM,
MODEL_TENSOR.A_ENC_ATTN_Q,
MODEL_TENSOR.A_ENC_ATTN_K,
MODEL_TENSOR.A_ENC_ATTN_V,
MODEL_TENSOR.A_ENC_INPUT_NORM,
MODEL_TENSOR.A_ENC_OUTPUT,
MODEL_TENSOR.A_ENC_OUTPUT_NORM,
MODEL_TENSOR.A_ENC_FFN_UP,
MODEL_TENSOR.A_ENC_FFN_GATE,
MODEL_TENSOR.A_ENC_FFN_DOWN,
MODEL_TENSOR.A_MMPROJ,
MODEL_TENSOR.A_MM_NORM_PRE,
MODEL_TENSOR.A_MM_NORM_MID,
],
MODEL_ARCH.LLAMA: [
MODEL_TENSOR.TOKEN_EMBD,
@@ -2186,6 +2254,7 @@ class VisionProjectorType:
LLAMA4 = "llama4"
QWEN2VL = "qwen2vl_merger"
QWEN25VL = "qwen2.5vl_merger"
ULTRAVOX = "ultravox"
INTERNVL = "internvl"
+1 -1
View File
@@ -251,7 +251,7 @@ class GGUFReader:
offs += curr_size
return offs - orig_offs, aparts, data_idxs, types
# We can't deal with this one.
raise ValueError('Unknown/unhandled field type {gtype}')
raise ValueError(f'Unknown/unhandled field type {gtype}')
def _get_tensor_info_field(self, orig_offs: int) -> ReaderField:
offs = orig_offs
+36 -7
View File
@@ -896,7 +896,7 @@ class GGUFWriter:
def add_remove_extra_whitespaces(self, value: bool) -> None:
self.add_bool(Keys.Tokenizer.REMOVE_EXTRA_WS, value)
def add_precompiled_charsmap(self, charsmap: Sequence[bytes]) -> None:
def add_precompiled_charsmap(self, charsmap: bytes) -> None:
self.add_array(Keys.Tokenizer.PRECOMPILED_CHARSMAP, charsmap)
def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None:
@@ -936,12 +936,18 @@ class GGUFWriter:
# for vision models
def add_clip_has_vision_encoder(self, value: bool) -> None:
self.add_bool(Keys.Clip.HAS_VISION_ENCODER, value)
def add_clip_has_audio_encoder(self, value: bool) -> None:
self.add_bool(Keys.Clip.HAS_AUDIO_ENCODER, value)
def add_clip_projector_type(self, value: str) -> None:
self.add_string(Keys.Clip.PROJECTOR_TYPE, value)
def add_vision_projection_dim(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.PROJECTION_DIM, value)
def add_vision_has_vision_encoder(self, value: bool) -> None:
self.add_bool(Keys.ClipVision.HAS_VISION_ENCODER, value)
def add_vision_patch_size(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.PATCH_SIZE, value)
@@ -957,9 +963,6 @@ class GGUFWriter:
def add_vision_head_count(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.Attention.HEAD_COUNT, value)
def add_vision_projector_type(self, value: str) -> None:
self.add_string(Keys.ClipVision.PROJECTOR_TYPE, value)
def add_vision_attention_layernorm_eps(self, value: float) -> None:
self.add_float32(Keys.ClipVision.Attention.LAYERNORM_EPS, value)
@@ -987,6 +990,32 @@ class GGUFWriter:
def add_vision_n_wa_pattern(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value)
# audio models
def add_audio_projection_dim(self, value: int) -> None:
self.add_uint32(Keys.ClipAudio.PROJECTION_DIM, value)
def add_audio_embedding_length(self, value: int) -> None:
self.add_uint32(Keys.ClipAudio.EMBEDDING_LENGTH, value)
def add_audio_feed_forward_length(self, value: int) -> None:
self.add_uint32(Keys.ClipAudio.FEED_FORWARD_LENGTH, value)
def add_audio_block_count(self, value: int) -> None:
self.add_uint32(Keys.ClipAudio.BLOCK_COUNT, value)
def add_audio_head_count(self, value: int) -> None:
self.add_uint32(Keys.ClipAudio.Attention.HEAD_COUNT, value)
def add_audio_attention_layernorm_eps(self, value: float) -> None:
self.add_float32(Keys.ClipAudio.Attention.LAYERNORM_EPS, value)
def add_audio_num_mel_bins(self, value: int) -> None:
self.add_uint32(Keys.ClipAudio.NUM_MEL_BINS, value)
def add_audio_stack_factor(self, value: int) -> None:
self.add_uint32(Keys.ClipAudio.Projector.STACK_FACTOR, value)
def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
pack_prefix = ''
if not skip_pack_prefix:
+62
View File
@@ -1110,6 +1110,68 @@ class TensorNameMap:
MODEL_TENSOR.V_MM_PATCH_MERGER: (
"multi_modal_projector.patch_merger.merging_layer", # mistral small 3.1
),
# audio (mtmd)
MODEL_TENSOR.A_ENC_EMBD_POS: (
"audio_tower.embed_positions", # ultravox
),
MODEL_TENSOR.A_ENC_CONV1D: (
"audio_tower.conv{bid}", # ultravox
),
MODEL_TENSOR.A_PRE_NORM: (),
MODEL_TENSOR.A_POST_NORM: (
"audio_tower.layer_norm", # ultravox
),
MODEL_TENSOR.A_ENC_ATTN_Q: (
"audio_tower.layers.{bid}.self_attn.q_proj", # ultravox
),
MODEL_TENSOR.A_ENC_ATTN_K: (
"audio_tower.layers.{bid}.self_attn.k_proj", # ultravox
),
MODEL_TENSOR.A_ENC_ATTN_V: (
"audio_tower.layers.{bid}.self_attn.v_proj", # ultravox
),
MODEL_TENSOR.A_ENC_INPUT_NORM: (
"audio_tower.layers.{bid}.self_attn_layer_norm", # ultravox
),
MODEL_TENSOR.A_ENC_OUTPUT: (
"audio_tower.layers.{bid}.self_attn.out_proj", # ultravox
),
MODEL_TENSOR.A_ENC_OUTPUT_NORM: (
"audio_tower.layers.{bid}.final_layer_norm", # ultravox
),
MODEL_TENSOR.A_ENC_FFN_UP: (
"audio_tower.layers.{bid}.fc1", # ultravox
),
MODEL_TENSOR.A_ENC_FFN_GATE: (),
MODEL_TENSOR.A_ENC_FFN_DOWN: (
"audio_tower.layers.{bid}.fc2", # ultravox
),
MODEL_TENSOR.A_MMPROJ: (
"audio.multi_modal_projector.linear_{bid}", # ultravox
),
MODEL_TENSOR.A_MM_NORM_PRE: (
"audio.multi_modal_projector.ln_pre", # ultravox
),
MODEL_TENSOR.A_MM_NORM_MID: (
"audio.multi_modal_projector.ln_mid", # ultravox
),
}
# architecture-specific block mappings
+62
View File
@@ -0,0 +1,62 @@
{%- if tools %}
{{- '<|im_start|>system\n' }}
{%- if messages[0]['role'] == 'system' %}
{{- messages[0]['content'] }}
{%- else %}
{{- '' }}
{%- endif %}
{{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
{%- for tool in tools %}
{{- "\n" }}
{{- tool | tojson }}
{%- endfor %}
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
{%- else %}
{%- if messages[0]['role'] == 'system' %}
{{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- for message in messages %}
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
{%- elif message.role == "assistant" and not message.tool_calls %}
{%- set content = message.content %}
{%- if not loop.last %}
{%- set content = message.content.split('</think>')[-1].lstrip('\n') %}
{%- endif %}
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
{%- elif message.role == "assistant" %}
{%- set content = message.content %}
{%- if not loop.last %}
{%- set content = message.content.split('</think>')[-1].lstrip('\n') %}
{%- endif %}
{{- '<|im_start|>' + message.role }}
{%- if message.content %}
{{- '\n' + content }}
{%- endif %}
{%- for tool_call in message.tool_calls %}
{%- if tool_call.function is defined %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- '\n<tool_call>\n{"name": "' }}
{{- tool_call.name }}
{{- '", "arguments": ' }}
{{- tool_call.arguments | tojson }}
{{- '}\n</tool_call>' }}
{%- endfor %}
{{- '<|im_end|>\n' }}
{%- elif message.role == "tool" %}
{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}
{{- '<|im_start|>user' }}
{%- endif %}
{{- '\n<tool_response>\n' }}
{{- message.content }}
{{- '\n</tool_response>' }}
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
{{- '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n<think>\n' }}
{%- endif %}
+1
View File
@@ -19,4 +19,5 @@ These templates can be updated with the following commands:
./scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use > models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja
./scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use > models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja
./scripts/get_chat_template.py Qwen/Qwen2.5-7B-Instruct > models/templates/Qwen-Qwen2.5-7B-Instruct.jinja
./scripts/get_chat_template.py Qwen/QwQ-32B > models/templates/Qwen-QwQ-32B.jinja
```
@@ -1,3 +1,7 @@
-r ./requirements-convert_legacy_llama.txt
--extra-index-url https://download.pytorch.org/whl/cpu
torch~=2.2.1
torch~=2.2.1; platform_machine != "s390x"
# torch s390x packages can only be found from nightly builds
--extra-index-url https://download.pytorch.org/whl/nightly
torch>=0.0.0.dev0; platform_machine == "s390x"
@@ -1,3 +1,7 @@
-r ./requirements-convert_legacy_llama.txt
--extra-index-url https://download.pytorch.org/whl/cpu
torch~=2.2.1
torch~=2.2.1; platform_machine != "s390x"
# torch s390x packages can only be found from nightly builds
--extra-index-url https://download.pytorch.org/whl/nightly
torch>=0.0.0.dev0; platform_machine == "s390x"
@@ -1,2 +1,4 @@
-r ./requirements-convert_hf_to_gguf.txt
--extra-index-url https://download.pytorch.org/whl/cpu
# torch s390x packages can only be found from nightly builds
--extra-index-url https://download.pytorch.org/whl/nightly
+11
View File
@@ -12,6 +12,7 @@
export LLAMA_SERVER_BIN_PATH=$PWD/build/bin/llama-server
export LLAMA_CACHE=${LLAMA_CACHE:-$HOME/Library/Caches/llama.cpp}
./scripts/tool_bench.py run --n 10 --temp -1 --temp 0 --temp 1 --temp 2 --temp 5 --llama-baseline $PWD/buildMaster/bin/llama-server --output qwen14b.jsonl --hf bartowski/Qwen2.5-14B-Instruct-GGUF:Q4_K_L
./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 1.5B Q4_K_M" --output qwen1.5b.jsonl --hf bartowski/Qwen2.5-1.5B-Instruct-GGUF --ollama qwen2.5:1.5b-instruct-q4_K_M
./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 Coder 7B Q4_K_M" --output qwenc7b.jsonl --hf bartowski/Qwen2.5-Coder-7B-Instruct-GGUF --ollama qwen2.5-coder:7b
@@ -205,6 +206,7 @@ def run(
model: Annotated[Optional[str], typer.Option(help="Name of the model to test (server agnostic)")] = None,
hf: Annotated[Optional[str], typer.Option(help="GGUF huggingface model repo id (+ optional quant) to test w/ llama-server")] = None,
chat_template: Annotated[Optional[str], typer.Option(help="Chat template override for llama-server")] = None,
chat_template_file: Annotated[Optional[str], typer.Option(help="Chat template file override for llama-server")] = None,
ollama: Annotated[Optional[str], typer.Option(help="Ollama model tag to test")] = None,
llama_baseline: Annotated[Optional[str], typer.Option(help="llama-server baseline binary path to use as baseline")] = None,
n: Annotated[int, typer.Option(help="Number of times to run each test")] = 10,
@@ -229,6 +231,12 @@ def run(
# n_ctx = 8192
n_ctx = 2048
if model is None:
if hf is not None:
model = hf.split("/")[-1]
elif ollama is not None:
model = ollama
assert force or append or not output.exists(), f"Output file already exists: {output}; use --force to overwrite"
with output.open('a' if append else 'w') as output_file:
@@ -320,6 +328,7 @@ def run(
server.model_hf_repo = hf
server.model_hf_file = None
server.chat_template = chat_template
server.chat_template_file = chat_template_file
server.server_path = server_path
if port is not None:
server.server_port = port
@@ -335,6 +344,7 @@ def run(
temp=t,
output_kwargs=dict(
chat_template=chat_template,
chat_template_file=chat_template_file,
),
request_kwargs=dict(
ignore_chat_grammar=ignore_chat_grammar,
@@ -355,6 +365,7 @@ def run(
temp=t,
output_kwargs=dict(
chat_template=None,
chat_template_file=None,
),
request_kwargs=dict(
model=ollama,
+12 -2
View File
@@ -1177,8 +1177,18 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
for (const auto & trigger_pattern : grammar.trigger_patterns) {
if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
grammar.awaiting_trigger = false;
// get from the first match to the end of the string
auto constrained_str = grammar.trigger_buffer.substr(match.position(1));
// get from the first matched capturing group to the end of the string
size_t start = std::string::npos;
for (auto i = 1u; i < match.size(); i++) {
if (match.length(i) > 0) {
start = match.position(i);
break;
}
}
if (start == std::string::npos) {
start = match.position(0);
}
auto constrained_str = grammar.trigger_buffer.substr(start);
// std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
grammar.trigger_buffer.clear();
llama_grammar_accept_str(grammar, constrained_str);
+4 -4
View File
@@ -1287,6 +1287,10 @@ ggml_tensor * llm_graph_context::build_attn(
if (wo) {
cur = build_lora_mm(wo, cur);
if (arch == LLM_ARCH_GLM4) {
// GLM4 seems to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
}
if (wo_b) {
@@ -1367,10 +1371,6 @@ ggml_tensor * llm_graph_context::build_attn(
if (wo) {
cur = build_lora_mm(wo, cur);
if (arch == LLM_ARCH_GLM4) {
// GLM4 seems to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
}
if (wo_b) {
+17 -1
View File
@@ -2,6 +2,22 @@
#include "ggml.h"
void llama_hparams::set_swa_pattern(uint32_t n_pattern) {
for (uint32_t il = 0; il < n_layer; ++il) {
swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1));
}
}
bool llama_hparams::is_swa_any() const {
for (uint32_t il = 0; il < n_layer; ++il) {
if (swa_layers[il]) {
return true;
}
}
return false;
}
uint32_t llama_hparams::n_head(uint32_t il) const {
if (il < n_layer) {
return n_head_arr[il];
@@ -72,7 +88,7 @@ uint32_t llama_hparams::n_embd_v_s() const {
bool llama_hparams::is_swa(uint32_t il) const {
if (il < n_layer) {
return n_swa > 0 && n_swa_pattern > 0 && il % n_swa_pattern < (n_swa_pattern - 1);
return swa_layers[il];
}
GGML_ABORT("fatal error");
+23 -3
View File
@@ -102,9 +102,12 @@ struct llama_hparams {
// Sliding Window Attention (SWA)
llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
uint32_t n_swa = 0; // the size of the sliding window (0 - no SWA)
uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention
// the size of the sliding window (0 - no SWA)
uint32_t n_swa = 0;
// if swa_layers[il] == true, then layer il is SWA
// if swa_layers[il] == false, then layer il is dense (i.e. non-SWA)
// by default, all layers are dense
std::array<bool, LLAMA_MAX_LAYERS> swa_layers;
// for State Space Models
uint32_t ssm_d_conv = 0;
@@ -142,6 +145,23 @@ struct llama_hparams {
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
// this value n_pattern means that every nth layer is dense (i.e. non-SWA)
// note that if n_pattern == 0, all layers are SWA
// if n_pattern == 1, all layers are dense
// example: n_pattern = 3
// il == 0: swa
// il == 1: swa
// il == 2: dense
// il == 3: swa
// il == 4: swa
// il == 5: dense
// il == 6: swa
// etc ...
void set_swa_pattern(uint32_t n_pattern);
// return true if one of the layers is SWA
bool is_swa_any() const;
uint32_t n_head(uint32_t il = 0) const;
uint32_t n_head_kv(uint32_t il = 0) const;
+18 -10
View File
@@ -463,11 +463,14 @@ void llama_model::load_hparams(llama_model_loader & ml) {
GGML_ASSERT(hparams.n_expert_used == 0);
}
// zero-out the array hparams
std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0);
std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0);
std::fill(hparams.swa_layers.begin(), hparams.swa_layers.end(), 0);
ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false);
ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false);
@@ -574,7 +577,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED;
hparams.n_swa = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick
hparams.n_swa_pattern = 4; // pattern: 3 chunked - 1 full
hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full
switch (hparams.n_expert) {
case 16: type = LLM_TYPE_17B_16E; break;
@@ -863,7 +866,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
hparams.n_swa = 0;
hparams.n_swa_pattern = 1;
hparams.set_swa_pattern(1);
}
} break;
case LLM_ARCH_PHIMOE:
@@ -935,7 +938,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
{
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
hparams.n_swa = 4096; // default value of gemma 2
hparams.n_swa_pattern = 2;
hparams.set_swa_pattern(2);
hparams.attn_soft_cap = true;
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
@@ -953,7 +956,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
case LLM_ARCH_GEMMA3:
{
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
hparams.n_swa_pattern = 6;
hparams.set_swa_pattern(6);
hparams.rope_freq_base_train_swa = 10000.0f;
hparams.rope_freq_scale_train_swa = 1.0f;
@@ -1038,7 +1041,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
case LLM_ARCH_COHERE2:
{
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
hparams.n_swa_pattern = 4;
hparams.set_swa_pattern(4);
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
@@ -2486,7 +2489,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
// if output is NULL, init from the input tok embed
if (output == NULL) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
}
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
@@ -4320,7 +4327,7 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str());
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot);
LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa);
LLAMA_LOG_INFO("%s: n_swa_pattern = %u\n", __func__, hparams.n_swa_pattern);
LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any());
LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k);
LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v);
LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str());
@@ -13189,6 +13196,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
case LLM_ARCH_JINA_BERT_V2:
case LLM_ARCH_NOMIC_BERT:
case LLM_ARCH_NOMIC_BERT_MOE:
case LLM_ARCH_WAVTOKENIZER_DEC:
{
res = nullptr;
} break;
@@ -13215,7 +13223,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
GGML_ASSERT(hparams.n_swa_pattern != 1);
GGML_ASSERT(hparams.is_swa_any());
res = new llama_kv_cache_unified_iswa(
*this,
@@ -13229,7 +13237,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
cparams.n_batch,
padding);
} else {
GGML_ASSERT(hparams.n_swa_pattern == 1);
GGML_ASSERT(!hparams.is_swa_any());
res = new llama_kv_cache_unified(
*this,
+4 -4
View File
@@ -835,7 +835,7 @@ struct llm_tokenizer_ugm_session {
}
// initialize score_sum to -FLT_MAX so it will be always lower than sums of token scores
std::vector<struct best_tokenization> tokenization_results(input_len + 1, {vocab.token_unk(), 0, -FLT_MAX});
std::vector<struct best_tokenization> tokenization_results(input_len + 1, {vocab.token_unk(), 0, -DBL_MAX});
// at the beginning tokenization score is zero
tokenization_results[0] = { vocab.token_unk(), 0, 0 };
@@ -867,7 +867,7 @@ struct llm_tokenizer_ugm_session {
const double challenger_score = current_best.score_sum + token_score;
struct best_tokenization & current_champ = tokenization_results[prefix_offset];
if (challenger_score > current_champ.score_sum) {
struct best_tokenization challenger = { token_id, input_offset, (float) challenger_score };
struct best_tokenization challenger = { token_id, input_offset, challenger_score };
current_champ = challenger;
}
}
@@ -881,7 +881,7 @@ struct llm_tokenizer_ugm_session {
prefix_offset = input_offset + n_utf8_code_units;
struct best_tokenization & current_champ = tokenization_results[prefix_offset];
if (challenger_score > current_champ.score_sum) {
struct best_tokenization challenger = { vocab.token_unk(), input_offset, (float) challenger_score };
struct best_tokenization challenger = { vocab.token_unk(), input_offset, challenger_score };
current_champ = challenger;
}
}
@@ -1007,7 +1007,7 @@ private:
struct best_tokenization {
llama_token token_id;
size_t input_offset;
float score_sum;
double score_sum;
};
struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) {
+3 -1
View File
@@ -142,8 +142,10 @@ if (NOT WIN32)
# llama_build_and_test(test-double-float.cpp) # SLOW
endif()
llama_build_and_test(test-log.cpp)
llama_build_and_test(test-chat-parser.cpp)
llama_build_and_test(test-chat-template.cpp)
llama_build_and_test(test-json-partial.cpp)
llama_build_and_test(test-log.cpp)
llama_build_and_test(test-regex-partial.cpp)
# this fails on windows (github hosted runner) due to curl DLL not found (exit code 0xc0000135)
+355
View File
@@ -0,0 +1,355 @@
// Tests chat handling, including grammar generation and parsing for tool calling, for various templates.
//
// Also acts as a CLI to generate a Markdown summary of the formats of Jinja templates,
// e.g. given Minja (http://github.com/google/minja) checked out in parent dir:
//
// cmake -B build && cmake --build build --parallel && ./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null
//
#include <exception>
#include <iostream>
#include <json.hpp>
#include <string>
#include "chat-parser.h"
#include "common.h"
#include "log.h"
#include "regex-partial.h"
using json = nlohmann::ordered_json;
template <class T>
static void assert_equals(const T & expected, const T & actual) {
if (expected != actual) {
std::cerr << "Expected: " << expected << std::endl;
std::cerr << "Actual: " << actual << std::endl;
std::cerr << std::flush;
throw std::runtime_error("Test failed");
}
}
static void assert_equals(const char * expected, const std::string & actual) {
return assert_equals<std::string>(expected, actual);
}
static void assert_throws(const std::function<void()> & fn, const std::string & expected_exception_pattern = "") {
try {
fn();
} catch (const std::exception & e) {
if (expected_exception_pattern.empty()) {
return;
}
std::regex expected_exception_regex(expected_exception_pattern);
std::string actual_message = e.what();
if (std::regex_search(actual_message, expected_exception_regex)) {
return;
}
throw std::runtime_error("Exception doesn't match expected pattern: " + actual_message + " (pattern: " + expected_exception_pattern + ")");
throw std::runtime_error("Exception of unexpected type: " + std::string(e.what()));
}
throw std::runtime_error("Exception was expected but not thrown");
}
static void test_reasoning() {
{
common_chat_msg_parser builder("<tnk>Cogito</tnk>Ergo sum", /* is_partial= */ false, {
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
/* .reasoning_in_content = */ false,
/* .thinking_forced_open = */ false,
});
assert_equals(false, builder.try_parse_reasoning("<tnk>", "</tnk>"));
assert_equals("<tnk>Cogito</tnk>Ergo sum", builder.consume_rest());
}
{
common_chat_msg_parser builder("<tnk>Cogito</tnk>Ergo sum", /* is_partial= */ false, {
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
/* .reasoning_in_content = */ false,
/* .thinking_forced_open = */ false,
});
assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
assert_equals(std::string("Cogito"), builder.result().reasoning_content);
assert_equals("Ergo sum", builder.consume_rest());
}
{
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
/* .reasoning_in_content = */ false,
/* .thinking_forced_open = */ false,
});
assert_equals(false, builder.try_parse_reasoning("<tnk>", "</tnk>"));
assert_equals("Cogito</tnk>Ergo sum", builder.consume_rest());
}
{
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
/* .reasoning_in_content = */ false,
/* .thinking_forced_open = */ true,
});
assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
assert_equals(std::string("Cogito"), builder.result().reasoning_content);
assert_equals("Ergo sum", builder.consume_rest());
}
{
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
/* .reasoning_in_content = */ true,
/* .thinking_forced_open = */ true,
});
assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
assert_equals("<think>Cogito</think>", builder.result().content);
assert_equals("Ergo sum", builder.consume_rest());
}
}
static void test_regex() {
auto test_throws = [](const std::string & input, const std::string & regex, const std::string & expected_exception_pattern = "") {
common_chat_msg_parser builder(input, /* is_partial= */ false, {});
assert_throws([&]() { builder.consume_regex(common_regex(regex)); }, expected_exception_pattern);
};
test_throws("Hello, world!", "abc", "^abc$");
test_throws("Hello, world!", "e", "^e$");
{
common_chat_msg_parser builder("Hello, world!", /* is_partial= */ false, {});
builder.consume_regex(common_regex("Hello"));
assert_equals(", world!", builder.consume_rest());
}
{
// When in non partial mode, we can say whether the regex was consumed or not.
common_chat_msg_parser builder("Hello,", /* is_partial= */ false, {});
assert_equals(false, builder.try_consume_regex(common_regex("Hello, world!")).has_value());
}
{
common_chat_msg_parser builder("Hello,", /* is_partial= */ false, {});
auto res = builder.try_consume_regex(common_regex("H(el)l(?:o, world!)?"));
assert_equals(true, res.has_value());
// Verify captures
assert_equals<size_t>(2, res->groups.size());
assert_equals("Hell", builder.str(res->groups[0]));
assert_equals("el", builder.str(res->groups[1]));
// Verify position is after the match
assert_equals<size_t>(4, builder.pos());
assert_equals("o,", builder.consume_rest());
}
{
// But in partial mode, we have a partial final match / can't decide, so we throw a partial exception.
common_chat_msg_parser builder("Hello,", /* is_partial= */ true, {});
assert_throws([&]() {
builder.try_consume_regex(common_regex("Hello, world!"));
}, "^Hello, world!$");
}
// Now regardless of the mode, we can tell these aren't a match.
for (const auto is_partial : {false, true}) {
common_chat_msg_parser builder("Hello,", is_partial, {});
assert_equals(false, builder.try_consume_regex(common_regex("a(b|c)(d|e)f")).has_value());
}
for (const auto is_partial : {false, true}) {
common_chat_msg_parser builder("Hello,", is_partial, {});
assert_equals(false, builder.try_consume_literal("Oh"));
}
}
const std::vector<std::string> barely_healable_jsons = {
"{",
"{\"",
"{\"\\",
"{\"n",
"{\"name\"",
"{\"name\":",
"{\"name\":\"",
"{\"name\":\"\\",
"{\"name\":\"python",
"{\"name\":\"python\\",
"{\",",
"{\":",
"{\"[",
"{\"]",
"{\"{",
"{\"}",
"{\"1",
"{\"name\":\",",
"{\"name\":\":",
"{\"name\":\"[",
"{\"name\":\"]",
"{\"name\":\"{",
"{\"name\":\"}",
"{\"name\":\"1",
};
static void test(const std::string & input, bool is_partial, const std::vector<std::vector<std::string>> & args_paths, const std::vector<std::vector<std::string>> & content_paths, const std::string & expected) {
common_chat_msg_parser builder(input, is_partial, {});
auto js = builder.try_consume_json_with_dumped_args(args_paths, content_paths);
assert_equals(true, js.has_value());
assert_equals(is_partial, js->is_partial);
assert_equals(expected, args_paths.size() == 1 && args_paths[0].empty() ? js->value.get<std::string>() : js->value.dump());
}
static void test_with_args(const std::string & input, const std::string & expected, bool parse_as_partial = true, bool is_partial = true) {
common_chat_msg_parser builder(input, parse_as_partial, {});
auto js = builder.try_consume_json_with_dumped_args({{"args"}}, {});
assert_equals(true, js.has_value());
assert_equals(is_partial, js->is_partial);
assert_equals(expected, js->value.dump());
}
static void test_json_with_dumped_args_no_args() {
// Normal JSON, nothing to heal, nothing to dump
test("{\"name\": \"python\"}", false, {}, {}, "{\"name\":\"python\"}");
// Full json is args
test("{\"name\": \"python\"}", false, {{}}, {}, "{\"name\":\"python\"}");
// If the arguments are further down, don't heal partial content.
for (const auto & src : barely_healable_jsons) {
test(src, true, {{"arguments"}}, {}, "{}");
}
// But heal content that isn't partial.
test("{\"name\": \"python\"", true, {{"arguments"}}, {}, "{\"name\":\"python\"}");
}
static void test_json_with_dumped_args() {
// Partial content.
test("{\"content\": \"t", true, {}, {{"content"}}, "{\"content\":\"t\"}");
test("{\"content\": \"", true, {}, {{"content"}}, "{\"content\":\"\"}");
test("{\"content\": ", true, {}, {{"content"}}, "{}");
// If the entire JSON is the arguments, healing it them dumping it produces the same output as the input (just reformatted).
test("{\"name\": \"python", true, {{}}, {}, "{\"name\":\"python");
for (const auto & src : barely_healable_jsons) {
test(src, true, {{}}, {}, src);
}
// Full JSON w/ args
for (auto parse_as_partial : {true, false}) {
test_with_args(
R"({"name": "python", "args": {"arg1": 1}})",
R"({"name":"python","args":"{\"arg1\":1}"})",
parse_as_partial,
/* is_partial= */ false
);
}
// Partial JSON w/ partial args
test_with_args(
R"({"foo": "bar", "args": {")",
R"({"foo":"bar","args":"{\""})"
);
// Partial args broken in object key
test_with_args(
R"({"foo": "bar", "args": {"ar)",
R"({"foo":"bar","args":"{\"ar"})"
);
// Partial args broken after object key
test_with_args(
R"({"foo": "bar", "args": {"arg1")",
R"({"foo":"bar","args":"{\"arg1\""})"
);
// Partial args broken before object value
test_with_args(
R"({"foo": "bar", "args": {"arg1":)",
R"({"foo":"bar","args":"{\"arg1\":"})"
);
// Partial args broken before object value (space)
test_with_args(
R"({"foo": "bar", "args": {"arg1": )",
R"({"foo":"bar","args":"{\"arg1\":"})"
);
// Partial args broken in object value that may not be complete (int)
test_with_args(
R"({"foo": "bar", "args": {"arg1": 1)",
R"({"foo":"bar","args":"{\"arg1\":"})"
);
// Partial args broken in object value that is complete (int)
test_with_args(
R"({"foo": "bar", "args": {"arg1": 1 )",
R"({"foo":"bar","args":"{\"arg1\":1"})"
);
// Partial args broken in object value that is incomplete (string)
test_with_args(
R"({"foo": "bar", "args": {"arg1": ")",
R"({"foo":"bar","args":"{\"arg1\":\""})"
);
// Partial args broken in object value that is complete (string)
test_with_args(
R"({"foo": "bar", "args": {"arg1": "1")",
R"({"foo":"bar","args":"{\"arg1\":\"1\""})"
);
// Partial args broken on array opening
test_with_args(
R"({"foo": "bar", "args": [)",
R"({"foo":"bar","args":"["})"
);
// Partial args broken on array value that is incomplete (int)
test_with_args(
R"({"foo": "bar", "args": [1)",
R"({"foo":"bar","args":"["})"
);
// Partial args broken on array value that is complete (int)
test_with_args(
R"({"foo": "bar", "args": [1 )",
R"({"foo":"bar","args":"[1"})"
);
// Partial args broken on array value that is complete (string)
test_with_args(
R"({"foo": "bar", "args": ["1")",
R"({"foo":"bar","args":"[\"1\""})"
);
// Partial args broken after array value
test_with_args(
R"({"foo": "bar", "args": [1,)",
R"({"foo":"bar","args":"[1,"})"
);
// Partial args broken on nested array
test_with_args(
R"({"foo": "bar", "args": {"arg1": [)",
R"({"foo":"bar","args":"{\"arg1\":["})"
);
}
static void test_positions() {
{
common_chat_msg_parser builder("Hello, world!", /* is_partial= */ false, {});
assert_equals<size_t>(0, builder.pos());
assert_throws([&]() { builder.move_to(100); });
assert_equals<size_t>(0, builder.pos());
assert_throws([&]() { builder.move_back(1); });
assert_equals<size_t>(0, builder.pos());
builder.move_to(8);
assert_equals<size_t>(8, builder.pos());
builder.move_back(1);
assert_equals<size_t>(7, builder.pos());
assert_equals("world!", builder.consume_rest());
builder.move_to(0);
assert_equals<size_t>(0, builder.pos());
assert_throws([&]() { builder.finish(); });
assert_equals<size_t>(0, builder.pos());
builder.move_to(builder.input().size());
builder.finish();
}
{
common_chat_msg_parser builder("Hello, world!", /* is_partial= */ true, {});
builder.move_to(builder.input().size());
assert_equals<size_t>(builder.input().size(), builder.pos());
builder.finish();
}
}
int main() {
test_positions();
test_json_with_dumped_args_no_args();
test_json_with_dumped_args();
test_reasoning();
test_regex();
std::cout << "All tests passed!\n";
return 0;
}
+729 -276
View File
File diff suppressed because it is too large Load Diff
+237
View File
@@ -0,0 +1,237 @@
#include "common.h"
#include "json-partial.h"
#include <exception>
#include <iostream>
#include <stdexcept>
template <class T> static void assert_equals(const T & expected, const T & actual) {
if (expected != actual) {
std::cerr << "Expected: " << expected << std::endl;
std::cerr << "Actual: " << actual << std::endl;
std::cerr << std::flush;
throw std::runtime_error("Test failed");
}
}
static void test_json_healing() {
auto parse = [](const std::string & str) {
std::cerr << "# Parsing: " << str << '\n';
std::string::const_iterator it = str.begin();
const auto end = str.end();
common_json out;
std::string healing_marker = "$llama.cpp.json$";
if (common_json_parse(it, end, healing_marker, out)) {
auto dump = out.json.dump();
std::cerr << "Parsed: " << dump << '\n';
std::cerr << "Magic: " << out.healing_marker.json_dump_marker << '\n';
std::string result;
if (!out.healing_marker.json_dump_marker.empty()) {
auto i = dump.find(out.healing_marker.json_dump_marker);
if (i == std::string::npos) {
throw std::runtime_error("Failed to find magic in dump " + dump + " (magic: " + out.healing_marker.json_dump_marker + ")");
}
result = dump.substr(0, i);
} else {
result = dump;
}
std::cerr << "Result: " << result << '\n';
if (string_starts_with(str, result)) {
std::cerr << "Failure!\n";
}
// return dump;
} else {
throw std::runtime_error("Failed to parse: " + str);
}
};
auto parse_all = [&](const std::string & str) {
for (size_t i = 1; i < str.size(); i++) {
parse(str.substr(0, i));
}
};
parse_all("{\"a\": \"b\"}");
parse_all("{\"hey\": 1, \"ho\\\"ha\": [1]}");
parse_all("[{\"a\": \"b\"}]");
auto test = [&](const std::vector<std::string> & inputs, const std::string & expected, const std::string & expected_marker) {
for (const auto & input : inputs) {
common_json out;
assert_equals(true, common_json_parse(input, "$foo", out));
assert_equals<std::string>(expected, out.json.dump());
assert_equals<std::string>(expected_marker, out.healing_marker.json_dump_marker);
}
};
// No healing needed:
test(
{
R"([{"a":"b"}, "y"])",
},
R"([{"a":"b"},"y"])",
""
);
// Partial literals can't be healed:
test(
{
R"([1)",
R"([tru)",
R"([n)",
R"([nul)",
R"([23.2)",
},
R"(["$foo"])",
R"("$foo)"
);
test(
{
R"({"a": 1)",
R"({"a": tru)",
R"({"a": n)",
R"({"a": nul)",
R"({"a": 23.2)",
},
R"({"a":"$foo"})",
R"("$foo)"
);
test(
{
R"({)",
},
R"({"$foo":1})",
R"("$foo)"
);
test(
{
R"([)",
},
R"(["$foo"])",
R"("$foo)"
);
// Healing right after a full literal
test(
{
R"(1 )",
},
R"(1)",
""
);
test(
{
R"(true)",
R"(true )",
},
R"(true)",
""
);
test(
{
R"(null)",
R"(null )",
},
R"(null)",
""
);
test(
{
R"([1 )",
},
R"([1,"$foo"])",
R"(,"$foo)"
);
test(
{
R"([{})",
R"([{} )",
},
R"([{},"$foo"])",
R"(,"$foo)"
);
test(
{
R"([true)",
},
// TODO: detect the true/false/null literal was complete
R"(["$foo"])",
R"("$foo)"
);
test(
{
R"([true )",
},
R"([true,"$foo"])",
R"(,"$foo)"
);
test(
{
R"([true,)",
},
R"([true,"$foo"])",
R"("$foo)"
);
// Test nesting
test(
{
R"([{"a": [{"b": [{)",
},
R"([{"a":[{"b":[{"$foo":1}]}]}])",
R"("$foo)"
);
test(
{
R"([{"a": [{"b": [)",
},
R"([{"a":[{"b":["$foo"]}]}])",
R"("$foo)"
);
test(
{
R"([{"a": "b"})",
R"([{"a": "b"} )",
},
R"([{"a":"b"},"$foo"])",
R"(,"$foo)"
);
test(
{
R"([{"a": "b"},)",
R"([{"a": "b"}, )",
},
R"([{"a":"b"},"$foo"])",
R"("$foo)"
);
test(
{
R"({ "code)",
},
R"({"code$foo":1})",
R"($foo)"
);
test(
{
R"({ "code\)",
},
R"({"code\\$foo":1})",
R"(\$foo)"
);
test(
{
R"({ "code")",
},
R"({"code":"$foo"})",
R"(:"$foo)"
);
test(
{
R"({ "key")",
},
R"({"key":"$foo"})",
R"(:"$foo)"
);
}
int main() {
test_json_healing();
std::cerr << "All tests passed.\n";
return 0;
}
+13 -2
View File
@@ -1,5 +1,15 @@
# mtmd
# compile mtmd-audio separately to avoid long compile times with miniaudio.h
# TODO @ngxson : move miniaudio.h and stb_image.h to mtmd-helper.cpp, then compile the helper as a separate library
add_library(mtmd_audio STATIC mtmd-audio.cpp mtmd-audio.h)
if (BUILD_SHARED_LIBS)
set_target_properties(mtmd_audio PROPERTIES POSITION_INDEPENDENT_CODE ON)
endif()
target_link_libraries(mtmd_audio PRIVATE ggml ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(mtmd_audio PRIVATE cxx_std_17)
target_include_directories(mtmd_audio PRIVATE .)
add_library(mtmd OBJECT
mtmd.cpp
mtmd-helper.cpp
@@ -9,7 +19,7 @@ add_library(mtmd OBJECT
clip-impl.h
)
target_link_libraries(mtmd PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT})
target_link_libraries(mtmd PRIVATE ggml llama mtmd_audio ${CMAKE_THREAD_LIBS_INIT})
target_include_directories(mtmd PUBLIC .)
target_include_directories(mtmd PRIVATE ../..)
@@ -22,12 +32,13 @@ if (BUILD_SHARED_LIBS)
set_target_properties(mtmd PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_compile_definitions(mtmd PRIVATE LLAMA_SHARED LLAMA_BUILD)
add_library(mtmd_shared SHARED $<TARGET_OBJECTS:mtmd>)
target_link_libraries(mtmd_shared PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT})
target_link_libraries(mtmd_shared PRIVATE ggml llama mtmd_audio ${CMAKE_THREAD_LIBS_INIT})
install(TARGETS mtmd_shared LIBRARY)
endif()
if (NOT MSVC)
target_compile_options(mtmd PRIVATE -Wno-cast-qual) # stb_image.h
target_compile_options(mtmd_audio PRIVATE -Wno-cast-qual) # miniaudio.h
endif()
if(TARGET BUILD_INFO)
+37 -12
View File
@@ -16,22 +16,26 @@
#define KEY_FTYPE "general.file_type"
#define KEY_NAME "general.name"
#define KEY_DESCRIPTION "general.description"
#define KEY_MINICPMV_VERSION "clip.minicpmv_version"
#define KEY_PROJ_TYPE "clip.projector_type"
#define KEY_HAS_AUDIO_ENC "clip.has_audio_encoder"
#define KEY_HAS_VISION_ENC "clip.has_vision_encoder"
#define KEY_USE_GELU "clip.use_gelu"
#define KEY_USE_SILU "clip.use_silu"
#define KEY_N_EMBD "clip.vision.embedding_length"
#define KEY_N_FF "clip.vision.feed_forward_length"
#define KEY_N_BLOCK "clip.vision.block_count"
#define KEY_N_HEAD "clip.vision.attention.head_count"
#define KEY_LAYER_NORM_EPS "clip.vision.attention.layer_norm_epsilon"
#define KEY_PROJ_DIM "clip.vision.projection_dim"
#define KEY_N_EMBD "clip.%s.embedding_length"
#define KEY_N_FF "clip.%s.feed_forward_length"
#define KEY_N_BLOCK "clip.%s.block_count"
#define KEY_PROJ_DIM "clip.%s.projection_dim"
#define KEY_N_HEAD "clip.%s.attention.head_count"
#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon"
// vision-specific
#define KEY_IMAGE_SIZE "clip.vision.image_size"
#define KEY_PATCH_SIZE "clip.vision.patch_size"
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
#define KEY_IMAGE_STD "clip.vision.image_std"
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
#define KEY_PROJ_TYPE "clip.projector_type"
#define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size"
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
@@ -39,13 +43,18 @@
#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
#define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern"
#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size"
#define KEY_MINICPMV_VERSION "clip.minicpmv_version"
// audio-specific
#define KEY_A_NUM_MEL_BINS "clip.audio.num_mel_bins"
#define KEY_A_PROJ_STACK_FACTOR "clip.audio.projector.stack_factor"
//
// tensor name constants
//
#define TN_POS_EMBD "v.position_embd.weight"
#define TN_POS_EMBD "%s.position_embd.weight"
#define TN_CLASS_EMBD "v.class_embd"
#define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat
#define TN_PATCH_EMBD_1 "v.patch_embd.weight.1"
@@ -95,6 +104,12 @@
#define TN_GLM_ADAPTER_GATE "adapter.linear.gate.%s"
#define TN_GLM_ADAPTER_D_4H_2_H "adapter.linear.dense_4h_to_h.%s"
// ultravox
#define TN_CONV1D "a.conv1d.%d.%s"
#define TN_MM_AUDIO_MLP "mm.a.mlp.%d.%s"
#define TN_MM_NORM_PRE "mm.a.norm_pre.%s"
#define TN_MM_NORM_MID "mm.a.norm_mid.%s"
// align x to upper multiple of n
#define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n))
@@ -110,6 +125,7 @@ enum projector_type {
PROJECTOR_TYPE_IDEFICS3,
PROJECTOR_TYPE_PIXTRAL,
PROJECTOR_TYPE_QWEN25VL,
PROJECTOR_TYPE_ULTRAVOX,
PROJECTOR_TYPE_INTERNVL,
PROJECTOR_TYPE_LLAMA4,
PROJECTOR_TYPE_UNKNOWN,
@@ -126,6 +142,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},
{ PROJECTOR_TYPE_ULTRAVOX, "ultravox"},
{ PROJECTOR_TYPE_INTERNVL, "internvl"},
{ PROJECTOR_TYPE_LLAMA4, "llama4"},
};
@@ -147,8 +164,10 @@ struct clip_image_u8 {
std::vector<uint8_t> buf;
};
// RGB float32 image (NHWC)
// Memory layout: RGBRGBRGB...
// For images, buf.size() == nx*ny*3
// Memory layout: RGBRGBRGB...
// For audio, only one channel is used, buf.size() == nx*ny
// nx will be n_frames and ny will be n_mel
struct clip_image_f32 {
int nx;
int ny;
@@ -242,6 +261,7 @@ struct clip_image_u8_batch {
struct clip_image_f32_batch {
std::vector<clip_image_f32_ptr> entries;
bool is_audio = false;
// for llava-uhd style models, we need to know the grid size
// note: entries.size() == grid_x * grid_y + 1 (one overview image)
@@ -249,7 +269,12 @@ struct clip_image_f32_batch {
int grid_y = 0;
clip_image_f32_batch clone() const {
clip_image_f32_batch new_batch;
clip_image_f32_batch new_batch{
/* entries */ {},
/* is_audio */ is_audio,
/* grid_x */ grid_x,
/* grid_y */ grid_y,
};
new_batch.entries.reserve(entries.size());
for (const auto & entry : entries) {
new_batch.entries.emplace_back(new clip_image_f32(*entry));
+260 -56
View File
@@ -35,6 +35,7 @@ struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callbac
enum ffn_op_type {
FFN_GELU,
FFN_GELU_ERF,
FFN_SILU,
FFN_GELU_QUICK,
};
@@ -165,6 +166,9 @@ enum patch_merge_type {
};
struct clip_hparams {
bool has_vision = false;
bool has_audio = false;
int32_t image_size;
int32_t patch_size;
int32_t n_embd;
@@ -191,6 +195,10 @@ struct clip_hparams {
int32_t attn_window_size = 0;
int32_t n_wa_pattern = 0;
int32_t spatial_merge_size = 0;
// audio
int32_t n_mel_bins = 0; // whisper preprocessor
int32_t proj_stack_factor = 0; // ultravox
};
struct clip_layer {
@@ -332,6 +340,14 @@ struct clip_vision_model {
// pixtral
ggml_tensor * token_embd_img_break = nullptr;
ggml_tensor * mm_patch_merger_w = nullptr;
// ultravox / whisper encoder
ggml_tensor * conv1d_1_w = nullptr;
ggml_tensor * conv1d_1_b = nullptr;
ggml_tensor * conv1d_2_w = nullptr;
ggml_tensor * conv1d_2_b = nullptr;
ggml_tensor * mm_norm_pre_w = nullptr;
ggml_tensor * mm_norm_mid_w = nullptr;
};
struct clip_ctx {
@@ -1408,6 +1424,104 @@ struct clip_graph {
return gf;
}
// whisper encoder with custom projector
ggml_cgraph * build_whisper_enc() {
const int n_frames = img.nx;
const int n_pos = n_frames / 2;
GGML_ASSERT(model.position_embeddings->ne[1] >= n_pos);
ggml_tensor * inp = build_inp_raw(1);
// conv1d block
{
// convolution + gelu
ggml_tensor * cur = ggml_conv_1d_ph(ctx0, model.conv1d_1_w, inp, 1, 1);
cur = ggml_add(ctx0, cur, model.conv1d_1_b);
cur = ggml_gelu_erf(ctx0, cur);
cur = ggml_conv_1d_ph(ctx0, model.conv1d_2_w, cur, 2, 1);
cur = ggml_add(ctx0, cur, model.conv1d_2_b);
cur = ggml_gelu_erf(ctx0, cur);
// transpose
inp = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
cb(inp, "after_conv1d", -1);
}
// sanity check (only check one layer, but it should be the same for all)
GGML_ASSERT(model.layers[0].ln_1_w && model.layers[0].ln_1_b);
GGML_ASSERT(model.layers[0].ln_2_w && model.layers[0].ln_2_b);
GGML_ASSERT(model.layers[0].q_b);
GGML_ASSERT(model.layers[0].v_b);
GGML_ASSERT(!model.layers[0].k_b); // no bias for k
GGML_ASSERT(model.post_ln_w && model.post_ln_b);
ggml_tensor * pos_embd_selected = ggml_view_2d(
ctx0, model.position_embeddings,
model.position_embeddings->ne[0], n_pos,
model.position_embeddings->nb[1], 0
);
ggml_tensor * cur = build_vit(
inp, n_pos,
NORM_TYPE_NORMAL,
hparams.ffn_op,
pos_embd_selected,
nullptr);
cb(cur, "after_transformer", -1);
// StackAudioFrames
// https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/blob/main/ultravox_model.py
{
int64_t stride = n_embd * hparams.proj_stack_factor;
int64_t padded_len = GGML_PAD(ggml_nelements(cur), stride);
int64_t pad = padded_len - ggml_nelements(cur);
if (pad > 0) {
cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0);
cur = ggml_pad(ctx0, cur, pad, 0, 0, 0);
}
cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride,
ggml_row_size(cur->type, stride), 0);
}
cb(cur, "after_stacked", -1);
// UltravoxProjector
{
// pre-norm
cur = ggml_rms_norm(ctx0, cur, 1e-6);
cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w);
// ffn in
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
// swiglu
{
int64_t split_point = cur->ne[0] / 2;
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
// see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half
x1 = ggml_silu(ctx0, x1);
cur = ggml_mul(ctx0, x0, x1);
}
// mid-norm
cur = ggml_rms_norm(ctx0, cur, 1e-6);
cur = ggml_mul(ctx0, cur, model.mm_norm_mid_w);
// ffn out
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
}
cb(cur, "projected", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}
private:
//
// utility functions
@@ -1562,8 +1676,8 @@ private:
return inp;
}
ggml_tensor * build_inp_raw() {
ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, img.nx, img.ny, 3);
ggml_tensor * build_inp_raw(int channels = 3) {
ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, img.nx, img.ny, channels);
ggml_set_name(inp_raw, "inp_raw");
ggml_set_input(inp_raw);
return inp_raw;
@@ -1641,6 +1755,11 @@ private:
cur = ggml_gelu(ctx0, cur);
cb(cur, "ffn_gelu", il);
} break;
case FFN_GELU_ERF:
{
cur = ggml_gelu_erf(ctx0, cur);
cb(cur, "ggml_gelu_erf", il);
} break;
case FFN_GELU_QUICK:
{
cur = ggml_gelu_quick(ctx0, cur);
@@ -1832,6 +1951,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
{
res = graph.build_llama4();
} break;
case PROJECTOR_TYPE_ULTRAVOX:
{
res = graph.build_whisper_enc();
} break;
default:
{
res = graph.build_llava();
@@ -1915,18 +2038,30 @@ struct clip_model_loader {
// other hparams
{
get_i32(KEY_MINICPMV_VERSION, ctx_clip.minicpmv_version, false); // legacy
get_bool(KEY_HAS_AUDIO_ENC, hparams.has_audio, false);
get_bool(KEY_HAS_VISION_ENC, hparams.has_vision, false);
get_u32(KEY_N_EMBD, hparams.n_embd);
get_u32(KEY_N_HEAD, hparams.n_head);
get_u32(KEY_N_FF, hparams.n_ff);
get_u32(KEY_N_BLOCK, hparams.n_layer);
get_u32(KEY_PROJ_DIM, hparams.projection_dim);
get_f32(KEY_LAYER_NORM_EPS, hparams.eps);
get_u32(KEY_IMAGE_SIZE, hparams.image_size);
get_u32(KEY_PATCH_SIZE, hparams.patch_size);
get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false);
get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false);
const char * prefix = hparams.has_vision ? "vision" : "audio";
get_u32(string_format(KEY_N_EMBD, prefix), hparams.n_embd);
get_u32(string_format(KEY_N_HEAD, prefix), hparams.n_head);
get_u32(string_format(KEY_N_FF, prefix), hparams.n_ff);
get_u32(string_format(KEY_N_BLOCK, prefix), hparams.n_layer);
get_u32(string_format(KEY_PROJ_DIM, prefix), hparams.projection_dim);
get_f32(string_format(KEY_LAYER_NORM_EPS, prefix), hparams.eps);
if (hparams.has_vision) {
get_u32(KEY_IMAGE_SIZE, hparams.image_size);
get_u32(KEY_PATCH_SIZE, hparams.patch_size);
get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false);
get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false);
get_i32(KEY_MINICPMV_VERSION, ctx_clip.minicpmv_version, false); // legacy
} else if (hparams.has_audio) {
get_u32(KEY_A_NUM_MEL_BINS, hparams.n_mel_bins);
} else {
throw std::runtime_error(string_format("%s: neither vision nor audio encoder is present\n", __func__));
}
// default warmup value
hparams.warmup_image_size = hparams.image_size;
@@ -1964,7 +2099,7 @@ struct clip_model_loader {
}
}
{
if (hparams.has_vision) {
int idx_mean = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_MEAN);
int idx_std = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_STD);
GGML_ASSERT(idx_mean >= 0 && "image_mean not found");
@@ -2050,30 +2185,43 @@ struct clip_model_loader {
isize, isize*3, // 336, 1008
};
} break;
case PROJECTOR_TYPE_ULTRAVOX:
{
get_u32(KEY_A_PROJ_STACK_FACTOR, hparams.proj_stack_factor);
if (hparams.n_mel_bins != 128) {
throw std::runtime_error(string_format("%s: only 128 mel bins are supported for ultravox\n", __func__));
}
hparams.ffn_op = FFN_GELU_ERF;
log_ffn_op = "gelu_erf"; // temporary solution for logging
} break;
default:
break;
}
LOG_INF("%s: projector: %s\n", __func__, proj_type.c_str());
LOG_INF("%s: has_vision_encoder: %d\n", __func__, hparams.has_vision);
LOG_INF("%s: has_audio_encoder: %d\n", __func__, hparams.has_audio);
LOG_INF("%s: n_embd: %d\n", __func__, hparams.n_embd);
LOG_INF("%s: n_head: %d\n", __func__, hparams.n_head);
LOG_INF("%s: n_ff: %d\n", __func__, hparams.n_ff);
LOG_INF("%s: n_layer: %d\n", __func__, hparams.n_layer);
LOG_INF("%s: projection_dim: %d\n", __func__, hparams.projection_dim);
LOG_INF("%s: image_size: %d\n", __func__, hparams.image_size);
LOG_INF("%s: patch_size: %d\n", __func__, hparams.patch_size);
LOG_INF("\n");
LOG_INF("%s: has_llava_proj: %d\n", __func__, ctx_clip.has_llava_projector);
LOG_INF("%s: minicpmv_version: %d\n", __func__, ctx_clip.minicpmv_version);
LOG_INF("%s: proj_scale_factor: %d\n", __func__, hparams.proj_scale_factor);
LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern);
LOG_INF("%s: ffn_op: %s\n", __func__, log_ffn_op.c_str());
LOG_INF("%s: projection_dim: %d\n", __func__, hparams.projection_dim);
LOG_INF("\n");
if (hparams.has_vision) {
LOG_INF("%s: image_size: %d\n", __func__, hparams.image_size);
LOG_INF("%s: patch_size: %d\n", __func__, hparams.patch_size);
LOG_INF("%s: has_llava_proj: %d\n", __func__, ctx_clip.has_llava_projector);
LOG_INF("%s: minicpmv_version: %d\n", __func__, ctx_clip.minicpmv_version);
LOG_INF("%s: proj_scale_factor: %d\n", __func__, hparams.proj_scale_factor);
LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern);
} else if (hparams.has_audio) {
LOG_INF("%s: n_mel_bins: %d\n", __func__, hparams.n_mel_bins);
LOG_INF("%s: proj_stack_factor: %d\n", __func__, hparams.proj_stack_factor);
}
LOG_INF("\n");
LOG_INF("%s: model size: %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0);
LOG_INF("%s: metadata size: %.2f MiB\n", __func__, ggml_get_mem_size(ctx_meta.get()) / 1024.0 / 1024.0);
if (ctx_clip.proj_type == PROJECTOR_TYPE_LLAMA4) {
LOG_WRN("%s: llama 4 vision is known to have degraded quality: https://github.com/ggml-org/llama.cpp/pull/13282\n", __func__);
}
}
}
@@ -2082,6 +2230,9 @@ struct clip_model_loader {
std::map<std::string, size_t> tensor_offset;
std::vector<ggml_tensor *> tensors_to_load;
// TODO @ngxson : support both audio and video in the future
const char * prefix = hparams.has_audio ? "a" : "v";
// get offsets
for (int64_t i = 0; i < gguf_get_n_tensors(ctx_gguf.get()); ++i) {
const char * name = gguf_get_tensor_name(ctx_gguf.get(), i);
@@ -2119,47 +2270,47 @@ struct clip_model_loader {
vision_model.class_embedding = get_tensor(TN_CLASS_EMBD, false);
vision_model.pre_ln_w = get_tensor(string_format(TN_LN_PRE, "v", "weight"), false);
vision_model.pre_ln_b = get_tensor(string_format(TN_LN_PRE, "v", "bias"), false);
vision_model.pre_ln_w = get_tensor(string_format(TN_LN_PRE, prefix, "weight"), false);
vision_model.pre_ln_b = get_tensor(string_format(TN_LN_PRE, prefix, "bias"), false);
vision_model.post_ln_w = get_tensor(string_format(TN_LN_POST, "v", "weight"), false);
vision_model.post_ln_b = get_tensor(string_format(TN_LN_POST, "v", "bias"), false);
vision_model.post_ln_w = get_tensor(string_format(TN_LN_POST, prefix, "weight"), false);
vision_model.post_ln_b = get_tensor(string_format(TN_LN_POST, prefix, "bias"), false);
vision_model.patch_bias = get_tensor(TN_PATCH_BIAS, false);
vision_model.patch_embeddings_0 = get_tensor(TN_PATCH_EMBD, false);
vision_model.patch_embeddings_1 = get_tensor(TN_PATCH_EMBD_1, false);
vision_model.position_embeddings = get_tensor(TN_POS_EMBD, false);
vision_model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, prefix), false);
// layers
vision_model.layers.resize(hparams.n_layer);
for (int il = 0; il < hparams.n_layer; ++il) {
auto & layer = vision_model.layers[il];
layer.k_w = get_tensor(string_format(TN_ATTN_K, "v", il, "weight"));
layer.q_w = get_tensor(string_format(TN_ATTN_Q, "v", il, "weight"));
layer.v_w = get_tensor(string_format(TN_ATTN_V, "v", il, "weight"));
layer.o_w = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "weight"));
layer.k_norm = get_tensor(string_format(TN_ATTN_K_NORM, "v", il, "weight"), false);
layer.q_norm = get_tensor(string_format(TN_ATTN_Q_NORM, "v", il, "weight"), false);
layer.ln_1_w = get_tensor(string_format(TN_LN_1, "v", il, "weight"), false);
layer.ln_2_w = get_tensor(string_format(TN_LN_2, "v", il, "weight"), false);
layer.ls_1_w = get_tensor(string_format(TN_LS_1, "v", il, "weight"), false); // no bias
layer.ls_2_w = get_tensor(string_format(TN_LS_2, "v", il, "weight"), false); // no bias
layer.k_w = get_tensor(string_format(TN_ATTN_K, prefix, il, "weight"));
layer.q_w = get_tensor(string_format(TN_ATTN_Q, prefix, il, "weight"));
layer.v_w = get_tensor(string_format(TN_ATTN_V, prefix, il, "weight"));
layer.o_w = get_tensor(string_format(TN_ATTN_OUTPUT, prefix, il, "weight"));
layer.k_norm = get_tensor(string_format(TN_ATTN_K_NORM, prefix, il, "weight"), false);
layer.q_norm = get_tensor(string_format(TN_ATTN_Q_NORM, prefix, il, "weight"), false);
layer.ln_1_w = get_tensor(string_format(TN_LN_1, prefix, il, "weight"), false);
layer.ln_2_w = get_tensor(string_format(TN_LN_2, prefix, il, "weight"), false);
layer.ls_1_w = get_tensor(string_format(TN_LS_1, prefix, il, "weight"), false); // no bias
layer.ls_2_w = get_tensor(string_format(TN_LS_2, prefix, il, "weight"), false); // no bias
layer.k_b = get_tensor(string_format(TN_ATTN_K, "v", il, "bias"), false);
layer.q_b = get_tensor(string_format(TN_ATTN_Q, "v", il, "bias"), false);
layer.v_b = get_tensor(string_format(TN_ATTN_V, "v", il, "bias"), false);
layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "bias"), false);
layer.ln_1_b = get_tensor(string_format(TN_LN_1, "v", il, "bias"), false);
layer.ln_2_b = get_tensor(string_format(TN_LN_2, "v", il, "bias"), false);
layer.k_b = get_tensor(string_format(TN_ATTN_K, prefix, il, "bias"), false);
layer.q_b = get_tensor(string_format(TN_ATTN_Q, prefix, il, "bias"), false);
layer.v_b = get_tensor(string_format(TN_ATTN_V, prefix, il, "bias"), false);
layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, prefix, il, "bias"), false);
layer.ln_1_b = get_tensor(string_format(TN_LN_1, prefix, il, "bias"), false);
layer.ln_2_b = get_tensor(string_format(TN_LN_2, prefix, il, "bias"), false);
// ffn
layer.ff_up_w = get_tensor(string_format(TN_FFN_UP, "v", il, "weight"));
layer.ff_up_b = get_tensor(string_format(TN_FFN_UP, "v", il, "bias"), false);
layer.ff_gate_w = get_tensor(string_format(TN_FFN_GATE, "v", il, "weight"), false);
layer.ff_gate_b = get_tensor(string_format(TN_FFN_GATE, "v", il, "bias"), false);
layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, "v", il, "weight"));
layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, "v", il, "bias"), false);
layer.ff_up_w = get_tensor(string_format(TN_FFN_UP, prefix, il, "weight"));
layer.ff_up_b = get_tensor(string_format(TN_FFN_UP, prefix, il, "bias"), false);
layer.ff_gate_w = get_tensor(string_format(TN_FFN_GATE, prefix, il, "weight"), false);
layer.ff_gate_b = get_tensor(string_format(TN_FFN_GATE, prefix, il, "bias"), false);
layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "weight"));
layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "bias"), false);
// some models already exported with legacy (incorrect) naming which is quite messy, let's fix it here
// note: Qwen model converted from the old surgery script has n_ff = 0, so we cannot use n_ff to check!
@@ -2301,6 +2452,17 @@ struct clip_model_loader {
vision_model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false);
vision_model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
} break;
case PROJECTOR_TYPE_ULTRAVOX:
{
vision_model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
vision_model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias"));
vision_model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight"));
vision_model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias"));
vision_model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight"));
vision_model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight"));
vision_model.mm_norm_pre_w = get_tensor(string_format(TN_MM_NORM_PRE, "weight"));
vision_model.mm_norm_mid_w = get_tensor(string_format(TN_MM_NORM_MID, "weight"));
} break;
case PROJECTOR_TYPE_INTERNVL:
{
vision_model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
@@ -2358,13 +2520,19 @@ struct clip_model_loader {
}
void alloc_compute_meta() {
const auto & hparams = ctx_clip.vision_model.hparams;
ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
// create a fake batch
clip_image_f32_batch batch;
clip_image_f32_ptr img(clip_image_f32_init());
img->nx = ctx_clip.vision_model.hparams.warmup_image_size;
img->ny = ctx_clip.vision_model.hparams.warmup_image_size;
if (hparams.has_vision) {
img->nx = hparams.warmup_image_size;
img->ny = hparams.warmup_image_size;
} else {
img->nx = 1024; // TODO @ngxson : use a better default
img->ny = hparams.n_mel_bins;
}
img->buf.resize(img->nx * img->ny * 3);
batch.entries.push_back(std::move(img));
@@ -3278,6 +3446,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
n_patches = n_patches_y*n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
} else if (ctx->proj_type == PROJECTOR_TYPE_LLAMA4) {
n_patches /= (scale_factor * scale_factor);
} else if (ctx->proj_type == PROJECTOR_TYPE_ULTRAVOX) {
const int proj_stack_factor = ctx->vision_model.hparams.proj_stack_factor;
const int n_len = CLIP_ALIGN(img->nx, proj_stack_factor);
n_patches = n_len / proj_stack_factor / 2;
}
return n_patches;
@@ -3435,7 +3607,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
};
// set input pixel values
{
if (!imgs.is_audio) {
size_t nelem = 0;
for (const auto & img : imgs.entries) {
nelem += img->nx * img->ny * 3;
@@ -3472,6 +3644,16 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
}
}
set_input_f32("inp_raw", inp_raw);
} else {
// audio input
GGML_ASSERT(imgs.entries.size() == 1);
const auto & mel_inp = imgs.entries[0];
const int n_step = mel_inp->nx;
const int n_mel = mel_inp->ny;
std::vector<float> inp_raw(n_step * n_mel);
std::memcpy(inp_raw.data(), mel_inp->buf.data(), n_step * n_mel * sizeof(float));
set_input_f32("inp_raw", inp_raw);
}
// set input per projector
@@ -3668,6 +3850,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
case PROJECTOR_TYPE_GEMMA3:
case PROJECTOR_TYPE_IDEFICS3:
case PROJECTOR_TYPE_INTERNVL:
case PROJECTOR_TYPE_ULTRAVOX:
{
// do nothing
} break;
@@ -3766,6 +3949,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return ctx->vision_model.mm_input_proj_w->ne[0];
case PROJECTOR_TYPE_IDEFICS3:
return ctx->vision_model.projection->ne[1];
case PROJECTOR_TYPE_ULTRAVOX:
return ctx->vision_model.mm_2_w->ne[1];
case PROJECTOR_TYPE_INTERNVL:
return ctx->vision_model.mm_3_w->ne[1];
case PROJECTOR_TYPE_LLAMA4:
@@ -3798,6 +3983,14 @@ bool clip_is_gemma3(const struct clip_ctx * ctx) {
return ctx->proj_type == PROJECTOR_TYPE_GEMMA3;
}
bool clip_has_vision_encoder(const struct clip_ctx * ctx) {
return ctx->vision_model.hparams.has_vision;
}
bool clip_has_audio_encoder(const struct clip_ctx * ctx) {
return ctx->vision_model.hparams.has_audio;
}
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
clip_image_f32 clip_img;
clip_img.buf.resize(h * w * 3);
@@ -3818,3 +4011,14 @@ bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img,
projector_type clip_get_projector_type(const struct clip_ctx * ctx) {
return ctx->proj_type;
}
void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel, int n_frames, float * mel) {
clip_image_f32 * audio = new clip_image_f32;
audio->nx = n_frames;
audio->ny = n_mel;
audio->buf.resize(n_frames * n_mel);
std::memcpy(audio->buf.data(), mel, n_frames * n_mel * sizeof(float));
batch->entries.push_back(clip_image_f32_ptr(audio));
batch->is_audio = true;
}
+6
View File
@@ -93,3 +93,9 @@ bool clip_is_llava(const struct clip_ctx * ctx);
bool clip_is_gemma3(const struct clip_ctx * ctx);
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);
// use by audio input
void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel, int n_frames, float * mel);
bool clip_has_vision_encoder(const struct clip_ctx * ctx);
bool clip_has_audio_encoder(const struct clip_ctx * ctx);
+93468
View File
File diff suppressed because it is too large Load Diff
+855
View File
@@ -0,0 +1,855 @@
// fix problem with std::min and std::max
#if defined(_WIN32)
#define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
# define NOMINMAX
#endif
#include <windows.h>
#endif
#include "mtmd-audio.h"
//#define MTMD_AUDIO_DEBUG
#define MINIAUDIO_IMPLEMENTATION
#ifndef MTMD_AUDIO_DEBUG
# define MA_NO_ENCODING
#endif
#define MA_NO_DEVICE_IO
#define MA_NO_RESOURCE_MANAGER
#define MA_NO_NODE_GRAPH
#define MA_NO_ENGINE
#define MA_NO_GENERATION
#define MA_API static
#include "miniaudio.h"
#define _USE_MATH_DEFINES // for M_PI
#include <cmath>
#include <cstdint>
#include <cstring>
#include <thread>
#include <vector>
#include <fstream>
#include <algorithm>
// most of the code here is copied from whisper.cpp
// align x to upper multiple of n
#define _ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n))
namespace whisper_preprocessor {
#define SIN_COS_N_COUNT WHISPER_N_FFT
namespace {
struct whisper_global_cache {
// In FFT, we frequently use sine and cosine operations with the same values.
// We can use precalculated values to speed up the process.
float sin_vals[SIN_COS_N_COUNT];
float cos_vals[SIN_COS_N_COUNT];
// Hann window (Use cosf to eliminate difference)
// ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
float hann_window[WHISPER_N_FFT];
whisper_global_cache() {
fill_sin_cos_table();
fill_hann_window(sizeof(hann_window)/sizeof(hann_window[0]), true, hann_window);
}
void fill_sin_cos_table() {
for (int i = 0; i < SIN_COS_N_COUNT; i++) {
double theta = (2 * M_PI * i) / SIN_COS_N_COUNT;
sin_vals[i] = sinf(theta);
cos_vals[i] = cosf(theta);
}
}
void fill_hann_window(int length, bool periodic, float * output) {
int offset = -1;
if (periodic) {
offset = 0;
}
for (int i = 0; i < length; i++) {
output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
}
}
} global_cache;
}
// naive Discrete Fourier Transform
// input is real-valued
// output is complex-valued
static void dft(const float* in, int N, float* out) {
const int sin_cos_step = SIN_COS_N_COUNT / N;
for (int k = 0; k < N; k++) {
float re = 0;
float im = 0;
for (int n = 0; n < N; n++) {
int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N
re += in[n]*global_cache.cos_vals[idx]; // cos(t)
im -= in[n]*global_cache.sin_vals[idx]; // sin(t)
}
out[k*2 + 0] = re;
out[k*2 + 1] = im;
}
}
// Cooley-Tukey FFT
// poor man's implementation - use something better
// input is real-valued
// output is complex-valued
static void fft(float* in, int N, float* out) {
if (N == 1) {
out[0] = in[0];
out[1] = 0;
return;
}
const int half_N = N / 2;
if (N - half_N*2 == 1) {
dft(in, N, out);
return;
}
float* even = in + N;
for (int i = 0; i < half_N; ++i) {
even[i]= in[2*i];
}
float* even_fft = out + 2 * N;
fft(even, half_N, even_fft);
float* odd = even;
for (int i = 0; i < half_N; ++i) {
odd[i] = in[2*i + 1];
}
float* odd_fft = even_fft + N;
fft(odd, half_N, odd_fft);
const int sin_cos_step = SIN_COS_N_COUNT / N;
for (int k = 0; k < half_N; k++) {
int idx = k * sin_cos_step; // t = 2*M_PI*k/N
float re = global_cache.cos_vals[idx]; // cos(t)
float im = -global_cache.sin_vals[idx]; // sin(t)
float re_odd = odd_fft[2*k + 0];
float im_odd = odd_fft[2*k + 1];
out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
}
}
static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
int n_samples, int frame_size, int frame_step, int n_threads,
const whisper_filters & filters, whisper_mel & mel) {
std::vector<float> fft_in(frame_size * 2, 0.0);
std::vector<float> fft_out(frame_size * 2 * 2 * 2);
int n_fft = filters.n_fft;
int i = ith;
// make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
WHISPER_ASSERT(n_fft == 1 + (frame_size / 2));
// calculate FFT only when fft_in are not all zero
for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
const int offset = i * frame_step;
// apply Hann window (~10% faster)
for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
fft_in[j] = hann[j] * samples[offset + j];
}
// fill the rest with zeros
if (n_samples - offset < frame_size) {
std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0);
}
// FFT
fft(fft_in.data(), frame_size, fft_out.data());
// Calculate modulus^2 of complex numbers
// Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
for (int j = 0; j < n_fft; j++) {
fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
}
// mel spectrogram
for (int j = 0; j < mel.n_mel; j++) {
double sum = 0.0;
// unroll loop (suggested by GH user @lunixbochs)
int k = 0;
for (k = 0; k < n_fft - 3; k += 4) {
sum +=
fft_out[k + 0] * filters.data[j * n_fft + k + 0] +
fft_out[k + 1] * filters.data[j * n_fft + k + 1] +
fft_out[k + 2] * filters.data[j * n_fft + k + 2] +
fft_out[k + 3] * filters.data[j * n_fft + k + 3];
}
// handle n_fft remainder
for (; k < n_fft; k++) {
sum += fft_out[k] * filters.data[j * n_fft + k];
}
sum = log10(std::max(sum, 1e-10));
mel.data[j * mel.n_len + i] = sum;
}
}
// Otherwise fft_out are all zero
double sum = log10(1e-10);
for (; i < mel.n_len; i += n_threads) {
for (int j = 0; j < mel.n_mel; j++) {
mel.data[j * mel.n_len + i] = sum;
}
}
}
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157
static bool log_mel_spectrogram(
const float * samples,
const int n_samples,
const int /*sample_rate*/,
const int frame_size,
const int frame_step,
const int n_mel,
const int n_threads,
const whisper_filters & filters,
const bool debug,
whisper_mel & mel) {
//const int64_t t_start_us = ggml_time_us();
// Hann window
WHISPER_ASSERT(frame_size == WHISPER_N_FFT && "Unsupported frame_size");
const float * hann = global_cache.hann_window;
// Calculate the length of padding
int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
int64_t stage_2_pad = frame_size / 2;
// Initialize a vector and copy data from C array to it.
std::vector<float> samples_padded;
samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2);
std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad);
// pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio
std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0);
// reflective pad 200 samples at the beginning of audio
std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin());
mel.n_mel = n_mel;
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936
// Calculate number of frames + remove the last frame
mel.n_len = (samples_padded.size() - frame_size) / frame_step;
// Calculate semi-padded sample length to ensure compatibility
mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step;
mel.data.resize(mel.n_mel * mel.n_len);
{
std::vector<std::thread> workers(n_threads - 1);
for (int iw = 0; iw < n_threads - 1; ++iw) {
workers[iw] = std::thread(
log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded),
n_samples + stage_2_pad, frame_size, frame_step, n_threads,
std::cref(filters), std::ref(mel));
}
// main thread
log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples + stage_2_pad, frame_size, frame_step, n_threads, filters, mel);
for (int iw = 0; iw < n_threads - 1; ++iw) {
workers[iw].join();
}
}
// clamping and normalization
double mmax = -1e20;
for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
if (mel.data[i] > mmax) {
mmax = mel.data[i];
}
}
mmax -= 8.0;
for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
if (mel.data[i] < mmax) {
mel.data[i] = mmax;
}
mel.data[i] = (mel.data[i] + 4.0)/4.0;
}
// Dump log_mel_spectrogram
if (debug) {
std::ofstream outFile("log_mel_spectrogram.json");
outFile << "[";
for (uint64_t i = 0; i < mel.data.size() - 1; i++) {
outFile << mel.data[i] << ", ";
}
outFile << mel.data[mel.data.size() - 1] << "]";
outFile.close();
}
return true;
}
bool preprocess_audio(
const float * samples,
size_t n_samples,
const whisper_filters & filters,
std::vector<whisper_mel> & output) {
if (n_samples == 0) {
// empty audio
return false;
}
whisper_mel out_full;
bool ok = log_mel_spectrogram(
samples,
n_samples,
COMMON_SAMPLE_RATE,
WHISPER_N_FFT,
WHISPER_HOP_LENGTH,
filters.n_mel,
4, // n_threads
filters,
false, // debug
out_full);
if (!ok) {
return false;
}
// because the cgraph in clip.cpp only accepts 3000 frames each, we need to split the mel
// we always expect the mel to have 3000 silent frames at the end
// printf("n_len %d\n", out_full.n_len);
const size_t frames_per_chunk = 3000;
GGML_ASSERT((size_t)out_full.n_len > frames_per_chunk);
for (size_t off = 0; off < (size_t)out_full.n_len; off += frames_per_chunk) {
int n_len = std::min(frames_per_chunk, (size_t)out_full.n_len - off);
if ((size_t)n_len < frames_per_chunk) {
break; // last uncomplete chunk will always be a padded chunk, safe to ignore
}
whisper_mel out_chunk;
out_chunk.n_len = n_len;
out_chunk.n_mel = out_full.n_mel;
out_chunk.n_len_org = out_full.n_mel; // unused
out_chunk.data.reserve(out_chunk.n_mel * out_chunk.n_len);
for (int i = 0; i < out_full.n_mel; i++) {
auto src = out_full.data.begin() + i*out_full.n_len + off;
out_chunk.data.insert(out_chunk.data.end(), src, src + frames_per_chunk);
}
output.push_back(std::move(out_chunk));
}
return true;
}
} // namespace whisper_preprocessor
namespace audio_helpers {
bool is_audio_file(const char * buf, size_t len) {
if (len < 12) {
return false;
}
// RIFF ref: https://en.wikipedia.org/wiki/Resource_Interchange_File_Format
// WAV ref: https://www.mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/WAVE.html
bool is_wav = memcmp(buf, "RIFF", 4) == 0 && memcmp(buf + 8, "WAVE", 4) == 0;
bool is_mp3 = len >= 3 && (
memcmp(buf, "ID3", 3) == 0 ||
// Check for MPEG sync word (simplified check)
((unsigned char)buf[0] == 0xFF && ((unsigned char)buf[1] & 0xE0) == 0xE0)
);
bool is_flac = memcmp(buf, "fLaC", 4) == 0;
return is_wav || is_mp3 || is_flac;
}
// returns true if the buffer is a valid audio file
bool decode_audio_from_buf(const unsigned char * buf_in, size_t len, int target_sampler_rate, std::vector<float> & pcmf32_mono) {
ma_result result;
const int channels = 1;
ma_decoder_config decoder_config = ma_decoder_config_init(ma_format_f32, channels, target_sampler_rate);
ma_decoder decoder;
result = ma_decoder_init_memory(buf_in, len, &decoder_config, &decoder);
if (result != MA_SUCCESS) {
return false;
}
ma_uint64 frame_count;
ma_uint64 frames_read;
result = ma_decoder_get_length_in_pcm_frames(&decoder, &frame_count);
if (result != MA_SUCCESS) {
ma_decoder_uninit(&decoder);
return false;
}
pcmf32_mono.resize(frame_count);
result = ma_decoder_read_pcm_frames(&decoder, pcmf32_mono.data(), frame_count, &frames_read);
if (result != MA_SUCCESS) {
ma_decoder_uninit(&decoder);
return false;
}
#ifdef MTMD_AUDIO_DEBUG
// save audio to wav file
ma_encoder_config config = ma_encoder_config_init(ma_encoding_format_wav, ma_format_f32, 1, target_sampler_rate);
ma_encoder encoder;
ma_encoder_init_file("output.wav", &config, &encoder);
ma_encoder_write_pcm_frames(&encoder, pcmf32_mono.data(), pcmf32_mono.size(), &frames_read);
ma_encoder_uninit(&encoder);
#endif
ma_decoder_uninit(&decoder);
return true;
}
} // namespace wav_utils
// precalculated mel filter banks
// values are multiplied by 1000.0 to save space, and will be divided by 1000.0 in the end of the function
//
// generated from python code:
//
// from numpy import load
// data = load('mel_filters.npz')
// lst = data.files
// for item in lst:
// print(item)
// print(data[item].shape)
// n_mel = data[item].shape[0]
// n_fft = data[item].shape[1]
// for i, row in enumerate(data[item]):
// for j, val in enumerate(row):
// val = val * 1000.0
// if val != 0:
// print(f"data[{i*n_fft + j}] = {val:.6f};")
namespace whisper_precalc_filters {
whisper_preprocessor::whisper_filters get_128_bins() {
whisper_preprocessor::whisper_filters filters;
filters.n_mel = 128;
filters.n_fft = 201;
std::vector data(filters.n_mel * filters.n_fft, 0.0f);
data[1] = 12.37398665;
data[202] = 30.39256483;
data[404] = 24.74797331;
data[605] = 18.01857911;
data[807] = 37.12195903;
data[1008] = 5.64459199;
data[1009] = 6.72939420;
data[1210] = 36.03715822;
data[1412] = 19.10337992;
data[1613] = 23.66316877;
data[1815] = 31.47736564;
data[2016] = 11.28918398;
data[2017] = 1.08480197;
data[2218] = 41.68175161;
data[2420] = 13.45878839;
data[2621] = 29.30776216;
data[2823] = 25.83277412;
data[3024] = 16.93377644;
data[3226] = 38.20675984;
data[3427] = 4.55979025;
data[3428] = 7.81419594;
data[3629] = 34.95235741;
data[3831] = 20.18818259;
data[4032] = 22.57836796;
data[4234] = 32.56217018;
data[4435] = 10.20438317;
data[4436] = 2.16960395;
data[4637] = 40.59694707;
data[4839] = 14.54358920;
data[5040] = 28.22295949;
data[5242] = 26.91757679;
data[5443] = 15.84897563;
data[5645] = 39.29156065;
data[5846] = 3.47498828;
data[5847] = 8.89899861;
data[6048] = 33.86755288;
data[6250] = 21.27298526;
data[6451] = 21.49356715;
data[6653] = 33.64697099;
data[6854] = 9.11958050;
data[6855] = 3.25440569;
data[7056] = 39.51214626;
data[7258] = 15.62839188;
data[7459] = 27.13815868;
data[7661] = 28.00237760;
data[7862] = 14.76417296;
data[8064] = 40.37636518;
data[8265] = 2.38068704;
data[8266] = 10.20263787;
data[8467] = 31.61146119;
data[8669] = 24.54700135;
data[8870] = 15.32919332;
data[8871] = 1.66583748;
data[9072] = 36.72905266;
data[9274] = 20.09709924;
data[9475] = 16.93102531;
data[9476] = 2.90265540;
data[9677] = 32.84499049;
data[9879] = 23.52004871;
data[10080] = 11.03894413;
data[10081] = 10.72582975;
data[10282] = 22.71829173;
data[10484] = 32.27872774;
data[10685] = 0.11626833;
data[10686] = 22.85348251;
data[10887] = 8.56344029;
data[10888] = 14.97978810;
data[11089] = 15.51398356;
data[11090] = 8.51490628;
data[11291] = 21.10680379;
data[11292] = 3.32652032;
data[11493] = 25.47064796;
data[11695] = 27.35907957;
data[11896] = 0.65853616;
data[11897] = 23.83812517;
data[12098] = 3.44359246;
data[12099] = 21.22455277;
data[12300] = 5.35842171;
data[12301] = 19.42555793;
data[12502] = 6.49324711;
data[12503] = 18.35542172;
data[12704] = 6.93138083;
data[12705] = 17.93504693;
data[12906] = 6.74968259;
data[12907] = 18.09151843;
data[13108] = 6.01899112;
data[13109] = 18.75767298;
data[13310] = 4.80452832;
data[13311] = 19.87172849;
data[13512] = 3.16627859;
data[13513] = 21.37690969;
data[13514] = 1.25317345;
data[13714] = 1.15934468;
data[13715] = 20.80361731;
data[13716] = 4.04486805;
data[13917] = 17.55363122;
data[13918] = 7.08320038;
data[14119] = 14.07538634;
data[14120] = 10.32655034;
data[14321] = 10.40921453;
data[14322] = 13.73696327;
data[14523] = 6.59187697;
data[14524] = 17.27988198;
data[14525] = 1.46804214;
data[14725] = 2.65681883;
data[14726] = 18.09193194;
data[14727] = 5.85655728;
data[14928] = 13.34277913;
data[14929] = 10.28267574;
data[15130] = 8.56800377;
data[15131] = 14.72230814;
data[15132] = 1.04039861;
data[15332] = 3.79085587;
data[15333] = 17.14678481;
data[15334] = 6.11609267;
data[15535] = 11.75929047;
data[15536] = 11.13393717;
data[15737] = 6.43857848;
data[15738] = 16.07806236;
data[15739] = 4.23917221;
data[15939] = 1.19989377;
data[15940] = 12.75671553;
data[15941] = 9.65298992;
data[16142] = 7.06935255;
data[16143] = 14.94054683;
data[16144] = 4.19024844;
data[16344] = 1.51483389;
data[16345] = 12.00899947;
data[16346] = 9.84823331;
data[16547] = 6.10224018;
data[16548] = 15.33857174;
data[16549] = 5.57676842;
data[16749] = 0.36827257;
data[16750] = 9.89749376;
data[16751] = 11.35340426;
data[16752] = 2.05122307;
data[16952] = 3.89297144;
data[16953] = 12.97352277;
data[16954] = 8.06631614;
data[17155] = 6.74493238;
data[17156] = 13.85874674;
data[17157] = 5.41190524;
data[17357] = 0.74220158;
data[17358] = 8.98779090;
data[17359] = 11.37871388;
data[17360] = 3.32958088;
data[17560] = 2.82313535;
data[17561] = 10.68049297;
data[17562] = 9.43340641;
data[17563] = 1.76325557;
data[17763] = 4.39018616;
data[17764] = 11.87758986;
data[17765] = 7.97005836;
data[17766] = 0.66104700;
data[17966] = 5.49466675;
data[17967] = 12.62953598;
data[17968] = 6.93987962;
data[18169] = 6.18401915;
data[18170] = 12.93473132;
data[18171] = 6.29778765;
data[18371] = 0.02325210;
data[18372] = 6.50206627;
data[18373] = 12.32661773;
data[18374] = 6.00216538;
data[18574] = 0.31548753;
data[18575] = 6.48925547;
data[18576] = 12.04130240;
data[18577] = 6.01462880;
data[18777] = 0.29979556;
data[18778] = 6.18288014;
data[18779] = 12.04272825;
data[18780] = 6.29981188;
data[18781] = 0.55689598;
data[18980] = 0.01120471;
data[18981] = 5.61729167;
data[18982] = 11.22337859;
data[18983] = 6.82516303;
data[18984] = 1.35264499;
data[19184] = 4.82410006;
data[19185] = 10.16623247;
data[19186] = 7.56075513;
data[19187] = 2.34590308;
data[19387] = 3.83235747;
data[19388] = 8.92296247;
data[19389] = 8.47910438;
data[19390] = 3.50978645;
data[19590] = 2.66873185;
data[19591] = 7.51965167;
data[19592] = 9.55500547;
data[19593] = 4.81966138;
data[19594] = 0.08431751;
data[19793] = 1.35767367;
data[19794] = 5.98019501;
data[19795] = 10.60271543;
data[19796] = 6.25298498;
data[19797] = 1.74059917;
data[19997] = 4.32644226;
data[19998] = 8.73131864;
data[19999] = 7.78916525;
data[20000] = 3.48923868;
data[20200] = 2.57835095;
data[20201] = 6.77582854;
data[20202] = 9.40941647;
data[20203] = 5.31194592;
data[20204] = 1.21447595;
data[20403] = 0.75411191;
data[20404] = 4.75395704;
data[20405] = 8.75380263;
data[20406] = 7.19209015;
data[20407] = 3.28754401;
data[20607] = 2.68179690;
data[20608] = 6.49331464;
data[20609] = 9.11457930;
data[20610] = 5.39387390;
data[20611] = 1.67316827;
data[20810] = 0.57394296;
data[20811] = 4.20600036;
data[20812] = 7.83805829;
data[20813] = 7.52023002;
data[20814] = 3.97470826;
data[20815] = 0.42918732;
data[21014] = 1.90464477;
data[21015] = 5.36569161;
data[21016] = 8.82673822;
data[21017] = 6.27609482;
data[21018] = 2.89750961;
data[21218] = 2.89885257;
data[21219] = 6.19694078;
data[21220] = 8.56699049;
data[21221] = 5.34748193;
data[21222] = 2.12797290;
data[21421] = 0.44750227;
data[21422] = 3.59030394;
data[21423] = 6.73310598;
data[21424] = 7.77023612;
data[21425] = 4.70231380;
data[21426] = 1.63439126;
data[21625] = 1.01536023;
data[21626] = 4.01018746;
data[21627] = 7.00501446;
data[21628] = 7.23442994;
data[21629] = 4.31095669;
data[21630] = 1.38748321;
data[21829] = 1.33348850;
data[21830] = 4.18730825;
data[21831] = 7.04112789;
data[21832] = 6.93188375;
data[21833] = 4.14605811;
data[21834] = 1.36023236;
data[22033] = 1.42879714;
data[22034] = 4.14824858;
data[22035] = 6.86769979;
data[22036] = 6.83705276;
data[22037] = 4.18239459;
data[22038] = 1.52773573;
data[22237] = 1.32610439;
data[22238] = 3.91751388;
data[22239] = 6.50892360;
data[22240] = 6.92639686;
data[22241] = 4.39672917;
data[22242] = 1.86706171;
data[22441] = 1.04827771;
data[22442] = 3.51767405;
data[22443] = 5.98707050;
data[22444] = 7.17824046;
data[22445] = 4.76767914;
data[22446] = 2.35711760;
data[22645] = 0.61636406;
data[22646] = 2.96949223;
data[22647] = 5.32262027;
data[22648] = 7.57265091;
data[22649] = 5.27558755;
data[22650] = 2.97852419;
data[22651] = 0.68146095;
data[22849] = 0.04971400;
data[22850] = 2.29204819;
data[22851] = 4.53438237;
data[22852] = 6.77671656;
data[22853] = 5.90240723;
data[22854] = 3.71349836;
data[22855] = 1.52458926;
data[23054] = 1.50285335;
data[23055] = 3.63961048;
data[23056] = 5.77636715;
data[23057] = 6.63159089;
data[23058] = 4.54574358;
data[23059] = 2.45989650;
data[23060] = 0.37404924;
data[23258] = 0.61795861;
data[23259] = 2.65410915;
data[23260] = 4.69025923;
data[23261] = 6.72641024;
data[23262] = 5.46034705;
data[23263] = 3.47270933;
data[23264] = 1.48507138;
data[23463] = 1.59233576;
data[23464] = 3.53261665;
data[23465] = 5.47289755;
data[23466] = 6.44368259;
data[23467] = 4.54962999;
data[23468] = 2.65557761;
data[23469] = 0.76152512;
data[23667] = 0.46749352;
data[23668] = 2.31641904;
data[23669] = 4.16534441;
data[23670] = 6.01426978;
data[23671] = 5.67844696;
data[23672] = 3.87357362;
data[23673] = 2.06870004;
data[23674] = 0.26382666;
data[23872] = 1.05349103;
data[23873] = 2.81536230;
data[23874] = 4.57723346;
data[23875] = 6.33910485;
data[23876] = 5.12815686;
data[23877] = 3.40826320;
data[23878] = 1.68837002;
data[24077] = 1.43350090;
data[24078] = 3.11241671;
data[24079] = 4.79133241;
data[24080] = 6.40943693;
data[24081] = 4.77052201;
data[24082] = 3.13160778;
data[24083] = 1.49269309;
data[24281] = 0.02932359;
data[24282] = 1.62918994;
data[24283] = 3.22905602;
data[24284] = 4.82892245;
data[24285] = 6.14671456;
data[24286] = 4.58496623;
data[24287] = 3.02321767;
data[24288] = 1.46146910;
data[24486] = 0.13601698;
data[24487] = 1.66055572;
data[24488] = 3.18509457;
data[24489] = 4.70963307;
data[24490] = 6.04072399;
data[24491] = 4.55250870;
data[24492] = 3.06429295;
data[24493] = 1.57607743;
data[24494] = 0.08786193;
data[24691] = 0.09328097;
data[24692] = 1.54603878;
data[24693] = 2.99879676;
data[24694] = 4.45155473;
data[24695] = 5.90431225;
data[24696] = 4.65566106;
data[24697] = 3.23751615;
data[24698] = 1.81937125;
data[24699] = 0.40122634;
data[24897] = 1.30262633;
data[24898] = 2.68698297;
data[24899] = 4.07133950;
data[24900] = 5.45569602;
data[24901] = 4.87832492;
data[24902] = 3.52695142;
data[24903] = 2.17557792;
data[24904] = 0.82420459;
data[25102] = 0.94595028;
data[25103] = 2.26512621;
data[25104] = 3.58430226;
data[25105] = 4.90347855;
data[25106] = 5.20569785;
data[25107] = 3.91795207;
data[25108] = 2.63020652;
data[25109] = 1.34246063;
data[25110] = 0.05471494;
data[25307] = 0.49037894;
data[25308] = 1.74744334;
data[25309] = 3.00450763;
data[25310] = 4.26157191;
data[25311] = 5.51863620;
data[25312] = 4.39707236;
data[25313] = 3.16995848;
data[25314] = 1.94284460;
data[25315] = 0.71573065;
data[25513] = 1.14698056;
data[25514] = 2.34485767;
data[25515] = 3.54273478;
data[25516] = 4.74061165;
data[25517] = 4.95198462;
data[25518] = 3.78264743;
data[25519] = 2.61331047;
data[25520] = 1.44397374;
data[25521] = 0.27463681;
data[25718] = 0.47569509;
data[25719] = 1.61717169;
data[25720] = 2.75864848;
data[25721] = 3.90012516;
data[25722] = 5.04160160;
data[25723] = 4.45712078;
data[25724] = 3.34284059;
data[25725] = 2.22856039;
data[25726] = 1.11428020;
for (auto & val : data) {
val /= 1000.0f;
}
filters.data = std::move(data);
return filters;
}
} // namespace whisper_precalc_filters
+62
View File
@@ -0,0 +1,62 @@
#pragma once
#include "ggml.h"
#include <cstdint>
#include <vector>
#include <string>
#define WHISPER_ASSERT GGML_ASSERT
#define WHISPER_SAMPLE_RATE 16000
#define WHISPER_N_FFT 400
#define WHISPER_HOP_LENGTH 160
#define WHISPER_CHUNK_SIZE 30
#define COMMON_SAMPLE_RATE 16000
namespace whisper_preprocessor {
struct whisper_mel {
int n_len;
int n_len_org;
int n_mel;
std::vector<float> data;
};
struct whisper_filters {
int32_t n_mel;
int32_t n_fft;
std::vector<float> data;
};
extern bool preprocess_audio(
const float * samples,
size_t n_samples,
const whisper_filters & filters,
std::vector<whisper_mel> & output);
} // namespace whisper_preprocessor
// TODO @ngxson : move this helper to mtmd-helpers.cpp
namespace audio_helpers {
extern bool is_audio_file(const char * buf, size_t len);
extern bool decode_audio_from_buf(
const unsigned char * buf_in,
size_t len,
int target_sampler_rate,
std::vector<float> & pcmf32_mono);
} // namespace audio_helpers
namespace whisper_precalc_filters {
extern whisper_preprocessor::whisper_filters get_128_bins();
} // namespace whisper_precalc_filters
+21 -14
View File
@@ -37,10 +37,10 @@ static volatile bool g_is_interrupted = false;
static void show_additional_info(int /*argc*/, char ** argv) {
LOG(
"Experimental CLI for multimodal\n\n"
"Usage: %s [options] -m <model> --mmproj <mmproj> --image <image> -p <prompt>\n\n"
"Usage: %s [options] -m <model> --mmproj <mmproj> --image <image> --audio <audio> -p <prompt>\n\n"
" -m and --mmproj are required\n"
" -hf user/repo can replace both -m and --mmproj in most cases\n"
" --image and -p are optional, if NOT provided, the CLI will run in chat mode\n"
" --image, --audio and -p are optional, if NOT provided, the CLI will run in chat mode\n"
" to disable using GPU for mmproj model, add --no-mmproj-offload\n",
argv[0]
);
@@ -142,7 +142,7 @@ struct mtmd_cli_context {
);
}
bool load_image(const std::string & fname) {
bool load_media(const std::string & fname) {
mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_file(fname.c_str()));
if (!bmp.ptr) {
return false;
@@ -243,7 +243,7 @@ int main(int argc, char ** argv) {
common_params params;
params.sampling.temp = 0.2; // lower temp by default for better quality
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, show_additional_info)) {
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_MTMD, show_additional_info)) {
return 1;
}
@@ -283,14 +283,14 @@ int main(int argc, char ** argv) {
if (is_single_turn) {
g_is_generating = true;
if (params.prompt.find("<__image__>") == std::string::npos) {
params.prompt += " <__image__>";
if (params.prompt.find(mtmd_default_marker()) == std::string::npos) {
params.prompt += mtmd_default_marker();
}
common_chat_msg msg;
msg.role = "user";
msg.content = params.prompt;
for (const auto & image : params.image) {
if (!ctx.load_image(image)) {
if (!ctx.load_media(image)) {
return 1; // error is already printed by libmtmd
}
}
@@ -303,7 +303,12 @@ int main(int argc, char ** argv) {
} else {
LOG("\n Running in chat mode, available commands:");
LOG("\n /image <path> load an image");
if (mtmd_support_vision(ctx.ctx_vision.get())) {
LOG("\n /image <path> load an image");
}
if (mtmd_support_audio(ctx.ctx_vision.get())) {
LOG("\n /audio <path> load an audio");
}
LOG("\n /clear clear the chat history");
LOG("\n /quit or /exit exit the program");
LOG("\n");
@@ -333,15 +338,17 @@ int main(int argc, char ** argv) {
continue;
}
g_is_generating = true;
if (line == "/image" || line.find("/image ") == 0) {
bool is_image = line == "/image" || line.find("/image ") == 0;
bool is_audio = line == "/audio" || line.find("/audio ") == 0;
if (is_image || is_audio) {
if (line.size() < 8) {
LOG_ERR("ERR: Missing image filename\n");
LOG_ERR("ERR: Missing media filename\n");
continue;
}
std::string image = line.substr(7);
if (ctx.load_image(image)) {
LOG("Image %s loaded\n", image.c_str());
content += "<__image__>";
std::string media_path = line.substr(7);
if (ctx.load_media(media_path)) {
LOG("%s %s loaded\n", media_path.c_str(), is_image ? "image" : "audio");
content += mtmd_default_marker();
}
// else, error is already printed by libmtmd
continue;
+29 -44
View File
@@ -12,17 +12,7 @@ size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks) {
size_t n_tokens = 0;
for (size_t i = 0; i < mtmd_input_chunks_size(chunks); i++) {
auto chunk = mtmd_input_chunks_get(chunks, i);
auto chunk_type = mtmd_input_chunk_get_type(chunk);
if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
size_t n_tokens_text;
mtmd_input_chunk_get_tokens_text(chunk, &n_tokens_text);
n_tokens += n_tokens_text;
} else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
auto tokens_image = mtmd_input_chunk_get_tokens_image(chunk);
n_tokens += mtmd_image_tokens_get_n_tokens(tokens_image);
} else {
GGML_ASSERT(false && "chunk type not supported");
}
n_tokens += mtmd_input_chunk_get_n_tokens(chunk);
}
return n_tokens;
}
@@ -31,17 +21,7 @@ llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks) {
llama_pos n_pos = 0;
for (size_t i = 0; i < mtmd_input_chunks_size(chunks); i++) {
auto chunk = mtmd_input_chunks_get(chunks, i);
auto chunk_type = mtmd_input_chunk_get_type(chunk);
if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
size_t n_tokens_text;
mtmd_input_chunk_get_tokens_text(chunk, &n_tokens_text);
n_pos += n_tokens_text;
} else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
auto tokens_image = mtmd_input_chunk_get_tokens_image(chunk);
n_pos += mtmd_image_tokens_get_n_pos(tokens_image);
} else {
GGML_ASSERT(false && "chunk type not supported");
}
n_pos += mtmd_input_chunk_get_n_pos(chunk);
}
return n_pos;
}
@@ -149,13 +129,10 @@ int32_t mtmd_helper_decode_image_chunk(
llama_seq_id seq_id,
int32_t n_batch,
llama_pos * new_n_past) {
if (mtmd_input_chunk_get_type(chunk) != MTMD_INPUT_CHUNK_TYPE_IMAGE) {
LOG_ERR("failed to decode image chunk: input chunk not of image type\n");
return -1;
}
const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk);
if (!image_tokens) {
LOG_ERR("failed to decode image chunk: image tokens are null\n");
auto chunk_type = mtmd_input_chunk_get_type(chunk);
const char * name = chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio";
if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
LOG_ERR("failed to decode chunk: input chunk not of image/audio type\n");
return -1;
}
@@ -163,15 +140,23 @@ int32_t mtmd_helper_decode_image_chunk(
int n_mmproj_embd = llama_model_n_embd(model);
int n_pos_per_embd = mtmd_decode_use_mrope(ctx) ? 4 : 1;
int32_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens);
int32_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk);
int32_t i_batch = 0;
int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
decode_embd_batch batch_embd(encoded_embd, n_tokens, n_pos_per_embd, n_mmproj_embd);
const int nx = mtmd_image_tokens_get_nx(image_tokens);
const int ny = mtmd_image_tokens_get_ny(image_tokens);
if (mtmd_decode_use_mrope(ctx)) {
const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk);
if (chunk_type != MTMD_INPUT_CHUNK_TYPE_IMAGE) {
LOG_ERR("failed to decode chunk: M-RoPE only accepts image chunk\n");
return -1;
}
if (!image_tokens) {
LOG_ERR("failed to decode chunk: image tokens are null\n");
return -1;
}
const int nx = mtmd_image_tokens_get_nx(image_tokens);
const int ny = mtmd_image_tokens_get_ny(image_tokens);
batch_embd.set_position_mrope(n_past, nx, ny, seq_id);
} else {
batch_embd.set_position_normal(n_past, seq_id);
@@ -187,22 +172,22 @@ int32_t mtmd_helper_decode_image_chunk(
int n_tokens_batch = std::min(n_batch, n_tokens - pos_offset);
llama_batch batch_embd_view = batch_embd.get_view(pos_offset, n_tokens_batch);
LOG_INF("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch);
LOG_INF("decoding %s batch %d/%d, n_tokens_batch = %d\n", name, i_batch+1, n_img_batches, n_tokens_batch);
int64_t t1 = ggml_time_ms();
int32_t ret = llama_decode(lctx, batch_embd_view);
if (ret != 0) {
LOG_ERR("failed to decode image\n");
LOG_ERR("failed to decode %s\n", name);
llama_set_causal_attn(lctx, true); // restore causal attn
return ret;
}
LOG_INF("image decoded (batch %d/%d) in %" PRId64 " ms\n", i_batch+1, n_img_batches, ggml_time_ms() - t1);
LOG_INF("%s decoded (batch %d/%d) in %" PRId64 " ms\n", name, i_batch+1, n_img_batches, ggml_time_ms() - t1);
i_batch++;
}
n_past += mtmd_image_tokens_get_n_pos(image_tokens);
n_past += mtmd_input_chunk_get_n_pos(chunk);
*new_n_past = n_past;
if (mtmd_decode_use_non_causal(ctx)) {
@@ -253,25 +238,25 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
*new_n_past += text_batch.n_tokens;
}
} else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk);
} else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE || chunk_type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
const char * name = chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio";
int64_t t0 = ggml_time_ms();
LOG_INF("encoding image or slice...\n");
LOG_INF("encoding %s slice...\n", name);
ret = mtmd_encode(ctx, image_tokens);
ret = mtmd_encode_chunk(ctx, chunk);
if (ret != 0) {
LOG_ERR("failed to encode image\n");
LOG_ERR("failed to encode %s slice\n", name);
llama_batch_free(text_batch);
return ret;
}
LOG_INF("image/slice encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
LOG_INF("%s slice encoded in %" PRId64 " ms\n", name, ggml_time_ms() - t0);
float * embd = mtmd_get_output_embd(ctx);
ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past);
if (ret != 0) {
LOG_ERR("failed to decode image\n");
LOG_ERR("failed to decode %s\n", name);
llama_batch_free(text_batch);
return ret;
}
+294 -73
View File
@@ -1,6 +1,7 @@
#include "clip.h"
#include "clip-impl.h"
#include "mtmd.h"
#include "mtmd-audio.h"
#include "llama.h"
@@ -19,17 +20,49 @@ struct mtmd_bitmap {
uint32_t ny;
std::vector<unsigned char> data;
std::string id; // optional user-defined id, for ex: can be set to image hash, useful for KV cache tracking
bool is_audio = false; // true if the bitmap is audio
};
struct mtmd_image_tokens_deleter {
void operator()(mtmd_image_tokens * val); // forward declaration
struct mtmd_image_tokens {
uint32_t nx; // number of tokens in x direction
uint32_t ny; // number of tokens in y direction
bool use_mrope_pos = false; // use M-RoPE position counting (the whole image is 1 temporal position)
uint32_t n_tokens() const { return nx * ny; }
clip_image_f32_batch batch_f32; // preprocessed image patches
std::string id; // optional user-defined ID, useful for KV cache tracking
mtmd_image_tokens clone() {
return mtmd_image_tokens{
nx,
ny,
use_mrope_pos,
batch_f32.clone(),
id
};
}
};
using mtmd_image_tokens_ptr = std::unique_ptr<mtmd_image_tokens, mtmd_image_tokens_deleter>;
using mtmd_image_tokens_ptr = std::unique_ptr<mtmd_image_tokens>;
struct mtmd_audio_tokens {
uint32_t n_tokens; // number of tokens
clip_image_f32_batch batch_f32; // preprocessed image patches
std::string id; // optional user-defined ID, useful for KV cache tracking
mtmd_audio_tokens clone() {
return mtmd_audio_tokens{
n_tokens,
batch_f32.clone(),
id
};
}
};
using mtmd_audio_tokens_ptr = std::unique_ptr<mtmd_audio_tokens>;
struct mtmd_input_chunk {
mtmd_input_chunk_type type;
std::vector<llama_token> tokens_text;
mtmd_image_tokens_ptr tokens_image;
mtmd_audio_tokens_ptr tokens_audio;
};
struct mtmd_input_chunks {
@@ -46,6 +79,10 @@ enum mtmd_slice_tmpl {
// TODO @ngxson : add support for idefics (SmolVLM)
};
const char * mtmd_default_marker() {
return "<__media__>";
}
mtmd_context_params mtmd_context_params_default() {
mtmd_context_params params;
params.use_gpu = true;
@@ -53,6 +90,7 @@ mtmd_context_params mtmd_context_params_default() {
params.n_threads = 4;
params.verbosity = GGML_LOG_LEVEL_INFO;
params.image_marker = MTMD_DEFAULT_IMAGE_MARKER;
params.media_marker = mtmd_default_marker();
return params;
}
@@ -63,7 +101,9 @@ struct mtmd_context {
bool print_timings;
int n_threads;
std::string image_marker;
std::string media_marker;
bool has_vision;
bool has_audio;
// for llava-uhd style models, we need special tokens in-between slices
// minicpmv calls them "slices", llama 4 calls them "tiles"
@@ -81,6 +121,9 @@ struct mtmd_context {
bool use_mrope = false; // for Qwen2VL, we need to use M-RoPE
// for whisper, we pre-calculate the mel filter bank
whisper_preprocessor::whisper_filters w_filters;
// TODO @ngxson : add timings
mtmd_context(const char * mmproj_fname,
@@ -89,8 +132,12 @@ struct mtmd_context {
text_model (text_model),
print_timings(ctx_params.print_timings),
n_threads (ctx_params.n_threads),
image_marker (ctx_params.image_marker)
media_marker (ctx_params.media_marker)
{
if (std::string(ctx_params.image_marker) != MTMD_DEFAULT_IMAGE_MARKER) {
throw std::runtime_error("custom image_marker is not supported anymore, use media_marker instead");
}
clip_context_params ctx_clip_params;
ctx_clip_params.use_gpu = ctx_params.use_gpu;
ctx_clip_params.verbosity = ctx_params.verbosity;
@@ -99,7 +146,9 @@ struct mtmd_context {
throw std::runtime_error(string_format("Failed to load CLIP model from %s\n", mmproj_fname));
}
use_mrope = clip_is_qwen2vl(ctx_clip);
has_vision = clip_has_vision_encoder(ctx_clip);
has_audio = clip_has_audio_encoder(ctx_clip);
use_mrope = clip_is_qwen2vl(ctx_clip);
projector_type proj = clip_get_projector_type(ctx_clip);
int minicpmv_version = clip_is_minicpmv(ctx_clip);
@@ -146,6 +195,21 @@ struct mtmd_context {
tok_row_end_trail = true; // add trailing end-of-row token
ov_img_first = false; // overview image is last
}
if (proj == PROJECTOR_TYPE_ULTRAVOX) {
// TODO @ngxson : check if model n_mel is 128 or 80
w_filters = whisper_precalc_filters::get_128_bins();
}
// warning messages
if (proj == PROJECTOR_TYPE_LLAMA4) {
LOG_WRN("%s: llama 4 vision is known to have degraded quality:\n"
" https://github.com/ggml-org/llama.cpp/pull/13282\n", __func__);
}
if (has_audio) {
LOG_WRN("%s: audio input is in experimental stage and may have reduced quality:\n"
" https://github.com/ggml-org/llama.cpp/pull/13623\n", __func__);
}
}
~mtmd_context() {
@@ -179,29 +243,6 @@ private:
}
};
struct mtmd_image_tokens_data {
clip_image_f32_batch batch_f32; // preprocessed image patches
};
struct mtmd_image_tokens {
uint32_t nx; // number of tokens in x direction
uint32_t ny; // number of tokens in y direction
bool use_mrope_pos = false; // use M-RoPE position counting (the whole image is 1 temporal position)
uint32_t n_tokens() const { return nx * ny; }
clip_image_f32_batch batch_f32; // preprocessed image patches
std::string id; // optional user-defined ID, useful for KV cache tracking
mtmd_image_tokens clone() {
return mtmd_image_tokens{
nx,
ny,
use_mrope_pos,
batch_f32.clone(),
id
};
}
};
mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
const struct llama_model * text_model,
const struct mtmd_context_params ctx_params) {
@@ -247,59 +288,63 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
auto vocab = llama_model_get_vocab(ctx->text_model);
std::string prompt_modified(text->text);
std::string marker_modified(ctx->image_marker);
std::string marker_modified(ctx->media_marker);
projector_type proj_type = clip_get_projector_type(ctx->ctx_clip);
// for compatibility, we convert image marker to media marker
string_replace_all(prompt_modified, MTMD_DEFAULT_IMAGE_MARKER, ctx->media_marker);
// a bit hacky here, but works for now
// for some models, we need to add prefix and suffix to the image embeddings
if (clip_is_gemma3(ctx->ctx_clip)) {
// gemma 3
// <start_of_image> ... (image embeddings) ... <end_of_image>
marker_modified = "<start_of_image>" + ctx->image_marker + "<end_of_image>";
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
marker_modified = "<start_of_image>" + ctx->media_marker + "<end_of_image>";
string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
} else if (proj_type == PROJECTOR_TYPE_IDEFICS3) {
// https://github.com/huggingface/transformers/blob/a42ba80fa520c784c8f11a973ca9034e5f859b79/src/transformers/models/idefics3/processing_idefics3.py#L192-L215
marker_modified = "<fake_token_around_image><global-img>" + ctx->image_marker + "<fake_token_around_image>";
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
marker_modified = "<fake_token_around_image><global-img>" + ctx->media_marker + "<fake_token_around_image>";
string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
} else if (proj_type == PROJECTOR_TYPE_PIXTRAL) {
// https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md
marker_modified = ctx->image_marker + "[IMG_END]";
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
marker_modified = ctx->media_marker + "[IMG_END]";
string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
} else if (proj_type == PROJECTOR_TYPE_QWEN2VL || proj_type == PROJECTOR_TYPE_QWEN25VL) {
// <|vision_start|> ... (image embeddings) ... <|vision_end|>
marker_modified = "<|vision_start|>" + ctx->image_marker + "<|vision_end|>";
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
marker_modified = "<|vision_start|>" + ctx->media_marker + "<|vision_end|>";
string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
} else if (proj_type == PROJECTOR_TYPE_LLAMA4) {
// (more details in mtmd_context constructor)
marker_modified = "<|image_start|>" + ctx->image_marker + "<|image_end|>";
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
marker_modified = "<|image_start|>" + ctx->media_marker + "<|image_end|>";
string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
} else if (proj_type == PROJECTOR_TYPE_INTERNVL) {
// <img> ... (image embeddings) ... </img>
marker_modified = "<img>" + ctx->image_marker + "</img>";
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
marker_modified = "<img>" + ctx->media_marker + "</img>";
string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
}
// llava-1.5, llava-1.6, Yi-VL, Yi-34B, granite: don't need to add prefix and suffix
// for glm-edge, BOI and EOI token's embeddings are not present in the text model
std::vector<std::string> parts = string_split_str(prompt_modified, ctx->image_marker);
std::vector<std::string> parts = string_split_str(prompt_modified, ctx->media_marker);
output->entries.clear();
output->entries.reserve(parts.size());
size_t i_img = 0;
size_t i_bm = 0;
// utility for adding raw tokens
auto add_text_chunk = [&output](std::vector<llama_token> && tokens) {
mtmd_input_chunk chunk{
MTMD_INPUT_CHUNK_TYPE_TEXT,
std::move(tokens),
{},
nullptr, // image tokens
nullptr, // audio tokens
};
output->entries.emplace_back(std::move(chunk));
};
@@ -317,8 +362,9 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
mtmd_input_chunk chunk{
MTMD_INPUT_CHUNK_TYPE_IMAGE,
{},
{}, // text tokens
std::move(image_tokens),
nullptr, // audio tokens
};
chunks.emplace_back(std::move(chunk));
}
@@ -336,24 +382,36 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
mtmd_input_chunk chunk{
MTMD_INPUT_CHUNK_TYPE_TEXT,
std::move(tokens),
{},
nullptr, // image tokens
nullptr, // audio tokens
};
output->entries.emplace_back(std::move(chunk));
if (&parts.back() != &part) {
// add image token to middle of 2 parts
// only add image/audio tokens to middle of 2 parts
// therefore, we skip handling image/audio if this is the last part
if (&parts.back() == &part) {
continue;
}
if (i_img >= n_bitmaps) {
if (!bitmaps[i_bm]->is_audio) {
// handle image
if (i_bm >= n_bitmaps) {
LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size());
return 1;
}
if (!ctx->has_vision) {
LOG_ERR("%s: error: model does not support vision input\n", __func__);
return 2;
}
// convert mtmd_bitmap to clip_image_u8
clip_image_u8_ptr img_u8(clip_image_u8_init());
img_u8->nx = bitmaps[i_img]->nx;
img_u8->ny = bitmaps[i_img]->ny;
img_u8->buf.resize(bitmaps[i_img]->data.size());
std::memcpy(img_u8->buf.data(), bitmaps[i_img]->data.data(), img_u8->nx * img_u8->ny * 3);
img_u8->nx = bitmaps[i_bm]->nx;
img_u8->ny = bitmaps[i_bm]->ny;
img_u8->buf.resize(bitmaps[i_bm]->data.size());
std::memcpy(img_u8->buf.data(), bitmaps[i_bm]->data.data(), img_u8->nx * img_u8->ny * 3);
// preprocess image
clip_image_f32_batch batch_f32;
@@ -370,7 +428,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
|| ctx->slice_tmpl == MTMD_SLICE_TMPL_LLAMA4
) {
// split batch into chunks of single images
auto chunks = split_batch_to_chunk(std::move(batch_f32), bitmaps[i_img]->id);
auto chunks = split_batch_to_chunk(std::move(batch_f32), bitmaps[i_bm]->id);
GGML_ASSERT(chunks.size() > 0);
auto ov_chunk = std::move(chunks.front());
@@ -446,7 +504,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
image_tokens->ny = 1;
}
image_tokens->batch_f32 = std::move(batch_f32);
image_tokens->id = bitmaps[i_img]->id; // optional
image_tokens->id = bitmaps[i_bm]->id; // optional
LOG_DBG("image_tokens->nx = %d\n", image_tokens->nx);
LOG_DBG("image_tokens->ny = %d\n", image_tokens->ny);
@@ -454,23 +512,101 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
mtmd_input_chunk chunk{
MTMD_INPUT_CHUNK_TYPE_IMAGE,
{},
{}, // text tokens
std::move(image_tokens),
nullptr, // audio tokens
};
output->entries.emplace_back(std::move(chunk));
}
i_img++; // move to next image
i_bm++; // move to next image
continue;
} else {
// handle audio
if (i_bm >= n_bitmaps) {
LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size());
return 1;
}
if (!ctx->has_audio) {
LOG_ERR("%s: error: model does not support audio input\n", __func__);
return 2;
}
if (bitmaps[i_bm]->data.size() == 0) {
LOG_ERR("%s: error: empty audio data\n", __func__);
return 2;
}
// preprocess audio
GGML_ASSERT(ctx->w_filters.n_mel); // make sure we have filter preloaded
std::vector<whisper_preprocessor::whisper_mel> mel_spec_chunks;
const float * samples = (const float *)bitmaps[i_bm]->data.data();
size_t n_samples = bitmaps[i_bm]->data.size() / sizeof(float);
bool ok = whisper_preprocessor::preprocess_audio(samples, n_samples, ctx->w_filters, mel_spec_chunks);
if (!ok) {
LOG_ERR("Unable to preprocess audio\n");
return 2;
}
// consider each mel_spec as a separate audio chunk
// TODO: maybe support batching, but this may come with memory cost
for (auto & mel_spec : mel_spec_chunks) {
clip_image_f32_ptr mel_f32(clip_image_f32_init());
mel_f32->nx = mel_spec.n_len;
mel_f32->ny = mel_spec.n_mel;
mel_f32->buf = std::move(mel_spec.data);
size_t n_tokens = clip_n_output_tokens(ctx->ctx_clip, mel_f32.get());
clip_image_f32_batch batch_f32;
batch_f32.is_audio = true;
batch_f32.entries.push_back(std::move(mel_f32));
mtmd_audio_tokens_ptr audio_tokens(new mtmd_audio_tokens);
audio_tokens->n_tokens = n_tokens;
audio_tokens->batch_f32 = std::move(batch_f32);
audio_tokens->id = bitmaps[i_bm]->id; // optional
LOG_DBG("audio_tokens->n_tokens = %d\n", audio_tokens->n_tokens);
mtmd_input_chunk chunk{
MTMD_INPUT_CHUNK_TYPE_AUDIO,
{}, // text tokens
nullptr, // image tokens
std::move(audio_tokens),
};
output->entries.emplace_back(std::move(chunk));
}
i_bm++;
continue;
}
}
return 0;
}
static void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens) {
if (image_tokens) {
delete image_tokens;
int32_t mtmd_encode_chunk(mtmd_context * ctx, const mtmd_input_chunk * chunk) {
if (chunk->type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
LOG_WRN("mtmd_encode_chunk has no effect for text chunks\n");
return 0;
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
return mtmd_encode(ctx, chunk->tokens_image.get());
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
ctx->image_embd_v.resize(chunk->tokens_audio->n_tokens * n_mmproj_embd);
bool ok = clip_image_batch_encode(
ctx->ctx_clip,
ctx->n_threads,
&chunk->tokens_audio->batch_f32,
ctx->image_embd_v.data());
return ok ? 0 : 1;
}
LOG_ERR("mtmd_encode_chunk: unknown chunk type %d\n", (int)chunk->type);
return 1;
}
int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) {
@@ -516,8 +652,12 @@ bool mtmd_decode_use_mrope(mtmd_context * ctx) {
return ctx->use_mrope;
}
void mtmd_image_tokens_deleter::operator()(mtmd_image_tokens * val) {
mtmd_image_tokens_free(val);
bool mtmd_support_vision(mtmd_context * ctx) {
return ctx->has_vision;
}
bool mtmd_support_audio(mtmd_context * ctx) {
return ctx->has_audio;
}
// these 2 helpers below use internal clip_image_u8_ptr,
@@ -526,6 +666,15 @@ void mtmd_image_tokens_deleter::operator()(mtmd_image_tokens * val) {
// whichever library they want, and then use mtmd_bitmap_init() to create bitmap
mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len) {
if (audio_helpers::is_audio_file((const char *)buf, len)) {
std::vector<float> pcmf32;
if (!audio_helpers::decode_audio_from_buf(buf, len, COMMON_SAMPLE_RATE, pcmf32)) {
LOG_ERR("Unable to read WAV audio file from buffer\n");
return nullptr;
}
return mtmd_bitmap_init_from_audio(pcmf32.size(), pcmf32.data());
}
clip_image_u8_ptr img_u8(clip_image_u8_init());
bool ok = clip_image_load_from_bytes(buf, len, img_u8.get());
if (!ok) {
@@ -538,15 +687,26 @@ mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t
}
mtmd_bitmap * mtmd_helper_bitmap_init_from_file(const char * fname) {
clip_image_u8_ptr img_u8(clip_image_u8_init());
bool ok = clip_image_load_from_file(fname, img_u8.get());
if (!ok) {
LOG_ERR("Unable to load image %s\n", fname);
std::vector<unsigned char> buf;
FILE * f = fopen(fname, "rb");
if (!f) {
LOG_ERR("Unable to open file %s: %s\n", fname, strerror(errno));
return nullptr;
}
uint32_t nx, ny;
unsigned char * data = clip_image_u8_get_data(img_u8.get(), &nx, &ny);
return mtmd_bitmap_init(nx, ny, data);
fseek(f, 0, SEEK_END);
long file_size = ftell(f);
fseek(f, 0, SEEK_SET);
buf.resize(file_size);
size_t n_read = fread(buf.data(), 1, file_size, f);
fclose(f);
if (n_read != (size_t)file_size) {
LOG_ERR("Failed to read entire file %s", fname);
return nullptr;
}
return mtmd_helper_bitmap_init_from_buf(buf.data(), buf.size());
}
//
@@ -567,6 +727,18 @@ mtmd_bitmap * mtmd_bitmap_init(uint32_t nx,
return bitmap;
}
mtmd_bitmap * mtmd_bitmap_init_from_audio(size_t n_samples,
const float * data) {
mtmd_bitmap * bitmap = new mtmd_bitmap;
bitmap->nx = n_samples;
bitmap->ny = 1;
bitmap->is_audio = true;
size_t data_size = n_samples * sizeof(float);
bitmap->data.resize(data_size);
std::memcpy(bitmap->data.data(), data, data_size);
return bitmap;
}
uint32_t mtmd_bitmap_get_nx(const mtmd_bitmap * bitmap) {
return bitmap->nx;
}
@@ -579,6 +751,14 @@ const unsigned char * mtmd_bitmap_get_data(const mtmd_bitmap * bitmap) {
return bitmap->data.data();
}
size_t mtmd_bitmap_get_n_bytes(const mtmd_bitmap * bitmap) {
return bitmap->data.size();
}
bool mtmd_bitmap_is_audio(const mtmd_bitmap * bitmap) {
return bitmap->is_audio;
}
const char * mtmd_bitmap_get_id(const mtmd_bitmap * bitmap) {
return bitmap->id.c_str();
}
@@ -642,17 +822,56 @@ const mtmd_image_tokens * mtmd_input_chunk_get_tokens_image(const mtmd_input_chu
return nullptr;
}
size_t mtmd_input_chunk_get_n_tokens(const mtmd_input_chunk * chunk) {
if (chunk->type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
return chunk->tokens_text.size();
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
return mtmd_image_tokens_get_n_tokens(chunk->tokens_image.get());
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
return chunk->tokens_audio->n_tokens;
} else {
GGML_ABORT("invalid chunk type");
}
}
llama_pos mtmd_input_chunk_get_n_pos(const mtmd_input_chunk * chunk) {
if (chunk->type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
return chunk->tokens_text.size();
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
return mtmd_image_tokens_get_n_pos(chunk->tokens_image.get());
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
return chunk->tokens_audio->n_tokens;
} else {
GGML_ABORT("invalid chunk type");
}
}
const char * mtmd_input_chunk_get_id(const mtmd_input_chunk * chunk) {
if (chunk->type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
return chunk->tokens_image->id.c_str();
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
return chunk->tokens_audio->id.c_str();
}
return nullptr;
}
mtmd_input_chunk * mtmd_input_chunk_copy(const mtmd_input_chunk * chunk) {
mtmd_input_chunk * copy = new mtmd_input_chunk{
chunk->type,
chunk->tokens_text,
mtmd_image_tokens_ptr(),
nullptr,
nullptr,
};
if (chunk->tokens_image) {
// copy the image tokens
copy->tokens_image = mtmd_image_tokens_ptr(new mtmd_image_tokens());
*copy->tokens_image = chunk->tokens_image->clone();
}
if (chunk->tokens_audio) {
// copy the audio tokens
copy->tokens_audio = mtmd_audio_tokens_ptr(new mtmd_audio_tokens());
*copy->tokens_audio = chunk->tokens_audio->clone();
}
return copy;
}
@@ -700,7 +919,8 @@ mtmd_input_chunks * mtmd_test_create_input_chunks() {
mtmd_input_chunk chunk_text{
MTMD_INPUT_CHUNK_TYPE_TEXT,
std::move(tokens_text),
{},
nullptr, // image tokens
nullptr, // audio tokens
};
chunks->entries.emplace_back(std::move(chunk_text));
@@ -712,8 +932,9 @@ mtmd_input_chunks * mtmd_test_create_input_chunks() {
image_tokens->id = "image_1";
mtmd_input_chunk chunk_image{
MTMD_INPUT_CHUNK_TYPE_IMAGE,
{},
{}, // text tokens
std::move(image_tokens),
nullptr, // audio tokens
};
chunks->entries.emplace_back(std::move(chunk_image));
+52 -21
View File
@@ -39,6 +39,7 @@
# define MTMD_API
#endif
// deprecated marker, use mtmd_default_marker() instead
#define MTMD_DEFAULT_IMAGE_MARKER "<__image__>"
#ifdef __cplusplus
@@ -48,6 +49,7 @@ extern "C" {
enum mtmd_input_chunk_type {
MTMD_INPUT_CHUNK_TYPE_TEXT,
MTMD_INPUT_CHUNK_TYPE_IMAGE,
MTMD_INPUT_CHUNK_TYPE_AUDIO,
};
// opaque types
@@ -79,9 +81,12 @@ struct mtmd_context_params {
bool print_timings;
int n_threads;
enum ggml_log_level verbosity;
const char * image_marker;
const char * image_marker; // deprecated, use media_marker instead
const char * media_marker;
};
MTMD_API const char * mtmd_default_marker(void);
MTMD_API struct mtmd_context_params mtmd_context_params_default(void);
// initialize the mtmd context
@@ -98,18 +103,28 @@ MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx);
// whether the current model use M-RoPE for llama_decode
MTMD_API bool mtmd_decode_use_mrope(mtmd_context * ctx);
// whether the current model supports vision input
MTMD_API bool mtmd_support_vision(mtmd_context * ctx);
// whether the current model supports audio input
MTMD_API bool mtmd_support_audio(mtmd_context * ctx);
// mtmd_bitmap
//
// length of data must be nx * ny * 3
// the data is in RGBRGBRGB... format
MTMD_API mtmd_bitmap * mtmd_bitmap_init (uint32_t nx,
uint32_t ny,
const unsigned char * data);
MTMD_API uint32_t mtmd_bitmap_get_nx (const mtmd_bitmap * bitmap);
MTMD_API uint32_t mtmd_bitmap_get_ny (const mtmd_bitmap * bitmap);
MTMD_API const unsigned char * mtmd_bitmap_get_data(const mtmd_bitmap * bitmap);
MTMD_API void mtmd_bitmap_free (mtmd_bitmap * bitmap);
// if bitmap is image:
// length of data must be nx * ny * 3
// the data is in RGBRGBRGB... format
// if bitmap is audio:
// length of data must be n_samples * sizeof(float)
// the data is in float format (PCM F32)
MTMD_API mtmd_bitmap * mtmd_bitmap_init (uint32_t nx, uint32_t ny, const unsigned char * data);
MTMD_API mtmd_bitmap * mtmd_bitmap_init_from_audio(size_t n_samples, const float * data);
MTMD_API uint32_t mtmd_bitmap_get_nx (const mtmd_bitmap * bitmap);
MTMD_API uint32_t mtmd_bitmap_get_ny (const mtmd_bitmap * bitmap);
MTMD_API const unsigned char * mtmd_bitmap_get_data (const mtmd_bitmap * bitmap);
MTMD_API size_t mtmd_bitmap_get_n_bytes(const mtmd_bitmap * bitmap);
MTMD_API bool mtmd_bitmap_is_audio (const mtmd_bitmap * bitmap);
MTMD_API void mtmd_bitmap_free (mtmd_bitmap * bitmap);
// bitmap ID is optional, but useful for KV cache tracking
// these getters/setters are dedicated functions, so you can for example calculate the hash of the image based on mtmd_bitmap_get_data()
MTMD_API const char * mtmd_bitmap_get_id(const mtmd_bitmap * bitmap);
@@ -132,6 +147,11 @@ MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chu
MTMD_API enum mtmd_input_chunk_type mtmd_input_chunk_get_type (const mtmd_input_chunk * chunk);
MTMD_API const llama_token * mtmd_input_chunk_get_tokens_text (const mtmd_input_chunk * chunk, size_t * n_tokens_output);
MTMD_API const mtmd_image_tokens * mtmd_input_chunk_get_tokens_image(const mtmd_input_chunk * chunk);
MTMD_API size_t mtmd_input_chunk_get_n_tokens (const mtmd_input_chunk * chunk);
// returns nullptr for ID on text chunk
MTMD_API const char * mtmd_input_chunk_get_id (const mtmd_input_chunk * chunk);
// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
MTMD_API llama_pos mtmd_input_chunk_get_n_pos (const mtmd_input_chunk * chunk);
// in case you want to use custom logic to handle the chunk (i.e. KV cache management)
// you can move the chunk ownership to your own code by copying it
@@ -144,27 +164,28 @@ MTMD_API void mtmd_input_chunk_free(mtmd_input_chunk * chunk);
//
// the instance will be constructed via mtmd_tokenize()
// it will be freed along with mtmd_input_chunk
MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens);
MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens); // TODO: deprecate
MTMD_API size_t mtmd_image_tokens_get_nx (const mtmd_image_tokens * image_tokens);
MTMD_API size_t mtmd_image_tokens_get_ny (const mtmd_image_tokens * image_tokens);
MTMD_API const char * mtmd_image_tokens_get_id (const mtmd_image_tokens * image_tokens);
MTMD_API const char * mtmd_image_tokens_get_id (const mtmd_image_tokens * image_tokens); // TODO: deprecate
// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
MTMD_API llama_pos mtmd_image_tokens_get_n_pos (const mtmd_image_tokens * image_tokens);
MTMD_API llama_pos mtmd_image_tokens_get_n_pos (const mtmd_image_tokens * image_tokens); // TODO: deprecate
// tokenize an input text prompt and an image
// the prompt must have the input image marker (default: "<__image__>") in it
// the marker will be replaced with the image tokens
// tokenize an input text prompt and a list of bitmaps (images/audio)
// the prompt must have the input image marker (default: "<__media__>") in it
// the default marker is defined by mtmd_default_marker()
// the marker will be replaced with the image/audio chunk
// for example:
// "here is an image: <__image__>\ndescribe it in detail."
// "here is an image: <__media__>\ndescribe it in detail."
// this will gives 3 chunks:
// 1. "here is an image: <start_of_image>"
// 2. (image tokens)
// 2. (image/audio tokens)
// 3. "<end_of_image>\ndescribe it in detail."
// number of bitmaps must be equal to the number of image markers in the prompt
// number of bitmaps must be equal to the number of markers in the prompt
// this function is thread-safe (shared ctx)
// return values:
// 0 on success
// 1 on number of images not matching the number of markers
// 1 on number of bitmaps not matching the number of markers
// 2 on image preprocessing error
MTMD_API int32_t mtmd_tokenize(mtmd_context * ctx,
mtmd_input_chunks * output,
@@ -173,9 +194,14 @@ MTMD_API int32_t mtmd_tokenize(mtmd_context * ctx,
size_t n_bitmaps);
// returns 0 on success
// TODO: deprecate
MTMD_API int32_t mtmd_encode(mtmd_context * ctx,
const mtmd_image_tokens * image_tokens);
// returns 0 on success
MTMD_API int32_t mtmd_encode_chunk(mtmd_context * ctx,
const mtmd_input_chunk * chunk);
// get output embeddings from the last encode pass
MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
@@ -189,12 +215,16 @@ MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
//
// helper function to construct a mtmd_bitmap from a file
// it calls mtmd_helper_bitmap_init_from_buf() internally
// returns nullptr on failure
// this function is thread-safe
MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_file(const char * fname);
// helper function to construct a mtmd_bitmap from a buffer containing a file
// the file content must be an image in format supported by stb_image (jpg, png, bmp, gif, etc.)
// supported formats:
// image: formats supported by stb_image: jpg, png, bmp, gif, etc.
// audio: formats supported by miniaudio: wav, mp3, flac
// note: audio files will be auto-detected based on magic bytes
// returns nullptr on failure
// this function is thread-safe
MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len);
@@ -293,6 +323,7 @@ struct bitmap {
uint32_t nx() { return mtmd_bitmap_get_nx(ptr.get()); }
uint32_t ny() { return mtmd_bitmap_get_ny(ptr.get()); }
const unsigned char * data() { return mtmd_bitmap_get_data(ptr.get()); }
size_t n_bytes() { return mtmd_bitmap_get_n_bytes(ptr.get()); }
std::string id() { return mtmd_bitmap_get_id(ptr.get()); }
void set_id(const char * id) { mtmd_bitmap_set_id(ptr.get(), id); }
};
Binary file not shown.
+219 -171
View File
@@ -1,3 +1,4 @@
#include "chat.h"
#include "utils.hpp"
#include "arg.h"
@@ -114,11 +115,11 @@ struct slot_params {
struct common_params_speculative speculative;
// OAI-compat fields
bool verbose = false;
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
bool verbose = false;
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
common_chat_syntax oaicompat_chat_syntax;
json to_json() const {
std::vector<std::string> samplers;
@@ -176,7 +177,10 @@ struct slot_params {
{"grammar_lazy", sampling.grammar_lazy},
{"grammar_triggers", grammar_triggers},
{"preserved_tokens", sampling.preserved_tokens},
{"chat_format", common_chat_format_name(oaicompat_chat_format)},
{"chat_format", common_chat_format_name(oaicompat_chat_syntax.format)},
{"reasoning_format", (oaicompat_chat_syntax.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "deepseek" : "none")},
{"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content},
{"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open},
{"samplers", samplers},
{"speculative.n_max", speculative.n_max},
{"speculative.n_min", speculative.n_min},
@@ -352,11 +356,14 @@ struct server_task {
{
auto it = data.find("chat_format");
if (it != data.end()) {
params.oaicompat_chat_format = static_cast<common_chat_format>(it->get<int>());
SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_format).c_str());
params.oaicompat_chat_syntax.format = static_cast<common_chat_format>(it->get<int>());
SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_syntax.format).c_str());
} else {
params.oaicompat_chat_format = defaults.oaicompat_chat_format;
params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format;
}
params.oaicompat_chat_syntax.reasoning_format = params_base.reasoning_format;
params.oaicompat_chat_syntax.reasoning_in_content = params.stream;
params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false);
}
{
@@ -396,7 +403,14 @@ struct server_task {
params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
}
} else {
params.sampling.grammar_triggers.push_back(std::move(ct.value));
if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN) {
SRV_DBG("Grammar trigger pattern: `%s`\n", ct.value.value.c_str());
} else if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL) {
SRV_DBG("Grammar trigger pattern full: `%s`\n", ct.value.value.c_str());
} else {
throw std::runtime_error("Unknown grammar trigger type");
}
params.sampling.grammar_triggers.emplace_back(std::move(ct.value));
}
}
}
@@ -639,11 +653,12 @@ struct server_task_result_cmpl_final : server_task_result {
slot_params generation_params;
// OAI-compat fields
bool verbose = false;
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
bool verbose = false;
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
common_chat_msg oaicompat_msg;
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
virtual int get_index() override {
return index;
@@ -738,47 +753,20 @@ struct server_task_result_cmpl_final : server_task_result {
json to_json_oaicompat_chat() {
std::string finish_reason = "length";
common_chat_msg msg;
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
SRV_DBG("Parsing chat message: %s\n", content.c_str());
msg = common_chat_parse(content, oaicompat_chat_format);
finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
if (!oaicompat_msg.empty()) {
msg = oaicompat_msg;
} else {
msg.role = "assistant";
msg.content = content;
}
json message {
{"role", "assistant"},
};
if (!msg.reasoning_content.empty()) {
message["reasoning_content"] = msg.reasoning_content;
}
if (msg.content.empty() && !msg.tool_calls.empty()) {
message["content"] = json();
} else {
message["content"] = msg.content;
}
if (!msg.tool_calls.empty()) {
auto tool_calls = json::array();
for (const auto & tc : msg.tool_calls) {
tool_calls.push_back({
{"type", "function"},
{"function", {
{"name", tc.name},
{"arguments", tc.arguments},
}},
// Some templates generate and require an id (sometimes in a very specific format, e.g. Mistral Nemo).
// We only generate a random id for the ones that don't generate one by themselves
// (they also won't get to see it as their template likely doesn't use it, so it's all for the client)
{"id", tc.id.empty() ? gen_tool_call_id() : tc.id},
});
}
message["tool_calls"] = tool_calls;
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
}
json choice {
{"finish_reason", finish_reason},
{"index", 0},
{"message", message},
{"message", msg.to_json_oaicompat<json>()},
};
if (!stream && probs_output.size() > 0) {
@@ -818,17 +806,35 @@ struct server_task_result_cmpl_final : server_task_result {
std::time_t t = std::time(0);
std::string finish_reason = "length";
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
finish_reason = "stop";
finish_reason = oaicompat_msg.tool_calls.empty() ? "stop" : "tool_calls";
}
json choice = json {
{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()}
};
json deltas = json::array();
for (const auto & diff : oaicompat_msg_diffs) {
deltas.push_back({
{"choices", json::array({
json {
{"finish_reason", nullptr},
{"index", 0},
{"delta", common_chat_msg_diff_to_json_oaicompat<json>(diff)},
},
})},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"system_fingerprint", build_info},
{"object", "chat.completion.chunk"},
});
}
json ret = json {
{"choices", json::array({choice})},
deltas.push_back({
{"choices", json::array({
json {
{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()},
},
})},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
@@ -839,18 +845,18 @@ struct server_task_result_cmpl_final : server_task_result {
{"prompt_tokens", n_prompt_tokens},
{"total_tokens", n_decoded + n_prompt_tokens},
}},
};
});
if (timings.prompt_n >= 0) {
ret.push_back({"timings", timings.to_json()});
deltas.back().push_back({"timings", timings.to_json()});
}
// extra fields for debugging purposes
if (verbose) {
ret["__verbose"] = to_json_non_oaicompat();
if (verbose && !deltas.empty()) {
deltas.front()["__verbose"] = to_json_non_oaicompat();
}
return ret;
return deltas;
}
};
@@ -868,10 +874,11 @@ struct server_task_result_cmpl_partial : server_task_result {
result_timings timings;
// OAI-compat fields
bool verbose = false;
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
bool verbose = false;
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
virtual int get_index() override {
return index;
@@ -955,84 +962,50 @@ struct server_task_result_cmpl_partial : server_task_result {
std::time_t t = std::time(0);
json choices;
if (first) {
if (content.empty()) {
choices = json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{{"role", "assistant"}}}}});
} else {
// We have to send this as two updates to conform to openai behavior
// initial_ret is the role message for stream=True
json initial_ret = json{{"choices", json::array({json{
{"finish_reason", nullptr},
{"index", 0},
{"delta", json{
{"role", "assistant"},
{"content", ""}
}}}})},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"system_fingerprint", build_info},
{"object", "chat.completion.chunk"}};
json second_ret = json{
{"choices", json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json {
{"content", content}}}
}})},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"system_fingerprint", build_info},
{"object", "chat.completion.chunk"}};
if (prob_output.probs.size() > 0) {
second_ret["choices"][0]["logprobs"] = json{
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
};
}
if (timings.prompt_n >= 0) {
second_ret.push_back({"timings", timings.to_json()});
}
return std::vector<json>({initial_ret, second_ret});
}
} else {
choices = json::array({json{
{"finish_reason", nullptr},
{"index", 0},
{"delta",
json {
{"content", content},
}},
}});
}
GGML_ASSERT(choices.size() >= 1);
if (prob_output.probs.size() > 0) {
choices[0]["logprobs"] = json{
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
};
}
json ret = json {
{"choices", choices},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"system_fingerprint", build_info},
{"object", "chat.completion.chunk"}
std::vector<json> deltas;
auto add_delta = [&](const json & delta) {
deltas.push_back({
{"choices", json::array({
json {
{"finish_reason", nullptr},
{"index", 0},
{"delta", delta},
},
})},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"system_fingerprint", build_info},
{"object", "chat.completion.chunk"},
});
};
if (timings.prompt_n >= 0) {
ret.push_back({"timings", timings.to_json()});
// We have to send an initial update to conform to openai behavior
if (first) {
add_delta({
{"role", "assistant"},
{"content", nullptr},
});
}
return std::vector<json>({ret});
for (const auto & diff : oaicompat_msg_diffs) {
add_delta(common_chat_msg_diff_to_json_oaicompat<json>(diff));
}
if (!deltas.empty()) {
GGML_ASSERT(deltas[deltas.size() - 1].at("choices").size() >= 1);
if (prob_output.probs.size() > 0) {
deltas[deltas.size() - 1].at("choices").at(0)["logprobs"] = json {
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
};
}
if (timings.prompt_n >= 0) {
deltas[deltas.size() - 1].push_back({"timings", timings.to_json()});
}
}
return deltas;
}
};
@@ -1293,6 +1266,7 @@ struct server_slot {
std::string generated_text;
llama_tokens generated_tokens;
common_chat_msg chat_msg;
server_tokens cache_tokens;
@@ -1313,6 +1287,7 @@ struct server_slot {
llama_token sampled;
common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
std::vector<std::string> generated_tool_call_ids;
// stats
size_t n_sent_text = 0; // number of sent text character
@@ -1342,9 +1317,13 @@ struct server_slot {
n_past = 0;
n_sent_text = 0;
task_type = SERVER_TASK_TYPE_COMPLETION;
chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
generated_tokens.clear();
generated_token_probs.clear();
chat_msg = {};
json_schema = json();
generated_tool_call_ids.clear();
// clear speculative decoding stats
n_draft_total = 0;
@@ -1424,6 +1403,21 @@ struct server_slot {
return timings;
}
const common_chat_msg & update_chat_msg(std::vector<common_chat_msg_diff> & diffs) {
auto previous_msg = chat_msg;
SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
auto new_msg = common_chat_parse(
generated_text,
/* is_partial= */ stop != STOP_TYPE_EOS,
params.oaicompat_chat_syntax);
if (!new_msg.empty()) {
new_msg.ensure_tool_call_ids_set(generated_tool_call_ids, gen_tool_call_id);
chat_msg = new_msg;
diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg);
}
return chat_msg;
}
size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
size_t stop_pos = std::string::npos;
@@ -1891,6 +1885,7 @@ struct server_context {
float slot_prompt_similarity = 0.0f;
common_chat_templates_ptr chat_templates;
oaicompat_parser_options oai_parser_opt;
~server_context() {
mtmd_free(mctx);
@@ -2086,6 +2081,15 @@ struct server_context {
}
metrics.init();
oai_parser_opt = {
/* use_jinja */ params_base.use_jinja,
/* prefill_assistant */ params_base.prefill_assistant,
/* reasoning_format */ params_base.reasoning_format,
/* common_chat_templates */ chat_templates.get(),
/* allow_image */ mctx ? mtmd_support_vision(mctx) : false,
/* allow_audio */ mctx ? mtmd_support_audio (mctx) : false,
};
}
server_slot * get_slot_by_id(int id) {
@@ -2465,10 +2469,12 @@ struct server_context {
res->n_prompt_tokens = slot.n_prompt_tokens;
res->post_sampling_probs = slot.params.post_sampling_probs;
res->verbose = slot.params.verbose;
res->oaicompat = slot.params.oaicompat;
res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
res->verbose = slot.params.verbose;
res->oaicompat = slot.params.oaicompat;
res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
slot.update_chat_msg(res->oaicompat_msg_diffs);
// populate res.probs_output
if (slot.params.sampling.n_probs > 0) {
@@ -2489,7 +2495,7 @@ struct server_context {
res->id_slot = slot.id;
res->index = slot.index;
res->content = std::move(slot.generated_text);
res->content = slot.generated_text;
res->tokens = std::move(slot.generated_tokens);
res->timings = slot.get_timings();
res->prompt = slot.prompt_tokens.detokenize(ctx, true);
@@ -2509,7 +2515,8 @@ struct server_context {
res->oaicompat = slot.params.oaicompat;
res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
res->oaicompat_chat_format = slot.params.oaicompat_chat_format;
res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);
// populate res.probs_output
if (slot.params.sampling.n_probs > 0) {
if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) {
@@ -3341,6 +3348,37 @@ struct server_context {
common_set_adapter_lora(ctx, slot_batched->lora);
}
const bool do_encode = (params_base.embedding || params_base.reranking);
// pad the batch so that batch.n_tokens >= n_slots
// TODO: temporary workaround for https://github.com/ggml-org/llama.cpp/issues/13689
if (do_encode) {
const int n_slots = slots.size();
if (batch.n_tokens < n_slots) {
std::set<llama_seq_id> seq_ids;
for (int j = 0; j < batch.n_tokens; ++j) {
seq_ids.insert(batch.seq_id[j][0]);
}
// find unused sequence id
llama_seq_id seq_id = -1;
for (int i = 0; i < n_slots; ++i) {
if (seq_ids.find(i) == seq_ids.end()) {
seq_id = i;
}
}
const int n_add = n_slots - batch.n_tokens;
SRV_WRN("adding %d dummy tokens to the batch, seq_id = %d\n", n_add, seq_id);
for (int j = 0; j < n_add; ++j) {
common_batch_add(batch, 0, j, { seq_id }, false);
}
}
}
// process the created batch of tokens
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
@@ -3357,7 +3395,7 @@ struct server_context {
int ret = 0;
if (params_base.embedding || params_base.reranking) {
if (do_encode) {
ret = llama_encode(ctx, batch_view);
} else {
ret = llama_decode(ctx, batch_view);
@@ -3366,14 +3404,29 @@ struct server_context {
metrics.on_decoded(slots);
if (ret != 0) {
if (n_batch == 1 || ret < 0) {
// if you get here, it means the KV cache is full - try increasing it via the context size
SRV_ERR("failed to decode the batch: KV cache is full - try increasing it via the context size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
for (auto & slot : slots) {
slot.release();
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
{
std::string err;
if (n_batch == 1 && ret == 1) {
err = "Context size has been exceeded.";
}
if (ret == -1) {
err = "Invalid input batch.";
}
if (ret < -1) {
err = "Compute error.";
}
if (!err.empty()) {
SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
for (auto & slot : slots) {
slot.release();
send_error(slot, err);
}
break;
}
break; // break loop of n_batch
}
// retry with half the batch size to try to find a free slot in the KV cache
@@ -4046,7 +4099,10 @@ int main(int argc, char ** argv) {
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
{ "total_slots", ctx_server.params_base.n_parallel },
{ "model_path", ctx_server.params_base.model.path },
{ "modalities", json{{"vision", ctx_server.mctx != nullptr}} }, // TODO: add more in the future
{ "modalities", json{
{"vision", ctx_server.oai_parser_opt.allow_image},
{"audio", ctx_server.oai_parser_opt.allow_audio},
} },
{ "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) },
{ "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)},
{ "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)},
@@ -4137,10 +4193,10 @@ int main(int argc, char ** argv) {
for (auto & file : files) {
mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(file.data(), file.size()));
if (!bmp.ptr) {
throw std::runtime_error("Failed to load image");
throw std::runtime_error("Failed to load image or audio file");
}
// calculate bitmap hash (for KV caching)
std::string hash = fnv_hash(bmp.data(), bmp.nx()*bmp.ny()*3);
std::string hash = fnv_hash(bmp.data(), bmp.n_bytes());
bmp.set_id(hash.c_str());
bitmaps.entries.push_back(std::move(bmp));
}
@@ -4372,7 +4428,7 @@ int main(int argc, char ** argv) {
OAICOMPAT_TYPE_NONE); // infill is not OAI compatible
};
const auto handle_chat_completions = [&ctx_server, &params, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
const auto handle_chat_completions = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
LOG_DBG("request: %s\n", req.body.c_str());
if (ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
@@ -4381,13 +4437,9 @@ int main(int argc, char ** argv) {
auto body = json::parse(req.body);
std::vector<raw_buffer> files;
json data = oaicompat_completion_params_parse(
json data = oaicompat_chat_params_parse(
body,
params.use_jinja,
params.prefill_assistant,
params.reasoning_format,
ctx_server.chat_templates.get(),
ctx_server.mctx,
ctx_server.oai_parser_opt,
files);
handle_completions_impl(
@@ -4400,16 +4452,12 @@ int main(int argc, char ** argv) {
};
// same with handle_chat_completions, but without inference part
const auto handle_apply_template = [&ctx_server, &params, &res_ok](const httplib::Request & req, httplib::Response & res) {
const auto handle_apply_template = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
auto body = json::parse(req.body);
std::vector<raw_buffer> files; // dummy, unused
json data = oaicompat_completion_params_parse(
json data = oaicompat_chat_params_parse(
body,
params.use_jinja,
params.prefill_assistant,
params.reasoning_format,
ctx_server.chat_templates.get(),
ctx_server.mctx,
ctx_server.oai_parser_opt,
files);
res_ok(res, {{ "prompt", std::move(data.at("prompt")) }});
};
@@ -75,7 +75,7 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
choice = data["choices"][0]
if i == 0:
# Check first role message for stream=True
assert choice["delta"]["content"] == ""
assert choice["delta"]["content"] is None
assert choice["delta"]["role"] == "assistant"
else:
assert "role" not in choice["delta"]
@@ -92,7 +92,7 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
assert choice["finish_reason"] == finish_reason
else:
assert choice["finish_reason"] is None
content += choice["delta"]["content"]
content += choice["delta"]["content"] or ''
def test_chat_completion_with_openai_library():
@@ -251,8 +251,9 @@ def test_chat_completion_with_timings_per_token():
for i, data in enumerate(res):
if i == 0:
# Check first role message for stream=True
assert data["choices"][0]["delta"]["content"] == ""
assert data["choices"][0]["delta"]["content"] is None
assert data["choices"][0]["delta"]["role"] == "assistant"
assert "timings" not in data, f'First event should not have timings: {data}'
else:
assert "role" not in data["choices"][0]["delta"]
assert "timings" in data
@@ -311,7 +312,7 @@ def test_logprobs_stream():
choice = data.choices[0]
if i == 0:
# Check first role message for stream=True
assert choice.delta.content == ""
assert choice.delta.content is None
assert choice.delta.role == "assistant"
else:
assert choice.delta.role is None
+25
View File
@@ -47,3 +47,28 @@ def test_date_inside_prompt(template_name: str, format: str, tools: list[dict]):
today_str = datetime.date.today().strftime(format)
assert today_str in prompt, f"Expected today's date ({today_str}) in content ({prompt})"
@pytest.mark.parametrize("add_generation_prompt", [False, True])
@pytest.mark.parametrize("template_name,expected_generation_prompt", [
("meta-llama-Llama-3.3-70B-Instruct", "<|start_header_id|>assistant<|end_header_id|>"),
])
def test_add_generation_prompt(template_name: str, expected_generation_prompt: str, add_generation_prompt: bool):
global server
server.jinja = True
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start(timeout_seconds=TIMEOUT_SERVER_START)
res = server.make_request("POST", "/apply-template", data={
"messages": [
{"role": "user", "content": "What is today?"},
],
"add_generation_prompt": add_generation_prompt,
})
assert res.status_code == 200
prompt = res.body["prompt"]
if add_generation_prompt:
assert expected_generation_prompt in prompt, f"Expected generation prompt ({expected_generation_prompt}) in content ({prompt})"
else:
assert expected_generation_prompt not in prompt, f"Did not expect generation prompt ({expected_generation_prompt}) in content ({prompt})"
+84 -66
View File
@@ -8,6 +8,7 @@ path = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(path))
from utils import *
from enum import Enum
server: ServerProcess
@@ -20,7 +21,11 @@ def create_server():
server = ServerPreset.tinyllama2()
server.model_alias = "tinyllama-2-tool-call"
server.server_port = 8081
server.n_slots = 1
class CompletionMode(Enum):
NORMAL = "normal"
STREAMED = "streamed"
TEST_TOOL = {
"type":"function",
@@ -73,9 +78,8 @@ WEATHER_TOOL = {
}
}
def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict, argument_key: str | None, n_predict, **kwargs):
res = server.make_request("POST", "/v1/chat/completions", data={
body = server.make_any_request("POST", "/v1/chat/completions", data={
"max_tokens": n_predict,
"messages": [
{"role": "system", "content": "You are a coding assistant."},
@@ -86,13 +90,13 @@ def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict
"parallel_tool_calls": False,
**kwargs,
})
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0]
# assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = body["choices"][0]
tool_calls = choice["message"].get("tool_calls")
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
tool_call = tool_calls[0]
assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
# assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
assert expected_function_name == tool_call["function"]["name"]
actual_arguments = tool_call["function"]["arguments"]
@@ -102,12 +106,16 @@ def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict
assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
@pytest.mark.parametrize("template_name,tool,argument_key", [
("google-gemma-2-2b-it", TEST_TOOL, "success"),
("google-gemma-2-2b-it", TEST_TOOL, "success"),
("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"),
("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"),
("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"),
("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"),
])
def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None):
def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None, stream: CompletionMode):
global server
n_predict = 1024
# server = ServerPreset.stories15m_moe()
@@ -115,31 +123,43 @@ def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict,
server.n_predict = n_predict
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start(timeout_seconds=TIMEOUT_SERVER_START)
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, temperature=0.0, top_k=1, top_p=1.0)
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED, temperature=0.0, top_k=1, top_p=1.0)
@pytest.mark.slow
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
@pytest.mark.parametrize("template_name,tool,argument_key", [
("meta-llama-Llama-3.1-8B-Instruct", TEST_TOOL, "success"),
("meta-llama-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"),
("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"),
("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"),
("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"),
("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"),
# Functionary v3.2 format supports raw python content, which w/ a dummy stories model will never end on its own.
# ("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"),
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"),
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, "code"),
("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"),
("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, "code"),
("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"),
("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, "code"),
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"),
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, "code"),
("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", TEST_TOOL, "success"),
("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", PYTHON_TOOL, "code"),
("fireworks-ai-llama-3-firefunction-v2", TEST_TOOL, "success"),
# ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "codeFalse), True),
# ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "code"),
])
def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, argument_key: str | None):
def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, argument_key: str | None, stream: CompletionMode):
global server
n_predict = 512
# server = ServerPreset.stories15m_moe()
@@ -147,10 +167,11 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
server.n_predict = n_predict
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start(timeout_seconds=TIMEOUT_SERVER_START)
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict)
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED)
@pytest.mark.slow
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
@pytest.mark.parametrize("tool,argument_key,hf_repo,template_override", [
(TEST_TOOL, "success", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
@@ -184,9 +205,9 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
(PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
(PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
# (TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
# (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
# (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)),
(PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)),
@@ -203,10 +224,9 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
(TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
])
def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
global server
n_predict = 512
server.n_slots = 1
server.jinja = True
server.n_ctx = 8192
server.n_predict = n_predict
@@ -219,7 +239,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
elif isinstance(template_override, str):
server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START)
res = server.make_request("POST", "/v1/chat/completions", data={
body = server.make_any_request("POST", "/v1/chat/completions", data={
"max_tokens": n_predict,
"messages": [
{"role": "system", "content": "You are a coding assistant."},
@@ -228,12 +248,12 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
"tool_choice": "required",
"tools": [tool],
"parallel_tool_calls": False,
"stream": stream == CompletionMode.STREAMED,
"temperature": 0.0,
"top_k": 1,
"top_p": 1.0,
}, timeout=TIMEOUT_HTTP_REQUEST)
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0]
choice = body["choices"][0]
tool_calls = choice["message"].get("tool_calls")
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
tool_call = tool_calls[0]
@@ -248,7 +268,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
def do_test_completion_without_tool_call(server: ServerProcess, n_predict: int, tools: list[dict], tool_choice: str | None, **kwargs):
res = server.make_request("POST", "/v1/chat/completions", data={
body = server.make_any_request("POST", "/v1/chat/completions", data={
"max_tokens": n_predict,
"messages": [
{"role": "system", "content": "You are a coding assistant."},
@@ -258,26 +278,27 @@ def do_test_completion_without_tool_call(server: ServerProcess, n_predict: int,
"tool_choice": tool_choice,
**kwargs,
}, timeout=TIMEOUT_HTTP_REQUEST)
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0]
choice = body["choices"][0]
assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
("meta-llama-Llama-3.3-70B-Instruct", 128, [], None),
("meta-llama-Llama-3.3-70B-Instruct", 128, [TEST_TOOL], None),
("meta-llama-Llama-3.3-70B-Instruct", 128, [PYTHON_TOOL], 'none'),
])
def test_completion_without_tool_call_fast(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
def test_completion_without_tool_call_fast(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None, stream: CompletionMode):
global server
server.jinja = True
server.n_predict = n_predict
server.jinja = True
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start(timeout_seconds=TIMEOUT_SERVER_START)
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice)
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED)
@pytest.mark.slow
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
("meetkai-functionary-medium-v3.2", 256, [], None),
("meetkai-functionary-medium-v3.2", 256, [TEST_TOOL], None),
@@ -289,16 +310,17 @@ def test_completion_without_tool_call_fast(template_name: str, n_predict: int, t
("meta-llama-Llama-3.2-3B-Instruct", 256, [TEST_TOOL], None),
("meta-llama-Llama-3.2-3B-Instruct", 256, [PYTHON_TOOL], 'none'),
])
def test_completion_without_tool_call_slow(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
def test_completion_without_tool_call_slow(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None, stream: CompletionMode):
global server
server.jinja = True
server.n_predict = n_predict
server.jinja = True
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start(timeout_seconds=TIMEOUT_SERVER_START)
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice)
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED)
@pytest.mark.slow
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
@pytest.mark.parametrize("hf_repo,template_override", [
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
@@ -321,11 +343,11 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
# ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
# ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),
# ("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
# ("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
@@ -339,10 +361,9 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
# ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
])
def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | None):
def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
global server
n_predict = 512
server.n_slots = 1
server.jinja = True
server.n_ctx = 8192
server.n_predict = n_predict
@@ -355,11 +376,11 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] |
elif isinstance(template_override, str):
server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START)
do_test_weather(server, max_tokens=n_predict)
do_test_weather(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict)
def do_test_weather(server: ServerProcess, **kwargs):
res = server.make_request("POST", "/v1/chat/completions", data={
body = server.make_any_request("POST", "/v1/chat/completions", data={
"messages": [
{"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."},
{"role": "user", "content": "What is the weather in Istanbul?"},
@@ -367,14 +388,13 @@ def do_test_weather(server: ServerProcess, **kwargs):
"tools": [WEATHER_TOOL],
**kwargs,
}, timeout=TIMEOUT_HTTP_REQUEST)
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0]
choice = body["choices"][0]
tool_calls = choice["message"].get("tool_calls")
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
tool_call = tool_calls[0]
# assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"], f'Expected weather tool call, got {tool_call["function"]["name"]}'
assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
# assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
actual_arguments = json.loads(tool_call["function"]["arguments"])
assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}"
location = actual_arguments["location"]
@@ -383,6 +403,7 @@ def do_test_weather(server: ServerProcess, **kwargs):
@pytest.mark.slow
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
@pytest.mark.parametrize("result_override,n_predict,hf_repo,template_override", [
(None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
(None, 128, "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None),
@@ -400,9 +421,8 @@ def do_test_weather(server: ServerProcess, **kwargs):
# (None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
# ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
])
def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
global server
server.n_slots = 1
server.jinja = True
server.n_ctx = 8192 * 2
server.n_predict = n_predict
@@ -415,11 +435,11 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
elif isinstance(template_override, str):
server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START)
do_test_calc_result(server, result_override, n_predict)
do_test_calc_result(server, result_override, n_predict, stream=stream == CompletionMode.STREAMED)
def do_test_calc_result(server: ServerProcess, result_override: str | None, n_predict: int, **kwargs):
res = server.make_request("POST", "/v1/chat/completions", data={
body = server.make_any_request("POST", "/v1/chat/completions", data={
"max_tokens": n_predict,
"messages": [
{"role": "system", "content": "You are a tools-calling assistant. You express numerical values with at most two decimals."},
@@ -466,8 +486,7 @@ def do_test_calc_result(server: ServerProcess, result_override: str | None, n_pr
],
**kwargs,
}, timeout=TIMEOUT_HTTP_REQUEST)
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0]
choice = body["choices"][0]
tool_calls = choice["message"].get("tool_calls")
assert tool_calls is None, f'Expected no tool call in {choice["message"]}'
content = choice["message"].get("content")
@@ -480,18 +499,18 @@ def do_test_calc_result(server: ServerProcess, result_override: str | None, n_pr
@pytest.mark.slow
@pytest.mark.parametrize("n_predict,reasoning_format,expect_content,expect_reasoning_content,hf_repo,template_override", [
(128, 'deepseek', "^The sum of 102 and 7 is 109[\\s\\S]*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
(128, None, "^The sum of 102 and 7 is 109[\\s\\S]*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
(1024, 'deepseek', "To find the sum of[\\s\\S]*", "I need to calculate the sum of 102 and 7[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
(1024, 'none', "^(<think>\\s*)?I need[\\s\\S]*?</think>\\s*To find[\\s\\S]*", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
(1024, 'deepseek', "To find the sum of[\\s\\S]*", "First, I [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
@pytest.mark.parametrize("n_predict,reasoning_format,stream,expect_reasoning_content,expect_content,hf_repo,template_override", [
(128, 'deepseek', CompletionMode.NORMAL, None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
(128, None, CompletionMode.NORMAL, None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
(1024, 'deepseek', CompletionMode.NORMAL, "I need to calculate the sum of 102 and 7[\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
(1024, 'deepseek', CompletionMode.STREAMED, None, "^<think>I need to calculate [\\s\\S]*?</think>To find the sum of [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
(1024, 'deepseek', CompletionMode.NORMAL, "First, I [\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
(1024, 'deepseek', CompletionMode.STREAMED, None, "^<think>First, I [\\s\\S]*?</think>To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
# (1024, 'none', CompletionMode.NORMAL, None, "^(<think>\\s*)?I need[\\s\\S]*?</think>\\s*To find[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
# (128, 'deepseek', None, "^Okay, let me figure out the sum of 102 and 7[\\s\\S]*", "bartowski/Qwen_QwQ-32B-GGUF:Q4_K_M", None),
])
def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] | None, expect_content: str | None, expect_reasoning_content: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] | None, expect_content: str | None, expect_reasoning_content: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
global server
server.n_slots = 1
server.reasoning_format = reasoning_format
server.jinja = True
server.n_ctx = 8192 * 2
@@ -505,14 +524,14 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
elif isinstance(template_override, str):
server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START)
res = server.make_request("POST", "/v1/chat/completions", data={
body = server.make_any_request("POST", "/v1/chat/completions", data={
"max_tokens": n_predict,
"messages": [
{"role": "user", "content": "What's the sum of 102 and 7?"},
]
],
"stream": stream == CompletionMode.STREAMED,
}, timeout=TIMEOUT_HTTP_REQUEST)
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0]
choice = body["choices"][0]
assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
content = choice["message"].get("content")
@@ -529,6 +548,7 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
@pytest.mark.slow
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
@pytest.mark.parametrize("hf_repo,template_override", [
("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
@@ -562,10 +582,9 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", "chatml"),
])
def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | None] | None):
def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
global server
n_predict = 512 # High because of DeepSeek R1
server.n_slots = 1
server.jinja = True
server.n_ctx = 8192
server.n_predict = n_predict
@@ -579,11 +598,11 @@ def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | Non
server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START)
do_test_hello_world(server, max_tokens=n_predict)
do_test_hello_world(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict)
def do_test_hello_world(server: ServerProcess, **kwargs):
res = server.make_request("POST", "/v1/chat/completions", data={
body = server.make_any_request("POST", "/v1/chat/completions", data={
"messages": [
{"role": "system", "content": "You are a tool-calling agent."},
{"role": "user", "content": "say hello world with python"},
@@ -591,16 +610,15 @@ def do_test_hello_world(server: ServerProcess, **kwargs):
"tools": [PYTHON_TOOL],
**kwargs,
}, timeout=TIMEOUT_HTTP_REQUEST)
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0]
choice = body["choices"][0]
tool_calls = choice["message"].get("tool_calls")
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
tool_call = tool_calls[0]
# assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"]
assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
# assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
actual_arguments = json.loads(tool_call["function"]["arguments"])
assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}"
code = actual_arguments["code"]
assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}"
assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', code), f'Expected hello world, got {code}'
assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', re.sub(r'#.*\n?', '', code)), f'Expected hello world, got {code}'
@@ -30,6 +30,7 @@ def create_server():
("What is this:\n", "malformed", False, None),
("What is this:\n", "https://google.com/404", False, None), # non-existent image
("What is this:\n", "https://ggml.ai", False, None), # non-image data
# TODO @ngxson : test with multiple images, no images and with audio
]
)
def test_vision_chat_completion(prompt, image_url, success, re_content):
+71
View File
@@ -294,6 +294,77 @@ class ServerProcess:
print("Partial response from server", json.dumps(data, indent=2))
yield data
def make_any_request(
self,
method: str,
path: str,
data: dict | None = None,
headers: dict | None = None,
timeout: float | None = None,
) -> dict:
stream = data.get('stream', False)
if stream:
content: list[str] = []
tool_calls: list[dict] = []
finish_reason: Optional[str] = None
content_parts = 0
tool_call_parts = 0
arguments_parts = 0
for chunk in self.make_stream_request(method, path, data, headers):
assert len(chunk['choices']) == 1, f'Expected 1 choice, got {len(chunk["choices"])}'
choice = chunk['choices'][0]
if choice['delta'].get('content') is not None:
assert len(choice['delta']['content']) > 0, f'Expected non empty content delta!'
content.append(choice['delta']['content'])
content_parts += 1
if choice['delta'].get('finish_reason') is not None:
finish_reason = choice['delta']['finish_reason']
for tc in choice['delta'].get('tool_calls', []):
if 'function' not in tc:
raise ValueError(f"Expected function type, got {tc['type']}")
if tc['index'] >= len(tool_calls):
tool_calls.append(dict(
id="",
type="function",
function=dict(
name="",
arguments="",
)
))
tool_call = tool_calls[tc['index']]
if tc.get('id') is not None:
tool_call['id'] = tc['id']
fct = tc['function']
if fct.get('name') is not None:
tool_call['function']['name'] = fct['name']
if fct.get('arguments') is not None:
assert len(fct['arguments']) > 0, f'Expected non empty arguments delta!'
tool_call['function']['arguments'] += fct['arguments']
print(f'Streamed response had {content_parts} content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts')
result = dict(
choices=[
dict(
index=0,
finish_reason=finish_reason,
message=dict(
role='assistant',
content=''.join(content) if content else None,
tool_calls=tool_calls if tool_calls else None,
),
)
],
)
print("Final response from server", json.dumps(result, indent=2))
return result
else:
response = self.make_request(method, path, data, headers, timeout=timeout)
assert response.status_code == 200, f"Server returned error: {response.status_code}"
return response.body
server_instances: Set[ServerProcess] = set()
+82 -73
View File
@@ -474,26 +474,6 @@ static std::string gen_tool_call_id() {
// other common utils
//
static bool ends_with(const std::string & str, const std::string & suffix) {
return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
}
static size_t find_partial_stop_string(const std::string &stop, const std::string &text) {
if (!text.empty() && !stop.empty()) {
const char text_last_char = text.back();
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
if (stop[char_index] == text_last_char) {
const std::string current_partial = stop.substr(0, char_index + 1);
if (ends_with(text, current_partial)) {
return text.size() - char_index - 1;
}
}
}
}
return std::string::npos;
}
// TODO: reuse llama_detokenize
template <class Iter>
static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
@@ -536,6 +516,7 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons
// OAI utils
//
// used by /completions endpoint
static json oaicompat_completion_params_parse(const json & body) {
json llama_params;
@@ -580,31 +561,34 @@ static json oaicompat_completion_params_parse(const json & body) {
return llama_params;
}
static json oaicompat_completion_params_parse(
struct oaicompat_parser_options {
bool use_jinja;
bool prefill_assistant;
common_reasoning_format reasoning_format;
common_chat_templates * tmpls;
bool allow_image;
bool allow_audio;
};
// used by /chat/completions endpoint
static json oaicompat_chat_params_parse(
const json & body, /* openai api json semantics */
bool use_jinja,
bool prefill_assistant,
common_reasoning_format reasoning_format,
const struct common_chat_templates * tmpls,
bool allow_non_text,
const oaicompat_parser_options & opt,
std::vector<raw_buffer> & out_files)
{
json llama_params;
auto tools = json_value(body, "tools", json());
auto has_tools = tools.is_array() && !tools.empty();
auto stream = json_value(body, "stream", false);
auto tool_choice = json_value(body, "tool_choice", std::string("auto"));
if (tools.is_array() && !tools.empty()) {
if (stream) {
throw std::runtime_error("Cannot use tools with stream");
}
if (!use_jinja) {
if (!opt.use_jinja) {
if (has_tools) {
throw std::runtime_error("tools param requires --jinja flag");
}
}
if (!use_jinja) {
if (body.contains("tool_choice") && !body.at("tool_choice").is_null()) {
throw std::runtime_error("Unsupported param: tool_choice");
if (tool_choice != "auto") {
throw std::runtime_error("tool_choice param requires --jinja flag");
}
}
@@ -667,12 +651,12 @@ static json oaicompat_completion_params_parse(
for (auto & p : content) {
std::string type = json_value(p, "type", std::string());
json image_url = json_value(p, "image_url", json::object());
if (type == "image_url") {
if (!allow_non_text) {
throw std::runtime_error("image input is not supported by this server");
if (!opt.allow_image) {
throw std::runtime_error("image input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
}
json image_url = json_value(p, "image_url", json::object());
std::string url = json_value(image_url, "url", std::string());
if (string_starts_with(url, "http")) {
// download remote image
@@ -710,8 +694,31 @@ static json oaicompat_completion_params_parse(
// replace this chunk with a marker
p["type"] = "text";
p["text"] = MTMD_DEFAULT_IMAGE_MARKER;
p["text"] = mtmd_default_marker();
p.erase("image_url");
} else if (type == "input_audio") {
if (!opt.allow_audio) {
throw std::runtime_error("audio input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
}
json input_audio = json_value(p, "input_audio", json::object());
std::string data = json_value(input_audio, "data", std::string());
std::string format = json_value(input_audio, "format", std::string());
// while we also support flac, we don't allow it here so we matches the OAI spec
if (format != "wav" && format != "mp3") {
throw std::runtime_error("input_audio.format must be either 'wav' or 'mp3'");
}
auto decoded_data = base64_decode(data); // expected to be base64 encoded
out_files.push_back(decoded_data);
// replace this chunk with a marker
p["type"] = "text";
p["text"] = mtmd_default_marker();
p.erase("input_audio");
} else if (type != "text") {
throw std::runtime_error("unsupported content[].type");
}
}
}
@@ -719,21 +726,20 @@ static json oaicompat_completion_params_parse(
common_chat_templates_inputs inputs;
inputs.messages = common_chat_msgs_parse_oaicompat(messages);
inputs.tools = common_chat_tools_parse_oaicompat(tools);
inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto")));
inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tool_choice);
inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump();
inputs.grammar = grammar;
inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
inputs.use_jinja = use_jinja;
inputs.use_jinja = opt.use_jinja;
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE;
inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
inputs.reasoning_format = opt.reasoning_format;
if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) {
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
}
// if the assistant message appears at the end of list, we do not add end-of-turn token
// for ex. this can be useful to modify the reasoning process in reasoning models
bool prefill_assistant_message = !inputs.messages.empty() && inputs.messages.back().role == "assistant" && prefill_assistant;
bool prefill_assistant_message = !inputs.messages.empty() && inputs.messages.back().role == "assistant" && opt.prefill_assistant;
common_chat_msg last_message;
if (prefill_assistant_message) {
last_message = inputs.messages.back();
@@ -744,12 +750,13 @@ static json oaicompat_completion_params_parse(
throw std::runtime_error("Cannot have 2 or more assistant messages at the end of the list.");
}
inputs.extract_reasoning = false;
/* TODO: test this properly */
inputs.reasoning_format = COMMON_REASONING_FORMAT_NONE;
inputs.add_generation_prompt = true;
}
// Apply chat template to the list of messages
auto chat_params = common_chat_templates_apply(tmpls, inputs);
auto chat_params = common_chat_templates_apply(opt.tmpls, inputs);
/* Append assistant prefilled message */
if (prefill_assistant_message) {
@@ -769,6 +776,7 @@ static json oaicompat_completion_params_parse(
}
llama_params["grammar_triggers"] = grammar_triggers;
llama_params["preserved_tokens"] = chat_params.preserved_tokens;
llama_params["thinking_forced_open"] = chat_params.thinking_forced_open;
for (const auto & stop : chat_params.additional_stops) {
llama_params["stop"].push_back(stop);
}
@@ -782,6 +790,9 @@ static json oaicompat_completion_params_parse(
// Handle "logprobs" field
// TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
if (json_value(body, "logprobs", false)) {
if (has_tools && stream) {
throw std::runtime_error("logprobs is not supported with tools + stream");
}
llama_params["n_probs"] = json_value(body, "top_logprobs", 20);
} else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) {
throw std::runtime_error("top_logprobs requires logprobs to be set to true");
@@ -1040,7 +1051,7 @@ struct server_tokens {
private: // disallow accessing these members directly, risking out-of-sync
// map a **start** position in tokens to the image chunk
std::unordered_map<llama_pos, mtmd::input_chunk_ptr> map_pos_to_image;
std::unordered_map<llama_pos, mtmd::input_chunk_ptr> map_pos_to_media;
// list of tokens
// it can include LLAMA_TOKEN_NULL, which is used to indicate a token that is not a text token
@@ -1051,7 +1062,7 @@ private: // disallow accessing these members directly, risking out-of-sync
// for ex. with input of 5 text tokens and 2 images:
// [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1]
// pos 0 1 2 3 4 5 6 7 8 9
// map_pos_to_image will contain: {5, img0}, {8, img1}
// map_pos_to_media will contain: {5, img0}, {8, img1}
public:
server_tokens() = default;
@@ -1090,15 +1101,15 @@ public:
}
oss << "\n";
oss << "image pos: ";
for (const auto & it : map_pos_to_image) {
for (const auto & it : map_pos_to_media) {
oss << it.first << ", ";
}
return oss.str();
}
const mtmd::input_chunk_ptr & find_chunk(llama_pos pos) const {
auto it = map_pos_to_image.find(pos);
if (it != map_pos_to_image.end()) {
auto it = map_pos_to_media.find(pos);
if (it != map_pos_to_media.end()) {
return it->second;
} else {
throw std::runtime_error("Chunk not found");
@@ -1115,16 +1126,15 @@ public:
// will create a copy of the chunk if it contains non-text data
void push_back(const mtmd_input_chunk * chunk) {
auto type = mtmd_input_chunk_get_type(chunk);
if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
GGML_ASSERT(has_mtmd);
auto img_tokens = mtmd_input_chunk_get_tokens_image(chunk);
const int n_pos = mtmd_image_tokens_get_n_pos(img_tokens);
const int n_pos = mtmd_input_chunk_get_n_pos(chunk);
llama_pos start_pos = tokens.size();
for (int i = 0; i < n_pos; ++i) {
tokens.emplace_back(LLAMA_TOKEN_NULL);
}
mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
map_pos_to_image[start_pos] = std::move(new_chunk);
map_pos_to_media[start_pos] = std::move(new_chunk);
} else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
size_t n_tokens;
auto text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
@@ -1169,6 +1179,9 @@ public:
void keep_first(size_t n) {
GGML_ASSERT(n <= tokens.size());
if (has_mtmd) {
if (n == tokens.size()) {
return; // nothing to do
}
// we throw an error if we try to remove a token in the middle of an image
// for ex. with input of 5 text tokens and 2 images:
// [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1]
@@ -1183,10 +1196,10 @@ public:
}
}
// remove all image chunks that are not used anymore
for (auto it = map_pos_to_image.begin(); it != map_pos_to_image.end(); ) {
for (auto it = map_pos_to_media.begin(); it != map_pos_to_media.end(); ) {
llama_pos pos = it->first;
if (pos >= (llama_pos)n) {
it = map_pos_to_image.erase(it);
it = map_pos_to_media.erase(it);
} else {
++it;
}
@@ -1217,14 +1230,12 @@ public:
const auto & a_chunk = find_chunk(i);
const auto & b_chunk = b.find_chunk(i);
GGML_ASSERT(a_chunk && b_chunk);
const auto * a_img = mtmd_input_chunk_get_tokens_image(a_chunk.get());
const auto * b_img = mtmd_input_chunk_get_tokens_image(b_chunk.get());
std::string ai_id = mtmd_image_tokens_get_id(a_img);
std::string bi_id = mtmd_image_tokens_get_id(b_img);
size_t a_pos = mtmd_image_tokens_get_n_pos(a_img);
size_t b_pos = mtmd_image_tokens_get_n_pos(b_img);
std::string ai_id = mtmd_input_chunk_get_id(a_chunk.get());
std::string bi_id = mtmd_input_chunk_get_id(b_chunk.get());
size_t a_pos = mtmd_input_chunk_get_n_pos(a_chunk.get());
size_t b_pos = mtmd_input_chunk_get_n_pos(b_chunk.get());
if (ai_id == bi_id && a_pos == b_pos) {
GGML_ASSERT(a_pos > 0 && "Invalid image token"); // should never happen
GGML_ASSERT(a_pos > 0 && "Invalid media chunk"); // should never happen
i += a_pos - 1; // will be +1 by the for loop
continue;
} else {
@@ -1250,8 +1261,7 @@ public:
if (t == LLAMA_TOKEN_NULL) {
try {
const auto & chunk = find_chunk(i);
const auto * img_tokens = mtmd_input_chunk_get_tokens_image(chunk.get());
size_t n_pos = mtmd_image_tokens_get_n_pos(img_tokens);
size_t n_pos = mtmd_input_chunk_get_n_pos(chunk.get());
i += n_pos - 1; // will be +1 by the for loop
} catch (const std::exception & e) {
return false;
@@ -1270,22 +1280,21 @@ public:
llama_pos n_past,
int32_t seq_id,
llama_pos & n_pos_out) {
auto it = map_pos_to_image.find(n_past);
if (it == map_pos_to_image.end()) {
throw std::runtime_error("Chunk not found");
}
SRV_INF("%s\n", "processing image...");
auto & chunk = find_chunk(n_past);
const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE
? "image" : "audio";
SRV_INF("processing %s...\n", name);
int32_t n_batch = llama_n_batch(ctx);
int64_t t0 = ggml_time_ms();
llama_pos new_n_past = n_past;
int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx,
it->second.get(), // chunk
chunk.get(),
n_past,
seq_id,
n_batch,
true, // logits last
&new_n_past);
SRV_INF("image processed in %" PRId64 " ms\n", ggml_time_ms() - t0);
SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0);
if (result != 0) {
LOG_ERR("mtmd_helper_eval failed with status %d", result);
n_pos_out = n_past;
@@ -1,4 +1,8 @@
import { DocumentTextIcon, XMarkIcon } from '@heroicons/react/24/outline';
import {
DocumentTextIcon,
SpeakerWaveIcon,
XMarkIcon,
} from '@heroicons/react/24/outline';
import { MessageExtra } from '../utils/types';
import { useState } from 'react';
import { classNames } from '../utils/misc';
@@ -66,7 +70,11 @@ export default function ChatInputExtraContextItem({
className="w-14 h-14 flex items-center justify-center"
aria-description="Document icon"
>
<DocumentTextIcon className="h-8 w-14 text-base-content/50" />
{item.type === 'audioFile' ? (
<SpeakerWaveIcon className="h-8 w-8 text-gray-500" />
) : (
<DocumentTextIcon className="h-8 w-8 text-gray-500" />
)}
</div>
<div className="text-xs pr-4">
@@ -98,6 +106,19 @@ export default function ChatInputExtraContextItem({
src={showingItem.base64Url}
alt={`Preview image for ${showingItem.name}`}
/>
) : showingItem.type === 'audioFile' ? (
<audio
controls
className="w-full"
aria-description={`Audio file ${showingItem.name}`}
>
<source
src={`data:${showingItem.mimeType};base64,${showingItem.base64Data}`}
type={showingItem.mimeType}
aria-description={`Audio file ${showingItem.name}`}
/>
Your browser does not support the audio element.
</audio>
) : (
<div className="overflow-x-auto">
<pre className="whitespace-pre-wrap break-words text-sm">
@@ -278,6 +278,13 @@ export default function ChatScreen() {
function ServerInfo() {
const { serverProps } = useAppContext();
const modalities = [];
if (serverProps?.modalities?.audio) {
modalities.push('audio');
}
if (serverProps?.modalities?.vision) {
modalities.push('vision');
}
return (
<div
className="card card-sm shadow-sm border-1 border-base-content/20 text-base-content/70 mb-6"
@@ -291,6 +298,13 @@ function ServerInfo() {
<br />
<b>Build</b>: {serverProps?.build_info}
<br />
{modalities.length > 0 ? (
<>
<b>Supported modalities:</b> {modalities.join(', ')}
</>
) : (
''
)}
</p>
</div>
</div>
@@ -11,6 +11,7 @@ pdfjs.GlobalWorkerOptions.workerSrc = pdfjsWorkerSrc;
// This file handles uploading extra context items (a.k.a files)
// It allows processing these kinds of files:
// - image files (converted to base64)
// - audio files (converted to base64)
// - text files (including code files)
// - pdf (converted to text)
@@ -41,96 +42,73 @@ export function useChatExtraContext(): ChatExtraContextApi {
const isSupportVision = serverProps?.modalities?.vision;
const onFileAdded = (files: File[]) => {
for (const file of files) {
const mimeType = file.type;
console.debug({ mimeType, file });
if (file.size > 10 * 1024 * 1024) {
toast.error('File is too large. Maximum size is 10MB.');
break;
}
if (mimeType.startsWith('image/')) {
if (!isSupportVision) {
toast.error('Multimodal is not supported by this server or model.');
const onFileAdded = async (files: File[]) => {
try {
for (const file of files) {
const mimeType = file.type;
if (file.size > 10 * 1024 * 1024) {
toast.error('File is too large. Maximum size is 10MB.');
break;
}
const reader = new FileReader();
reader.onload = async (event) => {
if (event.target?.result) {
let base64Url = event.target.result as string;
if (mimeType === 'image/svg+xml') {
// Convert SVG to PNG
base64Url = await svgBase64UrlToPngDataURL(base64Url);
}
if (mimeType.startsWith('image/')) {
if (!isSupportVision) {
toast.error('Multimodal is not supported by this server or model.');
break;
}
addItems([
{
let base64Url = await getFileAsBase64(file);
if (mimeType === 'image/svg+xml') {
// Convert SVG to PNG
base64Url = await svgBase64UrlToPngDataURL(base64Url);
}
addItems([
{
type: 'imageFile',
name: file.name,
base64Url,
},
]);
} else if (mimeType.startsWith('video/')) {
toast.error('Video files are not supported yet.');
break;
} else if (mimeType.startsWith('audio/')) {
if (!/mpeg|wav/.test(mimeType)) {
toast.error('Only mp3 and wav audio files are supported.');
break;
}
// plain base64, not a data URL
const base64Data = await getFileAsBase64(file, false);
addItems([
{
type: 'audioFile',
name: file.name,
mimeType,
base64Data,
},
]);
} else if (mimeType.startsWith('application/pdf')) {
if (config.pdfAsImage && !isSupportVision) {
toast(
'Multimodal is not supported, PDF will be converted to text instead of image.'
);
break;
}
if (config.pdfAsImage && isSupportVision) {
// Convert PDF to images
const base64Urls = await convertPDFToImage(file);
addItems(
base64Urls.map((base64Url) => ({
type: 'imageFile',
name: file.name,
base64Url,
},
]);
}
};
reader.readAsDataURL(file);
} else if (
mimeType.startsWith('video/') ||
mimeType.startsWith('audio/')
) {
toast.error('Video and audio files are not supported yet.');
break;
} else if (mimeType.startsWith('application/pdf')) {
if (config.pdfAsImage && !isSupportVision) {
toast(
'Multimodal is not supported, PDF will be converted to text instead of image.'
);
break;
}
const promise =
config.pdfAsImage && isSupportVision
? convertPDFToImage(file).then((base64Urls) => {
addItems(
base64Urls.map((base64Url) => ({
type: 'imageFile',
name: file.name,
base64Url,
}))
);
})
: convertPDFToText(file).then((content) => {
if (isSupportVision) {
toast.success(
'PDF file converted to text. You can also convert it to image, see in Settings.'
);
}
addItems([
{
type: 'textFile',
name: file.name,
content,
},
]);
});
promise.catch((error) => {
console.error(error);
toast.error('Failed to parse PDF file.');
});
break;
} else {
// Because there can be many text file types (like code file), we will not check the mime type
// and will just check if the file is not binary.
const reader = new FileReader();
reader.onload = (event) => {
if (event.target?.result) {
const content = event.target.result as string;
if (!isLikelyNotBinary(content)) {
toast.error('File is binary. Please upload a text file.');
return;
}
}))
);
} else {
// Convert PDF to text
const content = await convertPDFToText(file);
addItems([
{
type: 'textFile',
@@ -138,10 +116,40 @@ export function useChatExtraContext(): ChatExtraContextApi {
content,
},
]);
if (isSupportVision) {
toast.success(
'PDF file converted to text. You can also convert it to image, see in Settings.'
);
}
}
};
reader.readAsText(file);
break;
} else {
// Because there can be many text file types (like code file), we will not check the mime type
// and will just check if the file is not binary.
const reader = new FileReader();
reader.onload = (event) => {
if (event.target?.result) {
const content = event.target.result as string;
if (!isLikelyNotBinary(content)) {
toast.error('File is binary. Please upload a text file.');
return;
}
addItems([
{
type: 'textFile',
name: file.name,
content,
},
]);
}
};
reader.readAsText(file);
}
}
} catch (error) {
const message = error instanceof Error ? error.message : String(error);
const errorMessage = `Error processing file: ${message}`;
toast.error(errorMessage);
}
};
@@ -154,6 +162,25 @@ export function useChatExtraContext(): ChatExtraContextApi {
};
}
async function getFileAsBase64(file: File, outputUrl = true): Promise<string> {
return new Promise((resolve, reject) => {
const reader = new FileReader();
reader.onload = (event) => {
if (event.target?.result) {
let result = event.target.result as string;
if (!outputUrl) {
// remove base64 url prefix and correct characters
result = result.substring(result.indexOf(',') + 1);
}
resolve(result);
} else {
reject(new Error('Failed to read file.'));
}
};
reader.readAsDataURL(file);
});
}
async function getFileAsBuffer(file: File): Promise<ArrayBuffer> {
return new Promise((resolve, reject) => {
const reader = new FileReader();
+8
View File
@@ -89,6 +89,14 @@ export function normalizeMsgsForAPI(messages: Readonly<Message[]>) {
type: 'image_url',
image_url: { url: extra.base64Url },
});
} else if (extra.type === 'audioFile') {
contentArr.push({
type: 'input_audio',
input_audio: {
data: extra.base64Data,
format: /wav/.test(extra.mimeType) ? 'wav' : 'mp3',
},
});
} else {
throw new Error('Unknown extra type');
}
+13
View File
@@ -51,6 +51,7 @@ export interface Message {
export type MessageExtra =
| MessageExtraTextFile
| MessageExtraImageFile
| MessageExtraAudioFile
| MessageExtraContext;
export interface MessageExtraTextFile {
@@ -65,6 +66,13 @@ export interface MessageExtraImageFile {
base64Url: string;
}
export interface MessageExtraAudioFile {
type: 'audioFile';
name: string;
base64Data: string;
mimeType: string;
}
export interface MessageExtraContext {
type: 'context';
name: string;
@@ -79,6 +87,10 @@ export type APIMessageContentPart =
| {
type: 'image_url';
image_url: { url: string };
}
| {
type: 'input_audio';
input_audio: { data: string; format: 'wav' | 'mp3' };
};
export type APIMessage = {
@@ -120,6 +132,7 @@ export interface LlamaCppServerProps {
n_ctx: number;
modalities?: {
vision: boolean;
audio: boolean;
};
// TODO: support params
}
+4 -2
View File
@@ -579,6 +579,8 @@ int main(int argc, char ** argv) {
params.model = params.vocoder.model;
params.embedding = true;
params.ctx_shift = false; // silence warning
params.n_ubatch = params.n_batch;
common_init_result llama_init_cts = common_init_from_params(params);
@@ -1020,8 +1022,8 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
}
GGML_ASSERT(batch.n_tokens == n_codes);
if (llama_decode(ctx_cts, batch) != 0) {
LOG_ERR("%s: llama_decode() failed\n", __func__);
if (llama_encode(ctx_cts, batch) != 0) {
LOG_ERR("%s: llama_encode() failed\n", __func__);
return 1;
}