Compare commits

..

59 Commits

Author SHA1 Message Date
David Zhao 79c1160b07 cuda: refactored ssm_scan and use CUB (#13291)
* cuda: refactored ssm_scan to use CUB

* fixed compilation error when when not using CUB

* assign L to constant and use size_t instead of int

* deduplicated functions

* change min blocks per mp to 1

* Use cub load and store warp transpose

* suppress clang warning
2025-08-09 20:29:43 +02:00
Aman Gupta 34c9d765bf CUDA: add attention sinks for tile and wmma (#15178)
* CUDA: add attention sinks for tile and wmma

* Review: formatting changes + remove syncthreads from tile + remove warp_reduce_max from wmma
2025-08-09 20:00:24 +08:00
compilade e54d41befc gguf-py : add Numpy MXFP4 de/quantization support (#15111)
* gguf-py : add MXFP4 de/quantization support

* ggml-quants : handle zero amax for MXFP4
2025-08-08 17:48:26 -04:00
Johannes Gäßler 4850b52aed server-bench: external OAI servers, sqlite (#15179)
* server-bench: external OAI servers, sqlite

* Update scripts/server-bench.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update scripts/server-bench.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update scripts/server-bench.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* raise_for_status

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2025-08-08 23:04:36 +02:00
AN Long cd6983d56d ggml : fix field name when new ggml_backend (#14944) 2025-08-08 14:37:22 +02:00
Olivier Chafik 6c7e9a5440 vendor: sync minja (#15161)
* vendor: sync minja

* Update minja.hpp

* Apply suggestions from code review

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2025-08-08 10:45:18 +01:00
Johannes Gäßler 1425f587a8 CUDA: attention sinks for mma FlashAttention (#15157) 2025-08-08 08:19:58 +02:00
lhez aaa3d07ae7 opencl: support sink in soft_max (attn sinks) (#15152) 2025-08-07 21:47:03 -07:00
Xuan-Son Nguyen 50aa938901 convert : support non-mxfp4 HF model (#15153)
* convert : support non-mxfp4 HF model

* rm redundant check

* disable debug check
2025-08-07 23:26:03 +02:00
Jeff Bolz c4f53563df vulkan: support fattn sinks (#15126) 2025-08-07 22:44:20 +02:00
Jeff Bolz a0552c8bee vulkan: Add env var to disable host visible vidmem (#15109) 2025-08-07 22:07:11 +02:00
RunningLeon 99acbc9921 llama : Support intern-s1 (#14875)
* support internvl

* support interns1

* resolve comments

* put interns1 in tensor mapping

* resolve comment

* move tokenizer changes to sub class
2025-08-07 18:20:40 +02:00
uvos 7ad67ba9fe HIP: add cmake option to enable compiler output of kernel resource usage metrics (#15103) 2025-08-07 16:44:14 +02:00
Christian Kastner 9a96389544 ggml: Skip backend library linking code when GGML_BACKEND_DL=ON (#15094)
Any available libraries are found and loaded dynamically at runtime.
2025-08-07 13:45:41 +02:00
Johannes Gäßler 1d72c84188 CUDA: GEMM for FP32/FP16/BF16 and ne11 <= 16 (#15131)
* CUDA: GEMM for FP32/FP16/BF16 and ne11 <= 16
2025-08-07 10:53:21 +02:00
Johannes Gäßler 20638e4f16 scripts: fix crash when --tool is not set (#15133) 2025-08-07 08:50:30 +02:00
Daniel Bevenius 36d3f00e14 requirements : fix PyTorch uint64 compatibility (#15134)
This commit addresses an issue with the convert_hf_to_gguf script
which is currently failing with:
```console
AttributeError: module 'torch' has no attribute 'uint64'
```

This occurred because safetensors expects torch.uint64 to be available
in the public API, but PyTorch 2.2.x only provides limited support for
unsigned types beyond uint8 it seems. The torch.uint64 dtype exists but
is not exposed in the standard torch namespace
(see pytorch/pytorch#58734).

PyTorch 2.4.0 properly exposes torch.uint64 in the public API, resolving
the compatibility issue with safetensors. This also required torchvision
to updated to =0.19.0 for compatibility.

Refs: https://huggingface.co/spaces/ggml-org/gguf-my-repo/discussions/186#68938de803e47d990aa087fb
Refs: https://github.com/pytorch/pytorch/issues/58734
2025-08-07 05:31:48 +02:00
Reese Levine 5fd160bbd9 ggml: Add basic SET_ROWS support in WebGPU (#15137)
* Begin work on set_rows

* Work on set rows

* Add error buffers for reporting unsupported SET_ROWS indices

* Remove extra comments
2025-08-06 15:14:40 -07:00
rmatif 756cfea826 fix profiling crash (#15072) 2025-08-06 14:17:51 -07:00
lhez e725a1a982 opencl: add swiglu_oai and add_id (#15121)
* opencl: add `swiglu-oai`

* opencl: add `add_id`

* opencl: add missing `add_id.cl`
2025-08-06 12:12:17 -07:00
Sachin Desai 3db4da56a5 chat : support Granite model reasoning and tool call (#14864) 2025-08-06 20:27:30 +02:00
Juk Armstrong 476aa3fd57 Fixed name -override-tensors to -override-tensor (#15129) 2025-08-06 17:28:48 +01:00
Diego Devesa 0d8831543c ggml : fix fallback to CPU for ununsupported ops (#15118) 2025-08-06 14:37:35 +02:00
Sigbjørn Skjæret 65c797c4fa chat : fix yandex chat template (#15116) 2025-08-06 13:26:49 +02:00
stevenkuang 25726898e8 chat : fix hunyuan auto-detection (#15114)
Signed-off-by: stevenkuang <stevenkuang@tencent.com>
2025-08-06 11:48:30 +02:00
Chenguang Li 2241453252 CANN: add support for ACL Graph (#15065)
* feat(cann): add optional support for ACL Graph execution

This commit adds support for executing ggml computational graphs using
Huawei's ACL graph mode via the USE_CANN_GRAPH flag. The support can be
enabled at compile time using the CMake option:

    -DUSE_CANN_GRAPH=ON

By default, ACL graph execution is **disabled**, and the fallback path
uses node-by-node execution.

Key additions:
- CMake option  to toggle graph mode
- Graph capture and execution logic using
- Tensor property matching to determine whether graph update is required
- Safe fallback and logging if the environment variable LLAMA_SET_ROWS
  is unset or invalid

This prepares the backend for performance improvements in repetitive graph
execution scenarios on Ascend devices.

Signed-off-by: noemotiovon <757486878@qq.com>

* Fix review comments

Signed-off-by: noemotiovon <757486878@qq.com>

* remane USE_CANN_GRAPH to USE_ACL_GRAPH

Signed-off-by: noemotiovon <757486878@qq.com>

* fix typo

Signed-off-by: noemotiovon <757486878@qq.com>

---------

Signed-off-by: noemotiovon <757486878@qq.com>
2025-08-06 14:12:42 +08:00
Reese Levine 9515c6131a ggml: WebGPU disable SET_ROWS for now (#15078)
* Add paramater buffer pool, batching of submissions, refactor command building/submission

* Add header for linux builds

* Free staged parameter buffers at once

* Format with clang-format

* Fix thread-safe implementation

* Use device implicit synchronization

* Update workflow to use custom release

* Remove testing branch workflow

* Disable set_rows until it's implemented

* Fix potential issue around empty queue submission

* Try synchronous submission

* Try waiting on all futures explicitly

* Add debug

* Add more debug messages

* Work on getting ssh access for debugging

* Debug on failure

* Disable other tests

* Remove extra if

* Try more locking

* maybe passes?

* test

* Some cleanups

* Restore build file

* Remove extra testing branch ci
2025-08-05 16:26:38 -07:00
Georgi Gerganov fd1234cb46 llama : add gpt-oss (#15091)
* oai moe

* compat with new checkpoint

* add attn sink impl

* add rope scaling yarn

* logits match with latest transformers code

* wip chat template

* rm trailing space

* use ggml_scale_bias

* rm redundant is_swa_all

* convert interleaved gate_up

* graph : fix activation function to match reference (#7)

* vocab : handle o200k_harmony special tokens

* ggml : add attention sinks support (#1)

* llama : add attn sinks

* ggml : add attn sinks

* cuda : add attn sinks

* vulkan : add support for sinks in softmax

remove unnecessary return

* ggml : add fused swiglu_oai op (#11)

* ggml : add fused swiglu_oai op

* Update ggml/src/ggml-cpu/ops.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* update CUDA impl

* cont : metal impl

* add vulkan impl

* test-backend-ops : more test cases, clean up

* llama : remove unfused impl

* remove extra lines

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Co-authored-by: slaren <slarengh@gmail.com>

* repack mxfp4 upon conversion

* clean up a bit

* enable thinking

* add quick hack to render only some special tokens

* fix bf16 conversion

* remove vocab hack

* webui ok

* support chat parsing for gpt-oss

* fix webui

* direct mapping mxfp4, FINALLY

* force using mxfp4

* properly use lazy tensor

* ggml : add mxfp4

ggml : use e8m0 conversion instead of powf

Co-authored-by: Diego Devesa <slarengh@gmail.com>

change kvalues_mxfp4 table to match e2m1 (#6)

metal : remove quantization for now (not used)

cuda : fix disabled CUDA graphs due to ffn moe bias

vulkan : add support for mxfp4

cont : add cm2 dequant

* ggml : add ggml_add_id (#13)

* ggml : add ggml_add_id

* add cuda impl

* llama : add weight support check for add_id

* perf opt

* add vulkan impl

* rename cuda files

* add metal impl

* allow in-place ggml_add_id

* llama : keep biases on CPU with --cpu-moe

* llama : fix compile error

ggml-ci

* cuda : add fallback for __nv_cvt_e8m0_to_bf16raw

ggml-ci

* cleanup

ggml-ci

* sycl : fix supports_op for MXFP4

ggml-ci

* fix Unknown reasoning format

* ggml-cpu : fix AVX build

ggml-ci

* fix hip build

ggml-ci

* cuda : add mxfp4 dequantization support for cuBLAS

ggml-ci

* ggml-cpu : fix mxfp4 fallback definitions for some architectures

ggml-ci

* cuda : fix version required for __nv_cvt_e8m0_to_bf16raw

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
Co-authored-by: slaren <slarengh@gmail.com>
2025-08-05 22:10:36 +03:00
Sigbjørn Skjæret f324a3b715 chat : only remove double bos/eos if added (#15086)
* only remove double bos/eos if added

* fix tests
2025-08-05 20:43:36 +02:00
Georgi Gerganov be42642581 readme : update hot topics (#15097) 2025-08-05 20:19:33 +03:00
Romain Biessy 3306ceabf0 sycl: fix mul_mat selection (#15092) 2025-08-05 18:39:55 +02:00
Juk Armstrong c81de6e107 Fix glm4moe bug (#15088) 2025-08-05 13:56:44 +01:00
Alex Wu 22f060c9c4 webui: fix markdown table (#15081)
* webui: fix markdown table

* webui: fix table display with themes
2025-08-05 13:56:44 +02:00
compilade ee3a9fcf88 context : fix index overflow on huge outputs (#15080)
* context : fix overflow when re-ordering huge outputs

* context : fix logits size overflow for huge batches
2025-08-05 11:27:45 +02:00
Diego Devesa ec428b02c3 llama : add --n-cpu-moe option (#15077)
* llama : add --n-cpu-moe option

Keeps the MoE weights of the first N layers in the CPU
2025-08-05 01:05:36 +02:00
compilade 19f68fa5a4 imatrix : warn when GGUF imatrix is saved without .gguf suffix (#15076)
* imatrix : add warning when suffix is not .gguf for GGUF imatrix

* imatrix : only warn about suffix when output format is unspecified
2025-08-04 23:26:52 +02:00
Christian Kastner 41613437ff cmake: Add GGML_BACKEND_DIR option (#15074)
* cmake: Add GGML_BACKEND_DIR option

This can be used by distributions to specify where to look for backends
when ggml is built with GGML_BACKEND_DL=ON.

* Fix phrasing
2025-08-04 21:29:14 +02:00
Sigbjørn Skjæret e5bebe5251 gguf-py : add --chat-template-file to gguf_new_metadata (#15075) 2025-08-04 21:01:48 +02:00
Sam ef0144c087 model: support GLM 4.5 family of models (#14939)
* model: Add GLM 4.5 (#14921)

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Merge in PR suggestions

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* model: Add GLM 4.5 family of models (#14921)

1. Updated tensor_mapping.py with NextN tensor mappings

- Added proper tensor mappings for all NextN/MTP tensors in /Users/samm/git/llama.cpp/gguf-py/gguf/tensor_mapping.py
- Added mappings for: eh_proj, embed_tokens, enorm, hnorm, shared_head.head, shared_head.norm

2. Added num_nextn_predict_layers configuration

- Added LLM_KV_NUM_NEXTN_PREDICT_LAYERS constant to llama-arch.h and llama-arch.cpp
- Added num_nextn_predict_layers field to llama_hparams struct
- Updated GLM4_MOE parameter loading in llama-model.cpp to read this parameter
- Modified tensor loading logic to conditionally load NextN tensors based on num_nextn_predict_layers
- Added GGUF writer support in gguf_writer.py with add_num_nextn_predict_layers() method
- Updated conversion script to extract and write this parameter from HuggingFace config

3. Added FIM tokens for GLM4_MOE

- Added GLM-4.5's FIM tokens to llama-vocab.cpp:
  - <|code_prefix|> for FIM_PRE
  - <|code_suffix|> for FIM_SUF
  - <|code_middle|> for FIM_MID

4. Removed manual NextN tensor handling

- Removed the special-case handling in convert_hf_to_gguf.py that manually mapped NextN tensors
- NextN tensors are now handled automatically through the proper tensor mapping system

* glm 4.5 update tensors names

* model: glm 4.5 apply suggestions from code review

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* model: glm 4.5 apply suggestions from code review

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* model: glm 4.5 apply suggestions from code review

* Apply suggestions from code review

* patch broken chat template

* typings fix

* add TENSOR_SKIP flag


Co-authored-by: Diego Devesa <slarengh@gmail.com>

* Update src/llama-model-loader.h

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Co-authored-by: Diego Devesa <slarengh@gmail.com>
2025-08-04 20:29:25 +02:00
Sigbjørn Skjæret 2721257e3e quantize : fix confusing error message if ftype is invalid (#15071) 2025-08-04 18:11:02 +02:00
Reese Levine 587d0118f5 ggml: WebGPU backend host improvements and style fixing (#14978)
* Add parameter buffer pool, batching of submissions, refactor command building/submission

* Add header for linux builds

* Free staged parameter buffers at once

* Format with clang-format

* Fix thread-safe implementation

* Use device implicit synchronization

* Update workflow to use custom release

* Remove testing branch workflow
2025-08-04 08:52:43 -07:00
Jeff Bolz 5aa1105da2 vulkan: fix build when using glslang that does not support coopmat2 (#15062) 2025-08-04 07:09:19 +02:00
compilade d31192b4ee imatrix : use GGUF by default (#14842)
* imatrix : use GGUF by default

* imatrix : use GGUF regardless of the output filename

The legacy format can only be produced with --output-format dat
2025-08-03 22:00:05 +02:00
compilade 0a2f5496be imatrix : fix 3d activation handling for hybrid and recurrent models (#14994)
* imatrix : use a single count for dense 3d tensors

* imatrix : fix 3d activations when model tensor is 2d

* imatrix : fix 3d tensor counts
2025-08-03 21:49:13 +02:00
compilade 11a3811164 memory : handle kv_unified for hybrid models (#15050) 2025-08-03 21:43:07 +02:00
Csaba Kecskemeti 97366dc6ab vocab : JetBrains Mellum pre-tokenizer (#15045) 2025-08-03 21:38:18 +02:00
Gabriel Larson 83bc2f288c model : add text-only support for Kimi-VL (and find special tokens in text_config) (#15051)
* basic kimi-vl textmodel conversion

* check config["text_config"] for special tokens
2025-08-03 16:56:25 +02:00
Jeff Bolz 6c7a441161 vulkan: Use coopmat2 for conv2d (#14982) 2025-08-03 14:23:57 +02:00
lhez 5c0eb5ef54 opencl: fix adreno compiler detection logic (#15029) 2025-08-02 19:51:18 +02:00
Johannes Gäßler 03d4698218 CUDA: use mma FA kernel for gqa > 4 on RTX 4000 (#15035) 2025-08-02 16:37:08 +02:00
leejet 3303c19b16 cuda: make im2col a little faster (#15025) 2025-08-02 17:15:36 +03:00
Daniel Bevenius 4fdea540bd kv-cache : skip alignment of n_stream in kv-cache log msg [no ci] (#15040)
This commit removes the right alignment the `n_stream` value in the
log message in the `llama_kv_cache_unified` constructor.

The motivation for this change is to enhance the readability of log
message. Currently the output looks like this:
```console
llama_kv_cache_unified: size = 2048.00 MiB (  4096 cells,  32 layers,  1/ 1 seqs), K (f16): 1024.00 MiB, V (f16): 1024.00 MiB
```
Notice that the `n_stream` value is right aligned, which makes it a
little harder to read.

With the change in this commit the output will look like
```console
llama_kv_cache_unified: size = 2048.00 MiB (  4096 cells,  32 layers, 1/1 seqs), K (f16): 1024.00 MiB, V (f16): 1024.00 MiB
```
2025-08-02 17:14:57 +03:00
Georgi Gerganov a4569c41fd llama : enable LLAMA_SET_ROWS=1 by default (#14959)
ggml-ci
2025-08-02 17:14:21 +03:00
Georgi Gerganov 15e92fd337 cuda, sycl : fix batched gemm when ne02 == 1 && ne03 > 1 (#15038)
* cuda, sycl : fix batched gemm when ne02 == 1 && ne03 > 1

ggml-ci

* cont : fix cont types

ggml-ci

* cont : adopt variable names and comment from the other branch
2025-08-02 17:13:05 +03:00
Sigbjørn Skjæret 2bf3fbf0b5 ci : check that pre-tokenizer hashes are up-to-date (#15032)
* torch is not required for convert_hf_to_gguf_update

* add --check-missing parameter

* check that pre-tokenizer hashes are up-to-date
2025-08-02 14:39:01 +02:00
Douglas Hanley 711d5e6fe6 convert : fix Qwen3-Embedding pre-tokenizer hash (#15030) 2025-08-02 12:51:02 +02:00
Jhen-Jie Hong f738989dcb chat : fix multiple tool_calls on hermes-2-pro (#14962) 2025-08-02 18:04:48 +08:00
Jeff Bolz 4cb208c93c vulkan: coopmat2 mul_mat optimizations (#14934)
- Increase tile size for k-quants, to match non-k-quants
- Choose more carefully between large and medium tiles, considering how it
  interacts with split_k
- Allow larger/non-power of two split_k, and make the splits a multiple of 256
- Use split_k==3 to when >1/2 and <=2/3 of the SMs would hae been used
2025-08-02 11:21:37 +02:00
R0CKSTAR 3025b621d1 llama-bench: rename DB table name from test to llama_bench (#15003)
Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
2025-08-02 17:20:40 +08:00
150 changed files with 7203 additions and 1405 deletions
+16 -48
View File
@@ -159,31 +159,15 @@ jobs:
- name: Dawn Dependency
id: dawn-depends
run: |
ARTIFACTS_JSON=$(curl -s -L \
-H "Accept: application/vnd.github+json" \
-H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \
-H "X-GitHub-Api-Version: 2022-11-28" \
"https://api.github.com/repos/google/dawn/actions/artifacts")
echo "Finding latest macos-latest-Release artifact..."
DOWNLOAD_URL=$(echo "$ARTIFACTS_JSON" | jq -r '.artifacts
| sort_by(.created_at)
| reverse
| map(select(.name | test("macos-latest-Release$")))
| .[0].archive_download_url')
if [ "$DOWNLOAD_URL" = "null" ] || [ -z "$DOWNLOAD_URL" ]; then
echo "No suitable Dawn artifact found!"
exit 1
fi
echo "Downloading from: $DOWNLOAD_URL"
curl -L \
-H "Accept: application/vnd.github+json" \
-H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \
-o artifact.zip "$DOWNLOAD_URL"
unzip artifact.zip
DAWN_VERSION="v1.0.0"
DAWN_OWNER="reeselevine"
DAWN_REPO="dawn"
DAWN_ASSET_NAME="Dawn-a1a6b45cced25a3b7f4fb491e0ae70796cc7f22b-macos-latest-Release.tar.gz"
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
curl -L -o artifact.tar.gz \
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
mkdir dawn
tar_file=$(find . -name '*.tar.gz' | head -n 1)
echo "Extracting: $tar_file"
tar -xvf "$tar_file" -C dawn --strip-components=1
tar -xvf artifact.tar.gz -C dawn --strip-components=1
- name: Build
id: cmake_build
@@ -433,31 +417,15 @@ jobs:
id: dawn-depends
run: |
sudo apt-get install -y libxrandr-dev libxinerama-dev libxcursor-dev mesa-common-dev libx11-xcb-dev libxi-dev
ARTIFACTS_JSON=$(curl -s -L \
-H "Accept: application/vnd.github+json" \
-H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \
-H "X-GitHub-Api-Version: 2022-11-28" \
"https://api.github.com/repos/google/dawn/actions/artifacts")
echo "Finding latest ubuntu-latest-Release artifact..."
DOWNLOAD_URL=$(echo "$ARTIFACTS_JSON" | jq -r '.artifacts
| sort_by(.created_at)
| reverse
| map(select(.name | test("ubuntu-latest-Release$")))
| .[0].archive_download_url')
if [ "$DOWNLOAD_URL" = "null" ] || [ -z "$DOWNLOAD_URL" ]; then
echo "No suitable Dawn artifact found!"
exit 1
fi
echo "Downloading from: $DOWNLOAD_URL"
curl -L \
-H "Accept: application/vnd.github+json" \
-H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \
-o artifact.zip "$DOWNLOAD_URL"
unzip artifact.zip
DAWN_VERSION="v1.0.0"
DAWN_OWNER="reeselevine"
DAWN_REPO="dawn"
DAWN_ASSET_NAME="Dawn-a1a6b45cced25a3b7f4fb491e0ae70796cc7f22b-ubuntu-latest-Release.tar.gz"
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
curl -L -o artifact.tar.gz \
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
mkdir dawn
tar_file=$(find . -name '*.tar.gz' | head -n 1)
echo "Extracting: $tar_file"
tar -xvf "$tar_file" -C dawn --strip-components=1
tar -xvf artifact.tar.gz -C dawn --strip-components=1
- name: Build
id: cmake_build
@@ -0,0 +1,45 @@
name: Check Pre-Tokenizer Hashes
on:
push:
paths:
- 'convert_hf_to_gguf.py'
- 'convert_hf_to_gguf_update.py'
pull_request:
paths:
- 'convert_hf_to_gguf.py'
- 'convert_hf_to_gguf_update.py'
jobs:
pre-tokenizer-hashes:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Install Python dependencies
run: |
python3 -m venv .venv
.venv/bin/pip install -r requirements/requirements-convert_hf_to_gguf_update.txt
- name: Update pre-tokenizer hashes
run: |
cp convert_hf_to_gguf.py /tmp
.venv/bin/python convert_hf_to_gguf_update.py --check-missing
- name: Check if committed pre-tokenizer hashes matches generated version
run: |
if ! diff -q convert_hf_to_gguf.py /tmp/convert_hf_to_gguf.py; then
echo "Model pre-tokenizer hashes (in convert_hf_to_gguf.py) do not match generated hashes (from convert_hf_to_gguf_update.py)."
echo "To fix: run ./convert_hf_to_gguf_update.py and commit the updated convert_hf_to_gguf.py along with your changes"
echo "Differences found:"
diff convert_hf_to_gguf.py /tmp/convert_hf_to_gguf.py || true
exit 1
fi
echo "Model pre-tokenizer hashes are up to date."
+1
View File
@@ -17,6 +17,7 @@ LLM inference in C/C++
## Hot topics
- Support for the `gpt-oss` model with native MXFP4 format has been added | [PR](https://github.com/ggml-org/llama.cpp/pull/15091) | [Collaboration with NVIDIA](https://blogs.nvidia.com/blog/rtx-ai-garage-openai-oss) | [Comment](https://github.com/ggml-org/llama.cpp/discussions/15095)
- Hot PRs: [All](https://github.com/ggml-org/llama.cpp/pulls?q=is%3Apr+label%3Ahot+) | [Open](https://github.com/ggml-org/llama.cpp/pulls?q=is%3Apr+label%3Ahot+is%3Aopen)
- Multimodal support arrived in `llama-server`: [#12898](https://github.com/ggml-org/llama.cpp/pull/12898) | [documentation](./docs/multimodal.md)
- VS Code extension for FIM completions: https://github.com/ggml-org/llama.vscode
+34 -8
View File
@@ -24,6 +24,7 @@
#include <cstdarg>
#include <filesystem>
#include <fstream>
#include <list>
#include <regex>
#include <set>
#include <string>
@@ -2375,20 +2376,35 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
throw std::invalid_argument("unknown buffer type");
}
// FIXME: this leaks memory
params.tensor_buft_overrides.push_back({strdup(tensor_name.c_str()), buft_list.at(buffer_type)});
// keep strings alive and avoid leaking memory by storing them in a static vector
static std::list<std::string> buft_overrides;
buft_overrides.push_back(tensor_name);
params.tensor_buft_overrides.push_back({buft_overrides.back().c_str(), buft_list.at(buffer_type)});
}
}
));
add_opt(common_arg(
{"--cpu-moe"},
"use CPU for Mixture of Experts (MoE) weights",
{"--cpu-moe", "-cmoe"},
"keep all Mixture of Experts (MoE) weights in the CPU",
[](common_params & params) {
params.tensor_buft_overrides.push_back({"\\.ffn_up_exps\\.weight$", ggml_backend_cpu_buffer_type()});
params.tensor_buft_overrides.push_back({"\\.ffn_down_exps\\.weight$", ggml_backend_cpu_buffer_type()});
params.tensor_buft_overrides.push_back({"\\.ffn_gate_exps\\.weight$", ggml_backend_cpu_buffer_type()});
params.tensor_buft_overrides.push_back({"\\.ffn_(up|down|gate)_exps", ggml_backend_cpu_buffer_type()});
}
).set_env("LLAMA_ARG_CPU_MOE"));
add_opt(common_arg(
{"--n-cpu-moe", "-ncmoe"}, "N",
"keep the Mixture of Experts (MoE) weights of the first N layers in the CPU",
[](common_params & params, int value) {
if (value < 0) {
throw std::invalid_argument("invalid value");
}
for (int i = 0; i < value; ++i) {
// keep strings alive and avoid leaking memory by storing them in a static vector
static std::list<std::string> buft_overrides;
buft_overrides.push_back(string_format("blk\\.%d\\.ffn_(up|down|gate)_exps", i));
params.tensor_buft_overrides.push_back({buft_overrides.back().c_str(), ggml_backend_cpu_buffer_type()});
}
}
).set_env("LLAMA_ARG_N_CPU_MOE"));
add_opt(common_arg(
{"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N",
"number of layers to store in VRAM",
@@ -2647,6 +2663,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.n_out_freq = value;
}
).set_examples({LLAMA_EXAMPLE_IMATRIX}));
add_opt(common_arg(
{"--output-format"}, "{gguf,dat}",
string_format("output format for imatrix file (default: %s)", params.imat_dat > 0 ? "dat" : "gguf"),
[](common_params & params, const std::string & value) {
/**/ if (value == "gguf") { params.imat_dat = -1; }
else if (value == "dat") { params.imat_dat = 1; }
else { throw std::invalid_argument("invalid output format"); }
}
).set_examples({LLAMA_EXAMPLE_IMATRIX}));
add_opt(common_arg(
{"--save-frequency"}, "N",
string_format("save an imatrix copy every N iterations (default: %d)", params.n_save_freq),
@@ -2922,11 +2947,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
"controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n"
"- none: leaves thoughts unparsed in `message.content`\n"
"- deepseek: puts thoughts in `message.reasoning_content` (except in streaming mode, which behaves as `none`)\n"
"(default: deepseek)",
"(default: auto)",
[](common_params & params, const std::string & value) {
/**/ if (value == "deepseek") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; }
else if (value == "deepseek-legacy") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY; }
else if (value == "none") { params.reasoning_format = COMMON_REASONING_FORMAT_NONE; }
else if (value == "auto") { params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; }
else { throw std::invalid_argument("invalid value"); }
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK"));
+9 -1
View File
@@ -55,7 +55,15 @@ bool common_chat_msg_parser::add_tool_call(const std::string & name, const std::
bool common_chat_msg_parser::add_tool_call(const json & tool_call) {
std::string name = tool_call.contains("name") ? tool_call.at("name") : "";
std::string id = tool_call.contains("id") ? tool_call.at("id") : "";
std::string arguments = tool_call.contains("arguments") ? tool_call.at("arguments") : "";
std::string arguments = "";
if (tool_call.contains("arguments")) {
if (tool_call.at("arguments").is_object()) {
arguments = tool_call.at("arguments").dump();
} else {
arguments = tool_call.at("arguments");
}
}
return add_tool_call(name, id, arguments);
}
+178 -5
View File
@@ -126,6 +126,8 @@ std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const comm
typedef minja::chat_template common_chat_template;
struct common_chat_templates {
bool add_bos;
bool add_eos;
bool has_explicit_template; // Model had builtin template or template overridde was specified.
std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
std::unique_ptr<common_chat_template> template_tool_use;
@@ -143,6 +145,8 @@ struct templates_params {
bool enable_thinking = true;
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
json extra_context;
bool add_bos;
bool add_eos;
};
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
@@ -445,6 +449,8 @@ std::string common_chat_format_single(
common_chat_templates_inputs inputs;
inputs.use_jinja = use_jinja;
inputs.add_bos = tmpls->add_bos;
inputs.add_eos = tmpls->add_eos;
std::string fmt_past_msg;
if (!past_msg.empty()) {
@@ -469,6 +475,8 @@ std::string common_chat_format_single(
std::string common_chat_format_example(const struct common_chat_templates * tmpls, bool use_jinja) {
common_chat_templates_inputs inputs;
inputs.use_jinja = use_jinja;
inputs.add_bos = tmpls->add_bos;
inputs.add_eos = tmpls->add_eos;
auto add_simple_msg = [&](auto role, auto content) {
common_chat_msg msg;
msg.role = role;
@@ -546,6 +554,8 @@ common_chat_templates_ptr common_chat_templates_init(
}
std::string token_bos = bos_token_override;
std::string token_eos = eos_token_override;
bool add_bos = false;
bool add_eos = false;
if (model) {
const auto * vocab = llama_model_get_vocab(model);
const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
@@ -560,9 +570,13 @@ common_chat_templates_ptr common_chat_templates_init(
};
token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
add_bos = llama_vocab_get_add_bos(vocab);
add_eos = llama_vocab_get_add_eos(vocab);
}
common_chat_templates_ptr tmpls(new common_chat_templates());
tmpls->has_explicit_template = has_explicit_template;
tmpls->add_bos = add_bos;
tmpls->add_eos = add_eos;
try {
tmpls->template_default = std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos);
} catch (const std::exception & e) {
@@ -592,6 +606,8 @@ const char * common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
case COMMON_CHAT_FORMAT_GRANITE: return "Granite";
case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS";
default:
throw std::runtime_error("Unknown chat format");
}
@@ -600,8 +616,10 @@ const char * common_chat_format_name(common_chat_format format) {
const char * common_reasoning_format_name(common_reasoning_format format) {
switch (format) {
case COMMON_REASONING_FORMAT_NONE: return "none";
case COMMON_REASONING_FORMAT_AUTO: return "auto";
case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek";
case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: return "deepseek-legacy";
case COMMON_REASONING_FORMAT_GRANITE: return "granite";
default:
throw std::runtime_error("Unknown reasoning format");
}
@@ -748,10 +766,10 @@ static std::string apply(
// instead of using `chat_template_options.use_bos_token = false`, since these tokens
// may be needed inside the template / between messages too.
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
if (string_starts_with(result, tmpl.bos_token())) {
if (inputs.add_bos && string_starts_with(result, tmpl.bos_token())) {
result = result.substr(tmpl.bos_token().size());
}
if (string_ends_with(result, tmpl.eos_token())) {
if (inputs.add_eos && string_ends_with(result, tmpl.eos_token())) {
result = result.substr(0, result.size() - tmpl.eos_token().size());
}
return result;
@@ -1289,6 +1307,26 @@ static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
tool_calls_end);
}
static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
auto prompt = apply(tmpl, inputs);
data.prompt = prompt;
data.format = COMMON_CHAT_FORMAT_GPT_OSS;
// TODO: support tool calls in GPT-OSS?
return data;
}
static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
// TODO @ngxson : this won't work with --special enabled, we should fix that
builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|start|>assistant<|channel|>final<|message|>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
}
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
LOG_DBG("%s\n", __func__);
common_chat_params data;
@@ -1646,7 +1684,7 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
"|<function name=\"([^\"]+)\">" // match 5 (function name again)
);
if (auto res = builder.try_find_regex(open_regex)) {
while (auto res = builder.try_find_regex(open_regex)) {
const auto & block_start = res->groups[1];
std::string block_end = block_start.empty() ? "" : "```";
@@ -1668,7 +1706,6 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
builder.consume_literal(block_end);
builder.consume_spaces();
}
builder.add_content(builder.consume_rest());
} else {
throw common_chat_msg_partial_exception("failed to parse tool call");
}
@@ -1693,7 +1730,124 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
builder.consume_spaces();
}
}
builder.add_content(builder.consume_rest());
}
}
builder.add_content(builder.consume_rest());
}
static common_chat_params common_chat_params_init_granite(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
// Pass thinking context for Granite template
json additional_context = {
{"thinking", inputs.enable_thinking},
};
data.prompt = apply(tmpl, inputs, /* messages_override= */ std::nullopt, /* tools_override= */ std::nullopt, additional_context);
data.format = COMMON_CHAT_FORMAT_GRANITE;
if (string_ends_with(data.prompt, "<think>\n") || string_ends_with(data.prompt, "<think>")) {
if (!inputs.enable_thinking) {
data.prompt += "</think>";
} else {
data.thinking_forced_open = true;
}
}
if (!inputs.tools.is_null()) {
// Granite uses <|tool_call|> followed by JSON list
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
std::string name = function.at("name");
auto parameters = function.at("parameters");
builder.resolve_refs(parameters);
tool_rules.push_back(builder.add_rule(name + "-call", builder.add_schema(name +
"-args", {
{"type", "object"},
{"properties", {
{"name", {{"const", name}}},
{"arguments", parameters},
}},
{"required", json::array({"name", "arguments"})},
})));
});
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | "));
auto tool_list = builder.add_rule("tool_list", "\"[\" space " + tool_call + " (\",\" space " + tool_call + ")* space \"]\"");
if (data.thinking_forced_open) {
builder.add_rule("root", "\"</think>\" space \"<response>\" space [^<]* \"</response>\" space \"<|tool_call|>\" space " + tool_list);
} else {
builder.add_rule("root", "\"<|tool_call|>\" space " + tool_list);
}
data.grammar_triggers.push_back({
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
"<|tool_call|>"
});
data.preserved_tokens = {
"<think>",
"</think>",
"<response>",
"</response>",
"<|tool_call|>",
};
});
} else {
// Handle thinking tags for non-tool responses
if (data.thinking_forced_open && inputs.enable_thinking) {
data.grammar_lazy = false;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
builder.add_rule("root", "\"</think>\" space \"<response>\" space .* \"</response>\" space");
});
data.preserved_tokens = {
"<think>",
"</think>",
"<response>",
"</response>",
};
}
}
return data;
}
static void common_chat_parse_granite(common_chat_msg_parser & builder) {
// Parse thinking tags
builder.try_parse_reasoning("<think>", "</think>");
// Parse response tags using regex
static const common_regex response_regex("<response>([\\s\\S]*?)</response>");
if (auto res = builder.try_find_regex(response_regex)) {
// Extract the content between the tags (capture group 1)
auto content = builder.str(res->groups[1]);
builder.add_content(content);
builder.move_to(res->groups[0].end);
}
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// Look for tool calls
static const common_regex tool_call_regex(regex_escape("<|tool_call|>"));
if (auto res = builder.try_find_regex(tool_call_regex)) {
builder.move_to(res->groups[0].end);
// Expect JSON array of tool calls
auto tool_calls_data = builder.consume_json();
if (tool_calls_data.json.is_array()) {
if (!builder.add_tool_calls(tool_calls_data.json)) {
builder.add_content("<|tool_call|>" + tool_calls_data.json.dump());
}
} else {
builder.add_content("<|tool_call|>" + tool_calls_data.json.dump());
}
} else {
builder.add_content(builder.consume_rest());
@@ -1733,6 +1887,8 @@ static common_chat_params common_chat_templates_apply_jinja(
params.enable_thinking = inputs.enable_thinking;
params.grammar = inputs.grammar;
params.now = inputs.now;
params.add_bos = inputs.add_bos;
params.add_eos = inputs.add_eos;
params.extra_context = json::object();
for (auto el : inputs.chat_template_kwargs) {
@@ -1769,11 +1925,21 @@ static common_chat_params common_chat_templates_apply_jinja(
return common_chat_params_init_command_r7b(tmpl, params);
}
// Granite (IBM) - detects thinking / tools support
if (src.find("elif thinking") != std::string::npos && src.find("<|tool_call|>") != std::string::npos) {
return common_chat_params_init_granite(tmpl, params);
}
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null()) {
return common_chat_params_init_hermes_2_pro(tmpl, params);
}
// GPT-OSS
if (src.find("<|channel|>") != std::string::npos && params.json_schema.is_null()) {
return common_chat_params_init_gpt_oss(tmpl, params);
}
// Use generic handler when mixing tools + JSON schema.
// TODO: support that mix in handlers below.
if ((params.tools.is_array() && params.json_schema.is_object())) {
@@ -1824,6 +1990,7 @@ static common_chat_params common_chat_templates_apply_legacy(
int alloc_size = 0;
std::vector<llama_chat_message> chat;
std::vector<std::string> contents;
for (const auto & msg : inputs.messages) {
auto content = msg.content;
for (const auto & part : msg.content_parts) {
@@ -1925,6 +2092,12 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
case COMMON_CHAT_FORMAT_COMMAND_R7B:
common_chat_parse_command_r7b(builder);
break;
case COMMON_CHAT_FORMAT_GRANITE:
common_chat_parse_granite(builder);
break;
case COMMON_CHAT_FORMAT_GPT_OSS:
common_chat_parse_gpt_oss(builder);
break;
default:
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
}
+4
View File
@@ -109,6 +109,8 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
COMMON_CHAT_FORMAT_HERMES_2_PRO,
COMMON_CHAT_FORMAT_COMMAND_R7B,
COMMON_CHAT_FORMAT_GRANITE,
COMMON_CHAT_FORMAT_GPT_OSS,
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
};
@@ -127,6 +129,8 @@ struct common_chat_templates_inputs {
bool enable_thinking = true;
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
std::map<std::string, std::string> chat_template_kwargs;
bool add_bos = false;
bool add_eos = false;
};
struct common_chat_params {
+4 -1
View File
@@ -236,8 +236,10 @@ struct common_params_diffusion {
enum common_reasoning_format {
COMMON_REASONING_FORMAT_NONE,
COMMON_REASONING_FORMAT_AUTO,
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
COMMON_REASONING_FORMAT_GRANITE, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
};
struct common_params {
@@ -394,7 +396,7 @@ struct common_params {
std::string chat_template = ""; // NOLINT
bool use_jinja = false; // NOLINT
bool enable_chat_template = true;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO;
int reasoning_budget = -1;
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
@@ -439,6 +441,7 @@ struct common_params {
int32_t n_out_freq = 10; // output the imatrix every n_out_freq iterations
int32_t n_save_freq = 0; // save the imatrix every n_save_freq iterations
int32_t i_chunk = 0; // start processing from this chunk
int8_t imat_dat = 0; // whether the legacy imatrix.dat format should be output (gguf <= 0 < dat)
bool process_output = false; // collect data for the output tensor
bool compute_ppl = true; // whether to compute perplexity
+382 -5
View File
@@ -678,6 +678,9 @@ class TextModel(ModelBase):
if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2":
# ref: https://huggingface.co/THUDM/glm-4-9b-hf
res = "glm4"
if chkhsh == "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902":
# ref: https://huggingface.co/zai-org/GLM-4.5-Air
res = "glm4"
if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35":
# ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
res = "minerva-7b"
@@ -702,6 +705,9 @@ class TextModel(ModelBase):
if chkhsh == "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890":
# ref: https://huggingface.co/moonshotai/Kimi-K2-Base
res = "kimi-k2"
if chkhsh == "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c":
# ref: https://huggingface.co/Qwen/Qwen3-Embedding-0.6B
res = "qwen2"
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
res = "llama-bpe"
@@ -849,9 +855,9 @@ class TextModel(ModelBase):
if chkhsh == "2085e1638f6c377a0aa4ead21b27bb4cb941bf800df86ed391011769c1758dfb":
# ref: https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B
res = "exaone4"
if chkhsh == "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c":
# ref: https://huggingface.co/Qwen/Qwen3-Embedding-8B
res = "qwen2"
if chkhsh == "a1e163ecab2e718a4c829d1148b6e86824ec36163bb71941c3dca9cd5ac25756":
# ref: https://huggingface.co/JetBrains/Mellum-4b-base
res = "mellum"
if res is None:
logger.warning("\n")
@@ -3322,7 +3328,13 @@ class Qwen25OmniModel(Qwen2VLVisionModel):
@ModelBase.register("InternVisionModel")
class InternVisionModel(MmprojModel):
def set_gguf_parameters(self):
assert self.hparams_vision is not None
if isinstance(self.hparams_vision['image_size'], list):
self.hparams_vision['image_size'] = self.hparams_vision['image_size'][0]
if isinstance(self.hparams_vision['patch_size'], list):
self.hparams_vision['patch_size'] = self.hparams_vision['patch_size'][0]
super().set_gguf_parameters()
hparams = self.hparams
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.INTERNVL)
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
@@ -3346,14 +3358,30 @@ class InternVisionModel(MmprojModel):
return gguf.GGMLQuantizationType.F32
return False
def _mapping_interns1_name(self, name):
names_map = {
"model.multi_modal_projector.layer_norm.bias": "mlp1.0.bias",
"model.multi_modal_projector.layer_norm.weight": "mlp1.0.weight",
"model.multi_modal_projector.linear_1.bias": "mlp1.1.bias",
"model.multi_modal_projector.linear_1.weight": "mlp1.1.weight",
"model.multi_modal_projector.linear_2.bias": "mlp1.3.bias",
"model.multi_modal_projector.linear_2.weight": "mlp1.3.weight",
}
if name in names_map:
name = names_map[name]
return name
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
if name.startswith("vision_model") or name.startswith("mlp"):
vision_prefix = ['vision_model', 'mlp', 'model.vision_tower', 'model.multi_modal_projector']
# deal with intern-s1 special case
name = self._mapping_interns1_name(name)
if any([name.startswith(prefix) for prefix in vision_prefix]):
# process visual tensors
# correct name
if name.startswith("vision_model"):
name = "vision_tower." + name
if (".ls" in name or "position_embedding" in name) and not name.endswith(".weight"):
if (".ls" in name or ".lambda_" in name or "position_embedding" in name) and not name.endswith(".weight"):
name += ".weight"
# split QKV tensors if needed
if ".qkv." in name:
@@ -3439,6 +3467,10 @@ class Qwen2MoeModel(TextModel):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# process the experts separately
name = name.replace("language_model.", "") # InternVL
if name.startswith("mlp") or name.startswith("vision_model") or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector"):
# skip visual tensors
return []
if name.find("experts") != -1:
n_experts = self.hparams["num_experts"]
assert bid is not None
@@ -3492,6 +3524,85 @@ class Qwen3Model(Qwen2Model):
class Qwen3MoeModel(Qwen2MoeModel):
model_arch = gguf.MODEL_ARCH.QWEN3MOE
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
hparams = ModelBase.load_hparams(self.dir_model)
self.origin_hf_arch = hparams.get('architectures', [None])[0]
def set_vocab(self):
# deal with intern-s1
if self.origin_hf_arch == 'InternS1ForConditionalGeneration':
self._set_vocab_interns1()
return
try:
self._set_vocab_sentencepiece()
except FileNotFoundError:
self._set_vocab_gpt2()
def _set_vocab_interns1(self):
tokens: list[str] = []
toktypes: list[int] = []
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
vocab = getattr(tokenizer, 'vocab', tokenizer.get_vocab())
vocab_size = self.hparams.get("vocab_size", len(vocab))
assert max(vocab.values()) < vocab_size
tokpre = self.get_vocab_base_pre(tokenizer)
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab.items()}
added_vocab = tokenizer.get_added_vocab()
added_tokens_decoder = tokenizer.added_tokens_decoder
for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.UNUSED)
else:
token: str = reverse_vocab[i]
if token in added_vocab:
# The tokenizer in llama.cpp assumes the CONTROL and USER_DEFINED tokens are pre-normalized.
# To avoid unexpected issues - we make sure to normalize non-normalized tokens
if not added_tokens_decoder[i].normalized:
previous_token = token
token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False))
if previous_token != token:
logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer")
if added_tokens_decoder[i].special or self.does_token_look_special(token):
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.USER_DEFINED)
else:
toktypes.append(gguf.TokenType.NORMAL)
tokens.append(token)
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
special_tokens_map_file = self.dir_model / 'special_tokens_map.json'
additional_special_tokens = []
if special_tokens_map_file.is_file():
with open(special_tokens_map_file, encoding = 'utf-8') as f:
additional_special_tokens = json.load(f).get('additional_special_tokens', [])
tokenizer_cfg_file = self.dir_model / 'special_tokens_map.json'
if tokenizer_cfg_file.is_file():
with open(tokenizer_cfg_file, encoding = 'utf-8') as f:
added_tokens_decoder = json.load(f).get('added_tokens_decoder', {})
token2ids_map = {data['content'] : int(token) for token, data in added_tokens_decoder.items() if data['special']}
for token in additional_special_tokens:
if token in token2ids_map:
special_vocab._set_special_token(token, token2ids_map[token])
special_vocab._set_special_token('eos', 151645)
special_vocab._set_special_token("bos", 151643)
special_vocab.add_to_gguf(self.gguf_writer)
@ModelBase.register("GPT2LMHeadModel")
class GPT2Model(TextModel):
@@ -6059,6 +6170,7 @@ class DeepseekModel(TextModel):
@ModelBase.register("DeepseekV2ForCausalLM")
@ModelBase.register("DeepseekV3ForCausalLM")
@ModelBase.register("KimiVLForConditionalGeneration")
class DeepseekV2Model(TextModel):
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
@@ -6161,6 +6273,13 @@ class DeepseekV2Model(TextModel):
_experts: list[dict[str, Tensor]] | None = None
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# skip vision tensors and remove "language_model." for Kimi-VL
if "vision_tower" in name or "multi_modal_projector" in name:
return []
if name.startswith("language_model."):
name = name.replace("language_model.", "")
# rename e_score_correction_bias tensors
if name.endswith("e_score_correction_bias"):
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
@@ -6685,6 +6804,139 @@ class Glm4Model(TextModel):
return super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Glm4MoeForCausalLM")
class Glm4MoeModel(TextModel):
model_arch = gguf.MODEL_ARCH.GLM4_MOE
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# GLM4_MOE has num_hidden_layers + 1 actual layers (including NextN layer)
self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("num_nextn_predict_layers", 0)
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
def set_vocab(self):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
tokens, toktypes, tokpre = self.get_vocab_base()
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
# Special tokens
# Note: Using <|endoftext|> (151329) for eot causes endless generation
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329
special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338
# Patch broken chat template
if isinstance(special_vocab.chat_template, str) and "visible_text(m.content).endswith" in special_vocab.chat_template:
special_vocab.chat_template = special_vocab.chat_template.replace(
"""{{ visible_text(m.content) }}\n{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not visible_text(m.content).endswith("/nothink")) else '' -}}""",
"""{% set content = visible_text(m.content) %}{{ content }}\n{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not content.endswith("/nothink")) else '' -}}""")
special_vocab.add_to_gguf(self.gguf_writer)
def set_gguf_parameters(self):
super().set_gguf_parameters()
if (rope_dim := self.hparams.get("head_dim")) is None:
rope_dim = (
self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
)
self.gguf_writer.add_rope_dimension_count(
int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5))
)
# MoE parameters - Use only routed expert count (shared experts handled separately)
if (n_routed_experts := self.hparams.get("n_routed_experts")) is not None:
self.gguf_writer.add_expert_count(n_routed_experts)
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
if (n_shared_experts := self.hparams.get("n_shared_experts")) is not None:
self.gguf_writer.add_expert_shared_count(n_shared_experts)
if (first_k_dense_replace := self.hparams.get("first_k_dense_replace")) is not None:
self.gguf_writer.add_leading_dense_block_count(first_k_dense_replace)
# Expert gating function (sigmoid for GLM4_MOE)
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
# Routed scaling factor
if (routed_scaling_factor := self.hparams.get("routed_scaling_factor")) is not None:
self.gguf_writer.add_expert_weights_scale(routed_scaling_factor)
# Normalise topk probabilities
if (norm_topk_prob := self.hparams.get("norm_topk_prob")) is not None:
self.gguf_writer.add_expert_weights_norm(norm_topk_prob)
# NextN/MTP prediction layers
if (num_nextn_predict_layers := self.hparams.get("num_nextn_predict_layers")) is not None:
self.gguf_writer.add_nextn_predict_layers(num_nextn_predict_layers)
_experts: list[dict[str, Tensor]] | None = None
def modify_tensors(
self, data_torch: Tensor, name: str, bid: int | None
) -> Iterable[tuple[str, Tensor]]:
if name.startswith("model.visual."): # ignore visual part
return []
elif name.startswith("model.language_model."):
name = name.replace("language_model.", "") # for multimodal variants
# Handle main token embedding (but not layer-specific NextN embeddings)
if name == "model.embed_tokens.weight" and ".layers." not in name:
return [(self.map_tensor_name("token_embd.weight"), data_torch)]
# Handle routed experts
if name.find("mlp.experts") != -1:
n_experts = self.hparams["n_routed_experts"]
assert bid is not None
if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]
self._experts[bid][name] = data_torch
if len(self._experts[bid]) >= n_experts * 3:
tensors: list[tuple[str, Tensor]] = []
# merge the experts into a single 3d tensor
for w_name in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []
for xid in range(n_experts):
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]
data_torch = torch.stack(datas, dim=0)
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
new_name = self.map_tensor_name(merged_name)
tensors.append((new_name, data_torch))
return tensors
else:
return []
if name.endswith("e_score_correction_bias"):
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
new_name = self.map_tensor_name(name)
return [(new_name, data_torch)]
def prepare_tensors(self):
super().prepare_tensors()
if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")
@ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration")
class ChatGLMModel(TextModel):
model_arch = gguf.MODEL_ARCH.CHATGLM
@@ -7803,6 +8055,130 @@ class SmolLM3Model(LlamaModel):
self.gguf_writer.add_chat_template(chat_template)
@ModelBase.register("GptOssForCausalLM")
class GptOssModel(TextModel):
model_arch = gguf.MODEL_ARCH.GPT_OSS
def transform_nibble_layout(self, tensor):
assert tensor.dtype == torch.uint8
assert tensor.shape[-1] == 16
# swap nibbles
t_lo = tensor & 0x0F
t_hi = tensor & 0xF0
t_swapped = (t_lo << 4) | (t_hi >> 4)
tensor = t_swapped
# transform aaaa...bbbb... to abababab...
blk_a, blk_b = tensor.chunk(2, dim=-1)
# get a_
blk_a0 = (blk_a & 0xF0).view(-1, 1)
blk_a1 = (blk_a << 4).view(-1, 1)
blk_a = torch.stack((blk_a0, blk_a1), dim=2).view(tensor.shape)
# get _b
blk_b0 = (blk_b >> 4).view(-1, 1)
blk_b1 = (blk_b & 0x0F).view(-1, 1)
blk_b = torch.stack((blk_b0, blk_b1), dim=2).view(tensor.shape)
# swap once more
out = blk_a | blk_b
out_h = out & 0xF0
out_l = out & 0x0F
out = (out_h >> 4) | (out_l << 4)
return out
def repack_mxfp4(self, new_name: str, blocks: Tensor, scales: Tensor):
assert blocks.dtype == torch.uint8
assert scales.dtype == torch.uint8
scales = scales.unsqueeze(-1)
assert len(blocks.shape) == 4
assert len(scales.shape) == 4
blocks = self.transform_nibble_layout(blocks)
new_data = torch.concat((scales, blocks), dim=-1)
new_shape = [new_data.shape[0], new_data.shape[1], new_data.shape[2] * 32]
logger.info(f"Repacked {new_name} with shape {new_shape} and quantization MXFP4")
# flatten last dim
new_data = new_data.view(new_data.shape[0], new_data.shape[1], new_data.shape[2] * new_data.shape[3])
new_data = new_data.numpy()
self.gguf_writer.add_tensor(new_name, new_data, raw_dtype=gguf.GGMLQuantizationType.MXFP4)
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
blocks0: Tensor = torch.zeros(1)
blocks1: Tensor = torch.zeros(1)
# we assume that tensors are loaded in the correct order
for name, data_torch in self.get_tensors():
if "mlp.experts.down_proj_blocks" in name:
blocks0 = data_torch
elif "mlp.experts.down_proj_scales" in name:
new_name = self.map_tensor_name(name.replace("_scales", ".weight"))
self.repack_mxfp4(new_name, blocks0, data_torch)
elif "mlp.experts.gate_up_proj_blocks" in name:
blocks0, blocks1 = data_torch[:, ::2, :, :], data_torch[:, 1::2, :, :]
elif "mlp.experts.gate_up_proj_scales" in name:
scales0, scales1 = data_torch[:, ::2, :], data_torch[:, 1::2, :]
new_name_gate = self.map_tensor_name(name.replace("gate_up_proj_scales", "gate_proj.weight"))
new_name_up = self.map_tensor_name(name.replace("gate_up_proj_scales", "up_proj.weight"))
self.repack_mxfp4(new_name_gate, blocks0, scales0)
self.repack_mxfp4(new_name_up, blocks1, scales1)
return []
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
if "sinks" in name:
name += ".weight"
# correct naming for down_proj
if "down_proj" in name:
if name.endswith("_bias"):
name = name.replace("down_proj_bias", "down_proj.bias")
elif "_blocks" not in name and "_scales" not in name:
logger.warning(f"{name} is not in MXFP4, performance may be degraded")
name = name.replace("down_proj", "down_proj.weight")
data_torch = data_torch.transpose(-1, -2)
else:
# otherwise, it should already be repacked to ggml MXFP4 format
return []
# split the gate_up into gate and up
if "gate_up_proj" in name:
if name.endswith("_bias"):
name_up = name.replace("gate_up_proj_bias", "up_proj.bias")
name_gate = name.replace("gate_up_proj_bias", "gate_proj.bias")
gate_proj_bias, up_proj_bias = data_torch[..., ::2], data_torch[..., 1::2]
return [
(self.map_tensor_name(name_gate), gate_proj_bias),
(self.map_tensor_name(name_up), up_proj_bias)
]
elif "_blocks" not in name and "_scales" not in name:
logger.warning(f"{name} is not in MXFP4, performance may be degraded")
name_up = name.replace("gate_up_proj", "up_proj.weight")
name_gate = name.replace("gate_up_proj", "gate_proj.weight")
data_torch = data_torch.transpose(-1, -2)
gate_proj_weight, up_proj_weight = data_torch[:, ::2, :], data_torch[:, 1::2, :]
return [
(self.map_tensor_name(name_gate), gate_proj_weight),
(self.map_tensor_name(name_up), up_proj_weight)
]
else:
# otherwise, it should already be repacked to ggml MXFP4 format
return []
return [(self.map_tensor_name(name), data_torch)]
def set_vocab(self):
self._set_vocab_gpt2()
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
self.gguf_writer.add_expert_feed_forward_length(self.hparams["intermediate_size"])
rope_scaling = self.hparams.get("rope_scaling") or {}
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type"))
assert rope_type == "yarn", f"GPT-OSS only supports yarn rope scaling, got {rope_type}"
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling.get("original_max_position_embeddings", 4096))
@ModelBase.register("Lfm2ForCausalLM")
@ModelBase.register("LFM2ForCausalLM")
class LFM2Model(TextModel):
@@ -7942,6 +8318,7 @@ class LazyTorchTensor(gguf.LazyBase):
_dtype_map: dict[torch.dtype, type] = {
torch.float16: np.float16,
torch.float32: np.float32,
torch.uint8: np.uint8,
}
# used for safetensors slices
+18 -6
View File
@@ -59,6 +59,10 @@ parser.add_argument(
"--full", action="store_true",
help="download full list of models - make sure you have access to all of them",
)
parser.add_argument(
"--check-missing", action="store_true",
help="only check for missing pre-tokenizer hashes",
)
parser.add_argument(
"hf_token",
help="optional HF token",
@@ -70,6 +74,10 @@ hf_token = args.hf_token if args.hf_token is not None else hf_token
if hf_token is None:
logger.warning("HF token not found. You can provide it as an argument or set it in ~/.cache/huggingface/token")
if args.check_missing and args.full:
logger.warning("Downloading full list of models requested, ignoring --check-missing!")
args.check_missing = False
# TODO: this string has to exercise as much pre-tokenizer functionality as possible
# will be updated with time - contributions welcome
CHK_TXT = '\n \n\n \n\n\n \t \t\t \t\n \n \n \n \n🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български \'\'\'\'\'\'```````\"\"\"\"......!!!!!!?????? I\'ve been \'told he\'s there, \'RE you sure? \'M not sure I\'ll make it, \'D you like some tea? We\'Ve a\'lL'
@@ -130,6 +138,7 @@ models = [
{"name": "midm-2.0", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/K-intelligence/Midm-2.0-Base-Instruct", },
{"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
{"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
]
# some models are known to be broken upstream, so we will skip them as exceptions
@@ -138,6 +147,7 @@ pre_computed_hashes = [
{"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b"},
{"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"},
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"},
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.5-Air", "chkhsh": "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902"},
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"},
{"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"},
{"name": "hunyuan-dense", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-4B-Instruct", "chkhsh": "bba3b3366b646dbdded5dbc42d59598b849371afc42f7beafa914afaa5b70aa6"},
@@ -147,6 +157,7 @@ pre_computed_hashes = [
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-7B-Base", "chkhsh": "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896"},
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-34B-Base", "chkhsh": "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b"},
{"name": "kimi-k2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/moonshotai/Kimi-K2-Base", "chkhsh": "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890"},
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3-Embedding-0.6B", "chkhsh": "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c"},
]
@@ -221,12 +232,13 @@ if not args.full:
all_models = models.copy()
models = [model for model in all_models if model["name"] not in existing_models]
logging.info(f"Downloading {len(models)} models...")
for model in models:
try:
download_model(model)
except Exception as e:
logger.error(f"Failed to download model {model['name']}. Error: {e}")
if not args.check_missing:
logging.info(f"Downloading {len(models)} models...")
for model in models:
try:
download_model(model)
except Exception as e:
logger.error(f"Failed to download model {model['name']}. Error: {e}")
# generate the source code for the convert_hf_to_gguf.py:get_vocab_base_pre() function:
+4 -2
View File
@@ -39,8 +39,9 @@ if (WIN32)
set(CMAKE_SHARED_MODULE_PREFIX "")
endif()
option(BUILD_SHARED_LIBS "ggml: build shared libraries" ${BUILD_SHARED_LIBS_DEFAULT})
option(GGML_BACKEND_DL "ggml: build backends as dynamic libraries (requires BUILD_SHARED_LIBS)" OFF)
option(BUILD_SHARED_LIBS "ggml: build shared libraries" ${BUILD_SHARED_LIBS_DEFAULT})
option(GGML_BACKEND_DL "ggml: build backends as dynamic libraries (requires BUILD_SHARED_LIBS)" OFF)
set(GGML_BACKEND_DIR "" CACHE PATH "ggml: directory to load dynamic backends from (requires GGML_BACKEND_DL")
#
# option list
@@ -175,6 +176,7 @@ option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM"
option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF)
option(GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 "ggml: enable rocWMMA FlashAttention on GFX12" OFF)
option(GGML_HIP_MMQ_MFMA "ggml: enable MFMA MMA for CDNA in MMQ" ON)
option(GGML_HIP_EXPORT_METRICS "ggml: enable kernel perf metrics output" OFF)
option(GGML_MUSA_GRAPHS "ggml: use MUSA graph, experimental, unstable" OFF)
option(GGML_MUSA_MUDNN_COPY "ggml: enable muDNN for accelerated copy" OFF)
option(GGML_VULKAN "ggml: use Vulkan" OFF)
+41 -39
View File
@@ -125,54 +125,56 @@ if(NOT TARGET ggml::ggml)
IMPORTED_LOCATION "${GGML_BASE_LIBRARY}")
set(_ggml_all_targets "")
foreach(_ggml_backend ${GGML_AVAILABLE_BACKENDS})
string(REPLACE "-" "_" _ggml_backend_pfx "${_ggml_backend}")
string(TOUPPER "${_ggml_backend_pfx}" _ggml_backend_pfx)
if (NOT GGML_BACKEND_DL)
foreach(_ggml_backend ${GGML_AVAILABLE_BACKENDS})
string(REPLACE "-" "_" _ggml_backend_pfx "${_ggml_backend}")
string(TOUPPER "${_ggml_backend_pfx}" _ggml_backend_pfx)
find_library(${_ggml_backend_pfx}_LIBRARY ${_ggml_backend}
REQUIRED
HINTS ${GGML_LIB_DIR}
NO_CMAKE_FIND_ROOT_PATH)
find_library(${_ggml_backend_pfx}_LIBRARY ${_ggml_backend}
REQUIRED
HINTS ${GGML_LIB_DIR}
NO_CMAKE_FIND_ROOT_PATH)
message(STATUS "Found ${${_ggml_backend_pfx}_LIBRARY}")
message(STATUS "Found ${${_ggml_backend_pfx}_LIBRARY}")
add_library(ggml::${_ggml_backend} UNKNOWN IMPORTED)
set_target_properties(ggml::${_ggml_backend}
PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${GGML_INCLUDE_DIR}"
IMPORTED_LINK_INTERFACE_LANGUAGES "CXX"
IMPORTED_LOCATION "${${_ggml_backend_pfx}_LIBRARY}"
INTERFACE_COMPILE_FEATURES c_std_90
POSITION_INDEPENDENT_CODE ON)
string(REGEX MATCH "^ggml-cpu" is_cpu_variant "${_ggml_backend}")
if(is_cpu_variant)
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES "ggml::ggml-base")
set_target_properties(ggml::${_ggml_backend}
PROPERTIES
INTERFACE_LINK_LIBRARIES "${GGML_CPU_INTERFACE_LINK_LIBRARIES}")
if(GGML_CPU_INTERFACE_LINK_OPTIONS)
set_target_properties(ggml::${_ggml_backend}
PROPERTIES
INTERFACE_LINK_OPTIONS "${GGML_CPU_INTERFACE_LINK_OPTIONS}")
endif()
else()
list(APPEND ${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES "ggml::ggml-base")
add_library(ggml::${_ggml_backend} UNKNOWN IMPORTED)
set_target_properties(ggml::${_ggml_backend}
PROPERTIES
INTERFACE_LINK_LIBRARIES "${${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES}")
INTERFACE_INCLUDE_DIRECTORIES "${GGML_INCLUDE_DIR}"
IMPORTED_LINK_INTERFACE_LANGUAGES "CXX"
IMPORTED_LOCATION "${${_ggml_backend_pfx}_LIBRARY}"
INTERFACE_COMPILE_FEATURES c_std_90
POSITION_INDEPENDENT_CODE ON)
if(${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS)
string(REGEX MATCH "^ggml-cpu" is_cpu_variant "${_ggml_backend}")
if(is_cpu_variant)
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES "ggml::ggml-base")
set_target_properties(ggml::${_ggml_backend}
PROPERTIES
INTERFACE_LINK_LIBRARIES "${GGML_CPU_INTERFACE_LINK_LIBRARIES}")
if(GGML_CPU_INTERFACE_LINK_OPTIONS)
set_target_properties(ggml::${_ggml_backend}
PROPERTIES
INTERFACE_LINK_OPTIONS "${GGML_CPU_INTERFACE_LINK_OPTIONS}")
endif()
else()
list(APPEND ${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES "ggml::ggml-base")
set_target_properties(ggml::${_ggml_backend}
PROPERTIES
INTERFACE_LINK_OPTIONS "${${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS}")
endif()
endif()
INTERFACE_LINK_LIBRARIES "${${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES}")
list(APPEND _ggml_all_targets ggml::${_ggml_backend})
endforeach()
if(${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS)
set_target_properties(ggml::${_ggml_backend}
PROPERTIES
INTERFACE_LINK_OPTIONS "${${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS}")
endif()
endif()
list(APPEND _ggml_all_targets ggml::${_ggml_backend})
endforeach()
endif()
list(APPEND GGML_INTERFACE_LINK_LIBRARIES ggml::ggml-base "${_ggml_all_targets}")
set_target_properties(ggml::ggml
+37 -1
View File
@@ -304,6 +304,16 @@
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
#define GGML_TENSOR_TERNARY_OP_LOCALS \
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
GGML_TENSOR_LOCALS(int64_t, ne2, src2, ne) \
GGML_TENSOR_LOCALS(size_t, nb2, src2, nb) \
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
#define GGML_TENSOR_BINARY_OP_LOCALS01 \
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
@@ -395,7 +405,8 @@ extern "C" {
// GGML_TYPE_IQ4_NL_4_4 = 36,
// GGML_TYPE_IQ4_NL_4_8 = 37,
// GGML_TYPE_IQ4_NL_8_8 = 38,
GGML_TYPE_COUNT = 39,
GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block)
GGML_TYPE_COUNT = 40,
};
// precision
@@ -430,6 +441,7 @@ extern "C" {
GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors
};
// available tensor operations:
@@ -438,6 +450,7 @@ extern "C" {
GGML_OP_DUP,
GGML_OP_ADD,
GGML_OP_ADD_ID,
GGML_OP_ADD1,
GGML_OP_ACC,
GGML_OP_SUB,
@@ -557,6 +570,7 @@ extern "C" {
GGML_GLU_OP_REGLU,
GGML_GLU_OP_GEGLU,
GGML_GLU_OP_SWIGLU,
GGML_GLU_OP_SWIGLU_OAI,
GGML_GLU_OP_GEGLU_ERF,
GGML_GLU_OP_GEGLU_QUICK,
@@ -831,6 +845,13 @@ extern "C" {
struct ggml_tensor * b,
enum ggml_type type);
// dst[i0, i1, i2] = a[i0, i1, i2] + b[i0, ids[i1, i2]]
GGML_API struct ggml_tensor * ggml_add_id(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * ids);
GGML_API struct ggml_tensor * ggml_add1(
struct ggml_context * ctx,
struct ggml_tensor * a,
@@ -1198,6 +1219,13 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b);
GGML_API struct ggml_tensor * ggml_swiglu_oai(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
float alpha,
float limit);
// normalize along rows
GGML_API struct ggml_tensor * ggml_norm(
struct ggml_context * ctx,
@@ -1570,6 +1598,10 @@ extern "C" {
float scale,
float max_bias);
GGML_API void ggml_soft_max_add_sinks(
struct ggml_tensor * a,
struct ggml_tensor * sinks);
GGML_API struct ggml_tensor * ggml_soft_max_ext_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
@@ -2052,6 +2084,10 @@ extern "C" {
GGML_API enum ggml_prec ggml_flash_attn_ext_get_prec(
const struct ggml_tensor * a);
GGML_API void ggml_flash_attn_ext_add_sinks(
struct ggml_tensor * a,
struct ggml_tensor * sinks);
// TODO: needs to be adapted to ggml_flash_attn_ext
GGML_API struct ggml_tensor * ggml_flash_attn_back(
struct ggml_context * ctx,
+12 -1
View File
@@ -214,6 +214,13 @@ add_library(ggml
ggml-backend-reg.cpp)
add_library(ggml::ggml ALIAS ggml)
if (GGML_BACKEND_DIR)
if (NOT GGML_BACKEND_DL)
message(FATAL_ERROR "GGML_BACKEND_DIR requires GGML_BACKEND_DL")
endif()
target_compile_definitions(ggml PUBLIC GGML_BACKEND_DIR="${GGML_BACKEND_DIR}")
endif()
target_link_libraries(ggml PUBLIC ggml-base)
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
@@ -227,7 +234,11 @@ function(ggml_add_backend_library backend)
set_target_properties(${backend} PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
target_compile_definitions(${backend} PRIVATE GGML_BACKEND_DL)
add_dependencies(ggml ${backend})
install(TARGETS ${backend} LIBRARY DESTINATION ${CMAKE_INSTALL_BINDIR})
if (GGML_BACKEND_DIR)
install(TARGETS ${backend} LIBRARY DESTINATION ${GGML_BACKEND_DIR})
else()
install(TARGETS ${backend} LIBRARY DESTINATION ${CMAKE_INSTALL_BINDIR})
endif()
else()
add_library(${backend} ${ARGN})
target_link_libraries(ggml PUBLIC ${backend})
+1
View File
@@ -29,6 +29,7 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
case GGML_OP_DIAG_MASK_ZERO:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_ADD:
case GGML_OP_ADD_ID:
case GGML_OP_ADD1:
case GGML_OP_SUB:
case GGML_OP_MUL:
+3
View File
@@ -498,6 +498,9 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
std::vector<fs::path> search_paths;
if (user_search_path == nullptr) {
#ifdef GGML_BACKEND_DIR
search_paths.push_back(fs::u8path(GGML_BACKEND_DIR));
#endif
// default search paths: executable directory, current directory
search_paths.push_back(get_executable_path());
search_paths.push_back(fs::current_path());
+7 -2
View File
@@ -1071,6 +1071,11 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
}
}
}
// if the node is still unassigned, assign it to the first backend that supports it
for (int b = 0; b < sched->n_backends && *cur_backend_id == -1; b++) {
ggml_backend_sched_set_if_supported(sched, node, b, cur_backend_id);
}
GGML_ASSERT(*cur_backend_id != -1);
}
// pass 5: split graph, find tensors that need to be copied
@@ -1098,7 +1103,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
const int node_backend_id = tensor_backend_id(node);
assert(node_backend_id != -1); // all nodes should be assigned by now, this can happen if there is no CPU fallback
GGML_ASSERT(node_backend_id != -1); // all nodes should be assigned by now, this can happen if there is no CPU fallback
// check if we should start a new split based on the sources of the current node
bool need_new_split = false;
@@ -1156,7 +1161,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
size_t src_id = hash_id(src);
const int src_backend_id = sched->hv_tensor_backend_ids[src_id];
assert(src_backend_id != -1); // all inputs should be assigned by now
GGML_ASSERT(src_backend_id != -1); // all inputs should be assigned by now
if (src->flags & GGML_TENSOR_FLAG_INPUT && sched->n_copies > 1) {
if (tensor_id_copy(src_id, src_backend_id, 0) == NULL) {
+4 -4
View File
@@ -281,10 +281,10 @@ ggml_backend_t ggml_backend_blas_init(void) {
ggml_backend_blas_context * ctx = new ggml_backend_blas_context;
ggml_backend_t backend = new ggml_backend {
/* .guid = */ ggml_backend_blas_guid(),
/* .interface = */ blas_backend_i,
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_blas_reg(), 0),
/* .context = */ ctx,
/* .guid = */ ggml_backend_blas_guid(),
/* .iface = */ blas_backend_i,
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_blas_reg(), 0),
/* .context = */ ctx,
};
#if defined(OPENBLAS_VERSION) && defined(GGML_USE_OPENMP)
+14
View File
@@ -31,6 +31,13 @@ string(REGEX MATCH "[0-9]+[a-zA-Z]" SOC_TYPE_MAJOR_SN "${SOC_VERSION}")
set(SOC_TYPE_COMPILE_OPTION "ASCEND_${SOC_TYPE_MAJOR_SN}")
string(TOUPPER ${SOC_TYPE_COMPILE_OPTION} SOC_TYPE_COMPILE_OPTION)
message(STATUS "CANN: SOC_VERSION = ${SOC_VERSION}")
option(USE_ACL_GRAPH "Enable CANN graph execution (ACL graph mode)" OFF)
if(USE_ACL_GRAPH AND (SOC_TYPE_MAJOR_SN STREQUAL "310P" OR SOC_TYPE_COMPILE_OPTION STREQUAL "ASCEND_310P"))
message(FATAL_ERROR
"CANN Graph (ACL graph mode) is not supported on 310P devices. "
"Please build with -DUSE_ACL_GRAPH=OFF or use a supported SOC.")
endif()
if (CANN_INSTALL_DIR)
# Only Support Linux.
@@ -68,6 +75,13 @@ if (CANN_INSTALL_DIR)
target_compile_definitions(ggml-cann PRIVATE "-D${SOC_TYPE_COMPILE_OPTION}")
if (USE_ACL_GRAPH)
target_compile_definitions(ggml-cann PRIVATE USE_ACL_GRAPH)
message(STATUS "CANN: USE_ACL_GRAPH is enabled.")
else()
message(STATUS "CANN: USE_ACL_GRAPH is disabled.")
endif()
message(STATUS "CANN: CANN_INCLUDE_DIRS = ${CANN_INCLUDE_DIRS}")
message(STATUS "CANN: CANN_LIBRARIES = ${CANN_LIBRARIES}")
else()
+36
View File
@@ -337,6 +337,29 @@ private:
int32_t device_;
};
#ifdef USE_ACL_GRAPH
struct ggml_graph_node_properties {
void * node_address;
ggml_op node_op;
int64_t ne[GGML_MAX_DIMS];
size_t nb[GGML_MAX_DIMS];
void * src_address[GGML_MAX_SRC];
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
};
struct ggml_cann_graph {
~ggml_cann_graph() {
if (graph != nullptr) {
aclmdlRIDestroy(graph);
}
}
aclmdlRI graph = nullptr;
std::vector<ggml_graph_node_properties> ggml_graph_properties;
};
#endif // USE_ACL_GRAPH
/**
* @brief Context for managing CANN backend operations.
*/
@@ -345,8 +368,13 @@ struct ggml_backend_cann_context {
std::string name; /**< Name of the device. */
std::string description; /**< Description of the device. */
aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
#ifdef USE_ACL_GRAPH
/// Cached CANN ACL graph used for executing the current ggml computation graph.
std::unique_ptr<ggml_cann_graph> cann_graph;
#endif
cann_task_queue task_queue;
bool async_mode;
bool support_set_rows;
aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */
@@ -362,6 +390,14 @@ struct ggml_backend_cann_context {
async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or(""));
GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__,
device, async_mode ? "ON" : "OFF");
support_set_rows = parse_bool(get_env("LLAMA_SET_ROWS").value_or(""));
GGML_LOG_INFO("%s: LLAMA_SET_ROWS is %s\n", __func__, support_set_rows ? "ON" : "OFF");
if (!support_set_rows) {
GGML_LOG_INFO("%s: CANN Graph currently only supports execution when LLAMA_SET_ROWS is ON. "
"Falling back to eager mode.\n", __func__);
}
}
/**
+189 -22
View File
@@ -2075,6 +2075,160 @@ static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
}
#ifdef USE_ACL_GRAPH
/**
* @brief Populate the internal CANN graph node properties from the ggml computation graph.
*
* This function copies all node attributes (operation type, dimensions, strides, input sources,
* and operation parameters) into the cached CANN graph structure for later reuse or comparison.
*
* @param cann_ctx The CANN backend context.
* @param cgraph The ggml computational graph.
*/
static void set_ggml_graph_node_properties(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) {
for (int node_idx = 0; node_idx < cgraph->n_nodes; node_idx++) {
ggml_tensor * node = cgraph->nodes[node_idx];
cann_ctx->cann_graph->ggml_graph_properties[node_idx].node_address = node->data;
cann_ctx->cann_graph->ggml_graph_properties[node_idx].node_op = node->op;
for (int dim = 0; dim < GGML_MAX_DIMS; dim++) {
cann_ctx->cann_graph->ggml_graph_properties[node_idx].ne[dim] = node->ne[dim];
cann_ctx->cann_graph->ggml_graph_properties[node_idx].nb[dim] = node->nb[dim];
}
for (int src = 0; src < GGML_MAX_SRC; src++) {
cann_ctx->cann_graph->ggml_graph_properties[node_idx].src_address[src] =
node->src[src] ? node->src[src]->data : nullptr;
}
memcpy(cann_ctx->cann_graph->ggml_graph_properties[node_idx].op_params, node->op_params, GGML_MAX_OP_PARAMS);
}
}
/**
* @brief Check if a ggml tensor node matches a previously captured CANN graph node.
*
* This function compares all relevant fields (address, op type, shape, source inputs, op params)
* to determine whether the current node matches a previously recorded version.
*
* @param node The current ggml tensor node.
* @param graph_node_properties The stored properties of a CANN graph node.
* @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
*/
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
if (node->data != graph_node_properties->node_address &&
node->op != GGML_OP_VIEW) {
return false;
}
if (node->op != graph_node_properties->node_op) {
return false;
}
for (int i = 0; i < GGML_MAX_DIMS; i++) {
if (node->ne[i] != graph_node_properties->ne[i]) {
return false;
}
if (node->nb[i] != graph_node_properties->nb[i]) {
return false;
}
}
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (node->src[i] &&
node->src[i]->data != graph_node_properties->src_address[i] &&
node->op != GGML_OP_VIEW
) {
return false;
}
}
if (node->op == GGML_OP_SCALE &&
memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
return false;
}
return true;
}
/**
* @brief Determine if the CANN graph needs to be rebuilt due to graph changes.
*
* This checks whether the number or properties of ggml graph nodes have changed
* compared to the last captured CANN graph. If so, the CANN graph must be re-captured.
*
* @param cann_ctx The CANN backend context.
* @param cgraph The current ggml computation graph.
* @return true if an update is required; false otherwise.
*/
static bool is_cann_graph_update_required(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) {
// The number of nodes is different, so the graph needs to be reconstructed.
if (cann_ctx->cann_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
cann_ctx->cann_graph->ggml_graph_properties.resize(cgraph->n_nodes);
return true;
}
// The number of nodes is the same; iterate over each node to check whether they match.
for (int i = 0; i < cgraph->n_nodes; i++) {
bool has_matching_properties = ggml_graph_node_has_matching_properties(
cgraph->nodes[i], &cann_ctx->cann_graph->ggml_graph_properties[i]);
if(!has_matching_properties) {
return true;
}
}
return false;
}
#endif // USE_ACL_GRAPH
/**
* @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API.
*
* If CANN graph execution is enabled and graph capture is required, this function begins
* graph capture, runs the graph, ends capture, and stores the captured graph.
*
* Otherwise, it falls back to op-by-op execution using the CANN compute kernel dispatcher.
*
* @param cann_ctx The CANN backend context.
* @param cgraph The ggml computation graph.
* @param use_cann_graph Whether to use CANN graph execution.
* @param cann_graph_update_required Whether graph capture is needed due to graph changes.
*/
static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph,
bool & use_cann_graph, bool & cann_graph_update_required) {
#ifdef USE_ACL_GRAPH
if (use_cann_graph && cann_graph_update_required) {
if (cann_ctx->cann_graph->graph != nullptr) {
ACL_CHECK(aclmdlRIDestroy(cann_ctx->cann_graph->graph));
cann_ctx->cann_graph->graph = nullptr;
}
ACL_CHECK(aclmdlRICaptureBegin(cann_ctx->stream(), ACL_MODEL_RI_CAPTURE_MODE_GLOBAL));
}
#endif // USE_ACL_GRAPH
// Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph.
// With the use of CANN graphs, the execution will be performed by the graph launch.
if (!use_cann_graph || cann_graph_update_required) {
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
continue;
}
bool ok = ggml_cann_compute_forward(*cann_ctx, node);
if (!ok) {
GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
}
GGML_ASSERT(ok);
}
}
#ifdef USE_ACL_GRAPH
if (use_cann_graph && cann_graph_update_required) { // End CANN graph capture
ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &cann_ctx->cann_graph->graph));
}
if (use_cann_graph) {
// Execute graph
ACL_CHECK(aclmdlRIExecuteAsync(cann_ctx->cann_graph->graph, cann_ctx->stream()));
}
#endif // USE_ACL_GRAPH
}
/**
* @brief Computes a computational graph using a CANN backend.
*
@@ -2091,27 +2245,38 @@ static enum ggml_status ggml_backend_cann_graph_compute(
ggml_backend_t backend, ggml_cgraph* cgraph) {
ggml_backend_cann_context* cann_ctx =
(ggml_backend_cann_context*)backend->context;
ggml_cann_set_device(cann_ctx->device);
//release temp buffer create by set tensor.
release_nz_workspace();
#ifdef USE_ACL_GRAPH
bool use_cann_graph = true;
bool cann_graph_update_required = false;
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor* node = cgraph->nodes[i];
if (ggml_is_empty(node) || node->op == GGML_OP_NONE) {
continue;
}
bool ok = ggml_cann_compute_forward(*cann_ctx, node);
if (!ok) {
GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__,
node->name, ggml_op_name(node->op));
}
GGML_ASSERT(ok);
// check environment LLAMA_SET_ROWS
if (!cann_ctx->support_set_rows) {
use_cann_graph = false;
}
if (use_cann_graph) {
if (cann_ctx->cann_graph == nullptr) {
cann_ctx->cann_graph.reset(new ggml_cann_graph());
cann_graph_update_required = true;
}
cann_graph_update_required = is_cann_graph_update_required(cann_ctx, cgraph);
set_ggml_graph_node_properties(cann_ctx, cgraph);
}
#else
bool use_cann_graph = false;
bool cann_graph_update_required = false;
#endif // USE_ACL_GRAPH
evaluate_and_capture_cann_graph(
cann_ctx,
cgraph,
use_cann_graph,
cann_graph_update_required
);
return GGML_STATUS_SUCCESS;
}
@@ -2226,12 +2391,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
// only support F32 and F16.
return false;
}
if (!ggml_are_same_shape(op, src) && !ggml_is_contiguous(op)) {
// unsupport dst is not contiguous.
return false;
}
return true;
} break;
case GGML_OP_CONT: {
@@ -2340,6 +2499,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
memcpy(&bias, (float*)op->op_params + 1, sizeof(float));
return bias == 0.0f; // TODO: support bias != 0.0f
case GGML_OP_SOFT_MAX:
// TODO: support attention sinks [TAG_ATTN_SINKS]
if (op->src[2]) {
return false;
}
// TODO: support broadcast
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
@@ -2354,6 +2517,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
if(op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_BF16){
return false;
}
// TODO: support attention sinks [TAG_ATTN_SINKS]
if (op->src[4]) {
return false;
}
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
// different head sizes of K and V are not supported yet
return false;
+17
View File
@@ -99,6 +99,9 @@ typedef sycl::half2 ggml_half2;
#define QI4_1 (QK4_1 / (4 * QR4_1))
#define QR4_1 2
#define QI_MXFP4 (QK_MXFP4 / (4 * QR_MXFP4))
#define QR_MXFP4 2
#define QI5_0 (QK5_0 / (4 * QR5_0))
#define QR5_0 2
@@ -184,6 +187,13 @@ typedef struct {
} block_q4_1;
static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding");
#define QK_MXFP4 32
typedef struct {
uint8_t e; // E8M0
uint8_t qs[QK_MXFP4/2];
} block_mxfp4;
static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + QK_MXFP4/2, "wrong mxfp4 block size/padding");
#define QK5_0 32
typedef struct {
ggml_half d; // delta
@@ -1074,10 +1084,17 @@ GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512)
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
GGML_TABLE_END()
// TODO: fix name to kvalues_iq4_nl
GGML_TABLE_BEGIN(int8_t, kvalues_iq4nl, 16)
-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113,
GGML_TABLE_END()
// e2m1 values (doubled)
// ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
GGML_TABLE_BEGIN(int8_t, kvalues_mxfp4, 16)
0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12,
GGML_TABLE_END()
#define NGRID_IQ1S 2048
#define IQ1S_DELTA 0.125f
#define IQ1M_DELTA 0.125f
+6
View File
@@ -13,6 +13,7 @@
#define ggml_vec_dot_q5_0_q8_0_generic ggml_vec_dot_q5_0_q8_0
#define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1
#define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
@@ -68,6 +69,7 @@
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
@@ -90,6 +92,7 @@
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
@@ -120,6 +123,7 @@
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
@@ -149,6 +153,7 @@
#define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K
#define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
@@ -179,6 +184,7 @@
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
+61
View File
@@ -589,6 +589,67 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
*s = sumf;
}
void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
assert(n % QK_MXFP4 == 0);
static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
const block_mxfp4 * GGML_RESTRICT x = vx;
const block_q8_0 * GGML_RESTRICT y = vy;
const int nb = n / QK_MXFP4;
int ib = 0;
float sumf = 0;
#if defined __ARM_NEON
const int8x16_t values = vld1q_s8(kvalues_mxfp4);
const uint8x16_t m4b = vdupq_n_u8(0x0f);
uint8x16x2_t q4bits;
int8x16x4_t q4b;
int8x16x4_t q8b;
int32x4_t prod_1;
int32x4_t prod_2;
for (; ib + 1 < nb; ib += 2) {
q4bits.val[0] = vld1q_u8(x[ib + 0].qs);
q4bits.val[1] = vld1q_u8(x[ib + 1].qs);
q8b.val[0] = vld1q_s8(y[ib + 0].qs);
q8b.val[1] = vld1q_s8(y[ib + 0].qs + 16);
q8b.val[2] = vld1q_s8(y[ib + 1].qs);
q8b.val[3] = vld1q_s8(y[ib + 1].qs + 16);
q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b));
q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b));
q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);
prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
sumf +=
GGML_E8M0_TO_FP32_HALF(x[ib + 0].e) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * vaddvq_s32(prod_1) +
GGML_E8M0_TO_FP32_HALF(x[ib + 1].e) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * vaddvq_s32(prod_2);
}
#endif
for (; ib < nb; ++ib) {
const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
int sumi1 = 0;
int sumi2 = 0;
for (int j = 0; j < QK_MXFP4/2; ++j) {
sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4];
}
sumf += d * (sumi1 + sumi2);
}
*s = sumf;
}
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
const int qk = QK8_0;
const int nb = n / qk;
+96 -8
View File
@@ -66,6 +66,12 @@ static inline int hsum_i32_4(const __m128i a) {
}
#if defined(__AVX2__) || defined(__AVX512F__)
static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
const __m256i ax = _mm256_sign_epi8(x, x);
const __m256i sy = _mm256_sign_epi8(y, x);
return _mm256_maddubs_epi16(ax, sy);
}
// spread 32 bits to 32 bytes { 0x00, 0xFF }
static inline __m256i bytes_from_bits_32(const uint8_t * x) {
uint32_t x32;
@@ -261,6 +267,11 @@ static inline __m256 quad_fp16_delta_float(const float x0, const float y0, const
return _mm256_set_m128(_mm_set1_ps(GGML_CPU_FP16_TO_FP32(x1) * GGML_CPU_FP16_TO_FP32(y1)),
_mm_set1_ps(GGML_CPU_FP16_TO_FP32(x0) * GGML_CPU_FP16_TO_FP32(y0)));
}
static inline __m256 quad_mx_delta_float(const int8_t x0, const float y0, const int8_t x1, const float y1) {
return _mm256_set_m128(_mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x1) * GGML_CPU_FP16_TO_FP32(y1)),
_mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x0) * GGML_CPU_FP16_TO_FP32(y0)));
}
#endif
#elif defined(__SSSE3__)
// horizontally add 4x4 floats
@@ -746,6 +757,91 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
#endif
}
void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
assert(n % QK_MXFP4 == 0);
static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
const block_mxfp4 * GGML_RESTRICT x = vx;
const block_q8_0 * GGML_RESTRICT y = vy;
const int nb = n / QK_MXFP4;
int ib = 0;
float sumf = 0;
#if defined __AVX2__
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
const __m128i m4b = _mm_set1_epi8(0x0f);
const __m256i mone = _mm256_set1_epi16(1);
__m256 accum1 = _mm256_setzero_ps();
__m256 accum2 = _mm256_setzero_ps();
for (; ib + 1 < nb; ib += 2) {
const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs);
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs);
const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib + 0].qs);
const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib + 1].qs);
const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);
accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 0].e)),
_mm256_cvtepi32_ps(p_1), accum1);
accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 1].e)),
_mm256_cvtepi32_ps(p_2), accum2);
}
sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
#elif defined __AVX__
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
const __m128i m4b = _mm_set1_epi8(0x0f);
__m256 accum = _mm256_setzero_ps();
for (; ib + 1 < nb; ib += 2) {
const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);
const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);
const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
const __m256 p = mul_sum_i8_quad_float(q4b_1_0, q4b_1_1, q4b_2_0, q4b_2_1, q8b_1_0, q8b_1_1, q8b_2_0, q8b_2_1);
const __m256 deltas = quad_mx_delta_float(x[ib].e, y[ib].d, x[ib + 1].e, y[ib + 1].d);
accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);
}
sumf = hsum_float_8(accum);
#endif
for (; ib < nb; ++ib) {
const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
int sumi1 = 0;
int sumi2 = 0;
for (int j = 0; j < QK_MXFP4/2; ++j) {
sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4];
}
sumf += d * (sumi1 + sumi2);
}
*s = sumf;
}
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
const int qk = QK8_0;
const int nb = n / qk;
@@ -3206,14 +3302,6 @@ void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
#endif
}
#if defined(__AVX2__)
static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
const __m256i ax = _mm256_sign_epi8(x, x);
const __m256i sy = _mm256_sign_epi8(y, x);
return _mm256_maddubs_epi16(ax, sy);
}
#endif
void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
+14 -1
View File
@@ -253,6 +253,12 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_Q8_1,
.nrows = 1,
},
[GGML_TYPE_MXFP4] = {
.from_float = quantize_row_mxfp4,
.vec_dot = ggml_vec_dot_mxfp4_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1,
},
[GGML_TYPE_Q2_K] = {
.from_float = quantize_row_q2_K,
.vec_dot = ggml_vec_dot_q2_K_q8_K,
@@ -1670,6 +1676,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_add(params, tensor);
} break;
case GGML_OP_ADD_ID:
{
ggml_compute_forward_add_id(params, tensor);
} break;
case GGML_OP_ADD1:
{
ggml_compute_forward_add1(params, tensor);
@@ -1924,7 +1934,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break;
case GGML_OP_FLASH_ATTN_EXT:
{
ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
ggml_compute_forward_flash_attn_ext(params, tensor);
} break;
case GGML_OP_FLASH_ATTN_BACK:
{
@@ -2111,6 +2121,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_DUP:
case GGML_OP_CONT:
case GGML_OP_ADD:
case GGML_OP_ADD_ID:
case GGML_OP_ADD1:
case GGML_OP_ACC:
{
@@ -2172,6 +2183,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_SWIGLU_OAI:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
{
@@ -2673,6 +2685,7 @@ struct ggml_cplan ggml_graph_plan(
}
} break;
case GGML_OP_ADD:
case GGML_OP_ADD_ID:
case GGML_OP_ADD1:
{
if (ggml_is_quantized(node->src[0]->type)) {
+21 -24
View File
@@ -35,7 +35,7 @@
// ggml-backend interface
std::vector<ggml_backend_buffer_type_t>& ggml_backend_cpu_get_extra_buffers_type() {
std::vector<ggml_backend_buffer_type_t> & ggml_backend_cpu_get_extra_buffer_types() {
static std::vector<ggml_backend_buffer_type_t> bufts = []() {
std::vector<ggml_backend_buffer_type_t> bufts;
@@ -57,8 +57,6 @@ std::vector<ggml_backend_buffer_type_t>& ggml_backend_cpu_get_extra_buffers_type
}
#endif
bufts.push_back(NULL);
return bufts;
}();
@@ -66,14 +64,20 @@ std::vector<ggml_backend_buffer_type_t>& ggml_backend_cpu_get_extra_buffers_type
}
static ggml_backend_buffer_type_t * ggml_backend_cpu_device_get_extra_buffers_type(ggml_backend_dev_t device) {
return ggml_backend_cpu_get_extra_buffers_type().data();
static std::vector<ggml_backend_buffer_type_t> extra_bufts = [] {
std::vector<ggml_backend_buffer_type_t> bufts = ggml_backend_cpu_get_extra_buffer_types();
bufts.push_back(nullptr);
return bufts;
}();
return extra_bufts.data();
GGML_UNUSED(device);
}
static bool ggml_backend_cpu_is_extra_buffer_type(ggml_backend_buffer_type_t buft) {
for (auto * extra : ggml_backend_cpu_get_extra_buffers_type()) {
if (extra && extra == buft) {
for (auto * extra : ggml_backend_cpu_get_extra_buffer_types()) {
if (extra == buft) {
return true;
}
}
@@ -210,10 +214,10 @@ ggml_backend_t ggml_backend_cpu_init(void) {
ctx->abort_callback_data = NULL;
ggml_backend_t cpu_backend = new ggml_backend {
/* .guid = */ ggml_backend_cpu_guid(),
/* .interface = */ ggml_backend_cpu_i,
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
/* .context = */ ctx,
/* .guid = */ ggml_backend_cpu_guid(),
/* .iface = */ ggml_backend_cpu_i,
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
/* .context = */ ctx,
};
if (cpu_backend == NULL) {
@@ -397,20 +401,13 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
return true;
}
// extra_buffer_op?
for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) {
if (extra) {
auto buf_extra = (ggml::cpu::extra_buffer_type*) extra->context;
if (buf_extra && buf_extra->supports_op(dev, op)) {
return true;
}
}
}
// the other case need host buffer.
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (op->src[i] && op->src[i]->buffer && !ggml_backend_buft_is_host(op->src[i]->buffer->buft)) {
return false;
// check extra buffer types
// note: only the first sources are checked for extra buffer types to reduce overhead, increase if necessary
for (int i = 0; i < 4; i++) {
if (op->src[i] && op->src[i]->buffer &&
ggml_backend_cpu_is_extra_buffer_type(op->src[i]->buffer->buft)) {
auto * buf_extra = (ggml::cpu::extra_buffer_type *) op->src[i]->buffer->buft->context;
return buf_extra->supports_op(dev, op);
}
}
+207 -9
View File
@@ -8,6 +8,7 @@
#include "vec.h"
#include <float.h>
#include <algorithm>
// ggml_compute_forward_dup
@@ -1283,6 +1284,7 @@ void ggml_compute_forward_add(
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@@ -1309,6 +1311,77 @@ void ggml_compute_forward_add(
}
}
// ggml_compute_forward_add_id
static void ggml_compute_forward_add_id_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(src2->type == GGML_TYPE_I32);
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
const int nr = ggml_nrows(src0);
GGML_TENSOR_TERNARY_OP_LOCALS
GGML_ASSERT( nb0 == sizeof(float));
GGML_ASSERT(nb10 == sizeof(float));
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int ir = ir0; ir < ir1; ++ir) {
// src0 indices
const int i3 = ir/(ne2*ne1);
const int i2 = (ir - i3*ne2*ne1)/ne1;
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
// src1 indices
const int i11 = *(int32_t *) ((char *) src2->data + i1*nb20 + i2*nb21);
GGML_ASSERT(i11 >= 0 && i11 < ne11);
ggml_vec_add_f32(ne0,
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
(float *) ((char *) src1->data + i11*nb11));
}
}
void ggml_compute_forward_add_id(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_add_id_f32(params, dst);
} break;
default:
{
GGML_ABORT("unsupported type for ggml_compute_forward_add_id: %s", ggml_type_name(src0->type));
}
}
}
// ggml_compute_forward_add1
static void ggml_compute_forward_add1_f32(
@@ -1660,6 +1733,7 @@ void ggml_compute_forward_add1(
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@@ -1787,6 +1861,7 @@ void ggml_compute_forward_acc(
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@@ -3614,6 +3689,93 @@ static void ggml_compute_forward_swiglu(
}
}
// ggml_compute_forward_swiglu_oai
static void ggml_compute_forward_swiglu_oai_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
char * src0_d = (char *) src0->data;
char * src1_d = (char *) (src1 ? src1->data : src0->data);
const size_t src0_o = src0->nb[1];
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_1(dst));
if (src1) {
GGML_ASSERT(ggml_is_contiguous_1(src1));
GGML_ASSERT(src0->type == src1->type);
}
const int ith = params->ith;
const int nth = params->nth;
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
const int nr = ggml_nrows(src0);
GGML_ASSERT(dst->ne[0] == nc);
GGML_ASSERT(ggml_nrows(dst) == nr);
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
const float alpha = ggml_get_op_params_f32(dst, 2);
const float limit = ggml_get_op_params_f32(dst, 3);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
float * src0_p = (float *) (src0_d + i1*src0_o);
float * src1_p = (float *) (src1_d + i1*src1_o);
float * dst_p = (float *) ((char *) dst->data + i1*(dst->nb[1]));
if (!src1) {
src0_p += swapped ? nc : 0;
src1_p += swapped ? 0 : nc;
}
for (int k = 0; k < nc; k++) {
const float x = std::min(src0_p[k], limit);
const float y = std::clamp(src1_p[k], -limit, limit);
const float out_glu = x / (1.f + expf(alpha * (-x)));
dst_p[k] = out_glu * (y + 1.f);
}
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
const float x = dst_p[k];
GGML_UNUSED(x);
assert(!isnan(x));
assert(!isinf(x));
}
#endif
}
}
static void ggml_compute_forward_swiglu_oai(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_swiglu_oai_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_geglu_erf
static void ggml_compute_forward_geglu_erf_f32(
@@ -4599,6 +4761,7 @@ void ggml_compute_forward_out_prod(
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@@ -4873,6 +5036,7 @@ void ggml_compute_forward_set(
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@@ -5134,6 +5298,7 @@ void ggml_compute_forward_get_rows(
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@@ -5523,6 +5688,7 @@ static void ggml_compute_forward_soft_max_f32(
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
assert(ggml_is_contiguous(dst));
assert(ggml_are_same_shape(src0, dst));
@@ -5557,6 +5723,9 @@ static void ggml_compute_forward_soft_max_f32(
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
// sinks
const float * sk = src2 ? (float *)((char *) src2->data) : nullptr;
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
@@ -5599,9 +5768,18 @@ static void ggml_compute_forward_soft_max_f32(
float max = -INFINITY;
ggml_vec_max_f32(ne00, &max, wp);
// if we have sinks, make a correction as if they were included in the softmax
if (sk) {
max = MAX(max, sk[i02]);
}
ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
assert(sum > 0.0);
if (sk) {
sum += (ggml_float) expf(sk[i02] - max);
}
sum = 1.0/sum;
ggml_vec_scale_f32(ne00, dp, sum);
@@ -5836,6 +6014,7 @@ void ggml_compute_forward_clamp(
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@@ -7989,12 +8168,14 @@ void ggml_compute_forward_argsort(
static void ggml_compute_forward_flash_attn_ext_f16(
const ggml_compute_params * params,
const ggml_tensor * q,
const ggml_tensor * k,
const ggml_tensor * v,
const ggml_tensor * mask,
ggml_tensor * dst) {
const ggml_tensor * q = dst->src[0];
const ggml_tensor * k = dst->src[1];
const ggml_tensor * v = dst->src[2];
const ggml_tensor * mask = dst->src[3];
const ggml_tensor * sinks = dst->src[4];
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
@@ -8189,6 +8370,23 @@ static void ggml_compute_forward_flash_attn_ext_f16(
}
}
// sinks
if (sinks) {
const float s = ((float *)((char *) sinks->data))[h];
float ms = 1.0f;
float vs = 1.0f;
if (s > M) {
ms = expf(M - s);
ggml_vec_scale_f32(DV, VKQ32, ms);
} else {
vs = expf(s - M);
}
S = S*ms + vs;
}
// V /= S
const float S_inv = 1.0f/S;
ggml_vec_scale_f32(DV, VKQ32, S_inv);
@@ -8208,17 +8406,13 @@ static void ggml_compute_forward_flash_attn_ext_f16(
void ggml_compute_forward_flash_attn_ext(
const ggml_compute_params * params,
const ggml_tensor * q,
const ggml_tensor * k,
const ggml_tensor * v,
const ggml_tensor * mask,
ggml_tensor * dst) {
switch (dst->op_params[3]) {
case GGML_PREC_DEFAULT:
case GGML_PREC_F32:
{
// uses F32 accumulators
ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
ggml_compute_forward_flash_attn_ext_f16(params, dst);
} break;
default:
{
@@ -9080,6 +9274,10 @@ void ggml_compute_forward_glu(
{
ggml_compute_forward_swiglu(params, dst);
} break;
case GGML_GLU_OP_SWIGLU_OAI:
{
ggml_compute_forward_swiglu_oai(params, dst);
} break;
case GGML_GLU_OP_GEGLU_ERF:
{
ggml_compute_forward_geglu_erf(params, dst);
+2 -7
View File
@@ -29,6 +29,7 @@ extern "C" {
void ggml_compute_forward_dup(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_add(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_add_id(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_add1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_acc(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_sum(const struct ggml_compute_params * params, struct ggml_tensor * dst);
@@ -82,13 +83,7 @@ void ggml_compute_forward_arange(const struct ggml_compute_params * params, stru
void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_flash_attn_ext(
const struct ggml_compute_params * params,
const struct ggml_tensor * q,
const struct ggml_tensor * k,
const struct ggml_tensor * v,
const struct ggml_tensor * mask,
struct ggml_tensor * dst);
void ggml_compute_forward_flash_attn_ext(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_flash_attn_back(
const struct ggml_compute_params * params,
const bool masked,
+35
View File
@@ -46,6 +46,10 @@ void quantize_row_q8_1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRI
quantize_row_q8_1_ref(x, y, k);
}
void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
quantize_row_mxfp4_ref(x, y, k);
}
//
// 2-6 bit quantization in super-blocks
//
@@ -181,6 +185,37 @@ void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, c
*s = sumf;
}
void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
assert(n % QK_MXFP4 == 0);
static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
const block_mxfp4 * GGML_RESTRICT x = vx;
const block_q8_0 * GGML_RESTRICT y = vy;
const int nb = n / QK_MXFP4;
int ib = 0;
float sumf = 0;
for (; ib < nb; ++ib) {
const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
int sumi1 = 0;
int sumi2 = 0;
for (int j = 0; j < QK_MXFP4/2; ++j) {
sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4];
}
sumf += d * (sumi1 + sumi2);
}
*s = sumf;
}
void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
const int qk = QK8_0;
const int nb = n / qk;
+8
View File
@@ -19,6 +19,8 @@ void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
@@ -39,6 +41,8 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
@@ -67,8 +71,12 @@ void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, c
void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q5_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_tq1_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_tq2_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q2_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q3_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q4_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+2 -2
View File
@@ -10,7 +10,7 @@ extra_buffer_type::~extra_buffer_type() {}
} // namespace ggml::cpu
bool ggml_cpu_extra_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) {
for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) {
for (auto extra : ggml_backend_cpu_get_extra_buffer_types()) {
if (extra && extra->context) {
auto buf_extra = (ggml::cpu::extra_buffer_type *) extra->context;
auto tensor_traits = buf_extra->get_tensor_traits(op);
@@ -23,7 +23,7 @@ bool ggml_cpu_extra_compute_forward(struct ggml_compute_params * params, struct
}
bool ggml_cpu_extra_work_size(int n_threads, const struct ggml_tensor * op, size_t * size) {
for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) {
for (auto extra : ggml_backend_cpu_get_extra_buffer_types()) {
if (extra && extra->context) {
auto buf_extra = (ggml::cpu::extra_buffer_type *) extra->context;
auto tensor_traits = buf_extra->get_tensor_traits(op);
+1 -1
View File
@@ -33,6 +33,6 @@ class extra_buffer_type {
} // namespace ggml::cpu
// implemented in ggml-cpu.cpp.
std::vector<ggml_backend_buffer_type_t> & ggml_backend_cpu_get_extra_buffers_type();
std::vector<ggml_backend_buffer_type_t> & ggml_backend_cpu_get_extra_buffer_types();
#endif
+19 -4
View File
@@ -55,7 +55,22 @@ inline static void ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x)
inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const ggml_fp16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) {
int i = 0;
#if defined(__AVX2__)
for (; i + 7 < n; i += 8) {
__m256 vx = _mm256_loadu_ps(x + i);
__m256 vy = _mm256_loadu_ps(y + i);
__m256 vz = _mm256_add_ps(vx, vy);
_mm256_storeu_ps(z + i, vz);
}
#endif
for (; i < n; ++i) {
z[i] = x[i] + y[i];
}
}
inline static void ggml_vec_add_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
for (int i = 0; i < n; ++i) {
z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) + GGML_CPU_FP16_TO_FP32(y[i]));
@@ -992,9 +1007,9 @@ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float *
inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
for (int i = 0; i < n; ++i) {
float v = GGML_CPU_FP16_TO_FP32(x[i]);
float w = GGML_CPU_FP16_TO_FP32(g[i]);
y[i] = GGML_CPU_FP32_TO_FP16((v/(1.0f + expf(-v))) * w);
float xi = GGML_CPU_FP16_TO_FP32(x[i]);
float gi = GGML_CPU_FP16_TO_FP32(g[i]);
y[i] = GGML_CPU_FP32_TO_FP16((xi/(1.0f + expf(-xi))) * gi);
}
}
+58
View File
@@ -0,0 +1,58 @@
#include "add-id.cuh"
static __global__ void add_id_kernel(
const float * src0, const float * src1, const int32_t * src2, float * dst,
int64_t ne0, int64_t ne1,
size_t nb01, size_t nb02,
size_t nb11,
size_t nb21
) {
const int64_t i1 = blockIdx.x;
const int64_t i2 = blockIdx.y;
const int i11 = *(int32_t *) ((char *) src2 + i1*sizeof(int32_t) + i2*nb21);
const size_t nb1 = ne0 * sizeof(float);
const size_t nb2 = ne1 * nb1;
float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2);
const float * src0_row = (const float *)((char *)src0 + i1*nb01 + i2*nb02);
const float * src1_row = (const float *)((char *)src1 + i11*nb11);
for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
dst_row[i0] = src0_row[i0] + src1_row[i0];
}
}
void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
GGML_TENSOR_TERNARY_OP_LOCALS
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(src2->type == GGML_TYPE_I32);
GGML_ASSERT(nb00 == sizeof(float));
GGML_ASSERT(nb10 == sizeof(float));
GGML_ASSERT(nb20 == sizeof(int32_t));
const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
const int32_t * src2_d = (const int32_t *)src2->data;
float * dst_d = (float *)dst->data;
int threads = std::min((int)ne00, 768); // cols
dim3 blocks(ne01, ne02); // n_experts_used, n_tokens
add_id_kernel<<<blocks, threads, 0, ctx.stream()>>>(
src0_d, src1_d, src2_d, dst_d,
ne0, ne1,
nb01, nb02,
nb11,
nb21
);
}
+3
View File
@@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+36 -2
View File
@@ -1,6 +1,7 @@
#pragma once
#include "ggml.h"
#include "ggml-impl.h"
#include "ggml-cuda.h"
#include <cstdint>
@@ -232,9 +233,13 @@ typedef float2 dfloat2;
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
#define NEW_MMA_AVAILABLE
#define TURING_MMA_AVAILABLE
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#define AMPERE_MMA_AVAILABLE
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#define CP_ASYNC_AVAILABLE
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
@@ -302,10 +307,14 @@ static bool amd_mfma_available(const int cc) {
}
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
static bool new_mma_available(const int cc) {
static bool turing_mma_available(const int cc) {
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
}
static bool ampere_mma_available(const int cc) {
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
}
static bool cp_async_available(const int cc) {
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
}
@@ -549,6 +558,24 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
#endif // defined(GGML_USE_HIP)
}
static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
#if CUDART_VERSION >= 12080
const nv_bfloat16 e = __nv_cvt_e8m0_to_bf16raw(x);
return (float) e;
#else
uint32_t bits;
if (x == 0) {
bits = 0x00400000;
} else {
bits = (uint32_t) x << 23;
}
float result;
memcpy(&result, &bits, sizeof(float));
return result;
#endif // CUDART_VERSION >= 12050
}
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
static __device__ __forceinline__ float get_alibi_slope(
@@ -607,6 +634,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_Q8_0> {
static constexpr int qi = QI8_0;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_MXFP4> {
static constexpr int qk = QK_MXFP4;
static constexpr int qr = QR_MXFP4;
static constexpr int qi = QI_MXFP4;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q2_K> {
static constexpr int qk = QK_K;
+28
View File
@@ -465,6 +465,24 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
}
}
template<typename dst_t>
static __global__ void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int64_t i = blockIdx.x;
const block_mxfp4 * x = (const block_mxfp4 *) vx + i*(QK_K/QK_MXFP4);
const int64_t tid = threadIdx.x;
const int64_t il = tid/8; // 0...3
const int64_t ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
const uint8_t * q4 = x[ib].qs + 4*il;
const float d = ggml_cuda_e8m0_to_fp32(x[ib].e);
for (int j = 0; j < 4; ++j) {
y[j+ 0] = d * kvalues_mxfp4[q4[j] & 0xf]*0.5f;
y[j+16] = d * kvalues_mxfp4[q4[j] >> 4]*0.5f;
}
}
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static void dequantize_block_cuda(const void * vx, dst_t * y,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
@@ -588,6 +606,12 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_mxfp4<<<nb, 32, 0, stream>>>(vx, y);
}
template <typename src_t, typename dst_t>
static __global__ void convert_unary(
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
@@ -677,6 +701,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_row_iq4_xs_cuda;
case GGML_TYPE_IQ3_S:
return dequantize_row_iq3_s_cuda;
case GGML_TYPE_MXFP4:
return dequantize_row_mxfp4_cuda;
case GGML_TYPE_F32:
return convert_unary_cont_cuda<float>;
case GGML_TYPE_BF16:
@@ -726,6 +752,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_iq4_xs_cuda;
case GGML_TYPE_IQ3_S:
return dequantize_row_iq3_s_cuda;
case GGML_TYPE_MXFP4:
return dequantize_row_mxfp4_cuda;
case GGML_TYPE_F16:
return convert_unary_cont_cuda<half>;
case GGML_TYPE_BF16:
+4 -1
View File
@@ -15,6 +15,7 @@ typedef void (* fattn_kernel_t)(
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
@@ -736,7 +737,8 @@ void launch_fattn(
GGML_ASSERT(V || is_mla);
const ggml_tensor * mask = dst->src[3];
const ggml_tensor * mask = dst->src[3];
const ggml_tensor * sinks = dst->src[4];
ggml_tensor * KQV = dst;
@@ -940,6 +942,7 @@ void launch_fattn(
K_data,
V_data,
mask ? ((const char *) mask->data) : nullptr,
sinks ? ((const char *) sinks->data) : nullptr,
KV_max.ptr,
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
+79 -24
View File
@@ -418,7 +418,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
float * const __restrict__ KQ_max,
float * const __restrict__ KQ_rowsum,
const int kb0) {
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
typedef fattn_mma_f16_config<DKQ, DV> c;
#ifdef CP_ASYNC_AVAILABLE
@@ -776,7 +776,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
GGML_UNUSED(kb0); GGML_UNUSED(tile_Q);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
@@ -785,6 +785,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const half2 * const __restrict__ K_h2,
const half2 * const __restrict__ V_h2,
const half2 * const __restrict__ mask_h2,
const float * const __restrict__ sinks_f,
float2 * const __restrict__ dstk,
float2 * const __restrict__ dstk_fixup,
const float scale,
@@ -800,7 +801,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const int jt,
const int kb0_start,
const int kb0_stop) {
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
typedef fattn_mma_f16_config<DKQ, DV> c;
@@ -957,6 +958,52 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
}
}
// If attention sinks are used, potentially re-scale if KQ_max is small.
// Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum
// so it's being done unconditionally for every thread.
if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) {
float KQ_max_scale[cols_per_thread];
#pragma unroll
for (int col = 0; col < cols_per_thread; ++col) {
static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented");
const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col);
const float sink = sinks_f[jc % ncols2];
const float KQ_max_new = fmaxf(KQ_max[col], sink);
const float KQ_max_diff = KQ_max[col] - KQ_max_new;
KQ_max_scale[col] = expf(KQ_max_diff);
KQ_max[col] = KQ_max_new;
*((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
const float KQ_max_add = expf(sink - KQ_max_new);
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add;
}
if (ntiles == 1) {
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
#pragma unroll
for (int i = 0; i < DV/tile_C_VKQ::I; ++i) {
#pragma unroll
for (int l = 0; l < tile_C_VKQ::ne; ++l) {
VKQ_C[i].x[l] *= KQ_max_scale_h2;
}
}
} else {
#pragma unroll
for (int col = 0; col < cols_per_thread; ++col) {
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
#pragma unroll
for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) {
#pragma unroll
for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
}
}
}
}
}
// Combine VKQ accumulator values if np > 1.
// It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
// So also write VKQ accumulators to shared memory in column-major format if np == 1.
@@ -1196,7 +1243,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask);
GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla>
@@ -1206,6 +1253,7 @@ static __global__ void flash_attn_ext_f16(
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
@@ -1222,7 +1270,7 @@ static __global__ void flash_attn_ext_f16(
const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
#if defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE)
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
@@ -1267,20 +1315,24 @@ static __global__ void flash_attn_ext_f16(
// kb0 == k start index when in the output tile.
int kb0_start = kbc % iter_k;
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
while (kbc < kbc_stop && kb0_stop == iter_k) {
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
const int head0 = zt * ncols2;
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2);
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
const int kb0_start_kernel = kb0_start * kb_niter;
int kb0_stop_kernel = kb0_stop * kb_niter;
@@ -1293,12 +1345,12 @@ static __global__ void flash_attn_ext_f16(
if (kb0_start == 0) {
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
} else {
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
}
@@ -1314,18 +1366,21 @@ static __global__ void flash_attn_ext_f16(
}
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
const int head0 = zt * ncols2;
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2);
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
const int kb0_start_kernel = kb0_start * kb_niter;
int kb0_stop_kernel = kb0_stop * kb_niter;
@@ -1337,10 +1392,10 @@ static __global__ void flash_attn_ext_f16(
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
constexpr bool needs_fixup = false;
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
#else
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
@@ -1352,7 +1407,7 @@ static __global__ void flash_attn_ext_f16(
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE)
}
template <int DKQ, int DV, int ncols1, int ncols2>
+32 -5
View File
@@ -13,6 +13,7 @@ static __global__ void flash_attn_tile_ext_f16(
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
@@ -48,10 +49,11 @@ static __global__ void flash_attn_tile_ext_f16(
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float * sinksf = (const float *) (sinks);
const int stride_KV2 = nb11 / sizeof(half2);
@@ -241,6 +243,31 @@ static __global__ void flash_attn_tile_ext_f16(
__syncthreads();
}
//Attention sink: adjust running max and sum once per head
if (sinksf && blockIdx.y == 0) {
const half sink = __float2half(sinksf[head]);
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
half kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
kqmax_new_j = warp_reduce_max(kqmax_new_j);
const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new_j));
kqmax[j0/nwarps] = kqmax_new_j;
const half val = hexp(sink - kqmax[j0/nwarps]);
kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
if (threadIdx.x == 0) {
kqsum[j0/nwarps].x = __hadd(kqsum[j0/nwarps].x, val);
}
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale;
}
}
}
float2 * dst2 = (float2 *) dst;
#pragma unroll
@@ -272,7 +299,7 @@ static __global__ void flash_attn_tile_ext_f16(
}
}
#else
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
+34 -5
View File
@@ -13,6 +13,7 @@ static __global__ void flash_attn_tile_ext_f32(
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
@@ -37,7 +38,7 @@ static __global__ void flash_attn_tile_ext_f32(
return;
#endif // FP16_MMA_AVAILABLE
if (use_logit_softcap && !(D == 128 || D == 256)) {
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
@@ -59,10 +60,11 @@ static __global__ void flash_attn_tile_ext_f32(
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float * sinksf = (const float *) (sinks);
const int stride_KV2 = nb11 / sizeof(half2);
@@ -251,6 +253,33 @@ static __global__ void flash_attn_tile_ext_f32(
__syncthreads();
}
//Attention sink: adjust running max and sum once per head
if (sinksf && blockIdx.y == 0) {
const float sink = sinksf[head];
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
float kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
kqmax_new_j = warp_reduce_max(kqmax_new_j);
const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new_j);
kqmax[j0/nwarps] = kqmax_new_j;
const float val = expf(sink - kqmax[j0/nwarps]);
kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
if (threadIdx.x == 0) {
kqsum[j0/nwarps] += val;
}
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale;
VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale;
}
}
}
float2 * dst2 = (float2 *) dst;
#pragma unroll
+39 -3
View File
@@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f16(
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
@@ -61,7 +62,8 @@ static __global__ void flash_attn_vec_ext_f16(
K += nb13*sequence + nb12*(head / gqa_ratio);
V += nb23*sequence + nb22*(head / gqa_ratio);
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float * sinksf = (const float *) (sinks);
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
@@ -75,11 +77,12 @@ static __global__ void flash_attn_vec_ext_f16(
half2 * KQ2 = (half2 *) KQ;
half kqmax[ncols];
half kqsum[ncols];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
kqmax[j] = -HALF_MAX_HALF;
kqsum[j] = 0.0f;
}
half kqsum[ncols] = {0.0f};
__shared__ half kqmax_shared[ncols][WARP_SIZE];
__shared__ half kqsum_shared[ncols][WARP_SIZE];
@@ -283,6 +286,39 @@ static __global__ void flash_attn_vec_ext_f16(
__syncthreads();
}
if (sinksf && blockIdx.y == 0) {
const half sink = __float2half(sinksf[head]);
#pragma unroll
for (int j = 0; j < ncols; ++j) {
if (threadIdx.x == 0) {
kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink);
}
}
__syncthreads();
#pragma unroll
for (int j = 0; j < ncols; ++j) {
half kqmax_new_j = kqmax_shared[j][threadIdx.x];
kqmax_new_j = warp_reduce_max(kqmax_new_j);
const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
kqmax[j] = kqmax_new_j;
const half val = hexp(sink - kqmax[j]);
kqsum[j] = kqsum[j]*KQ_max_scale;
if (tid == 0) {
kqsum[j] += val;
}
VKQ[j] *= __half2half2(KQ_max_scale);
}
__syncthreads();
}
#pragma unroll
for (int j = 0; j < ncols; ++j) {
kqsum[j] = warp_reduce_sum((float)kqsum[j]);
@@ -313,7 +349,7 @@ static __global__ void flash_attn_vec_ext_f16(
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
}
#else
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
+38 -2
View File
@@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f32(
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
@@ -72,7 +73,8 @@ static __global__ void flash_attn_vec_ext_f32(
K += nb13*sequence + nb12*(head / gqa_ratio);
V += nb23*sequence + nb22*(head / gqa_ratio);
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float * sinksf = (const float *) (sinks);
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
@@ -88,11 +90,12 @@ static __global__ void flash_attn_vec_ext_f32(
}
float kqmax[ncols];
float kqsum[ncols];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
kqmax[j] = -FLT_MAX/2.0f;
kqsum[j] = 0.0f;
}
float kqsum[ncols] = {0.0f};
__shared__ float kqmax_shared[ncols][WARP_SIZE];
__shared__ float kqsum_shared[ncols][WARP_SIZE];
@@ -279,6 +282,39 @@ static __global__ void flash_attn_vec_ext_f32(
__syncthreads();
}
if (sinksf && blockIdx.y == 0) {
const float sink = sinksf[head];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
if (threadIdx.x == 0) {
kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink);
}
}
__syncthreads();
#pragma unroll
for (int j = 0; j < ncols; ++j) {
float kqmax_new_j = kqmax_shared[j][threadIdx.x];
kqmax_new_j = warp_reduce_max(kqmax_new_j);
const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j);
kqmax[j] = kqmax_new_j;
const float val = expf(sink - kqmax[j]);
kqsum[j] = kqsum[j]*KQ_max_scale;
if (tid == 0) {
kqsum[j] += val;
}
VKQ[j] *= KQ_max_scale;
}
__syncthreads();
}
#pragma unroll
for (int j = 0; j < ncols; ++j) {
kqsum[j] = warp_reduce_sum(kqsum[j]);
+55 -6
View File
@@ -29,6 +29,7 @@ static __global__ void flash_attn_ext_f16(
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
@@ -81,11 +82,12 @@ static __global__ void flash_attn_ext_f16(
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const half2 * mask2 = (const half2 *) maskh;
const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const half2 * mask2 = (const half2 *) maskh;
const float * sinksf = (const float *) sinks;
const int stride_Q = nb01 / sizeof(float);
const int stride_KV = nb11 / sizeof(half);
@@ -380,6 +382,53 @@ static __global__ void flash_attn_ext_f16(
__syncthreads();
}
// Apply attention sinks
if (sinksf && blockIdx.y == 0) {
const float sinkf = sinksf[head];
const half sinkh = __float2half(sinkf);
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
if (std::is_same<KQ_acc_t, float>::value) {
float kqmax_new = fmaxf(KQ_max_f[j0/nwarps], sinkf);
const float KQ_max_scale = expf(KQ_max_f[j0/nwarps] - kqmax_new);
KQ_max_f[j0/nwarps] = kqmax_new;
KQ_rowsum_f[j0/nwarps] = KQ_rowsum_f[j0/nwarps] * KQ_max_scale + expf(sinkf - KQ_max_f[j0/nwarps]);
const half2 scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
const int i = i0 + threadIdx.x;
if (i0 + warp_size > D/2 && i >= D/2) break;
VKQ2[j*(D_padded/2) + i] *= scale_h2;
}
} else {
half kqmax_old = __low2half(KQ_max_h2[j0/nwarps]);
half kqmax_new = fmaxf(kqmax_old, sinkh);
KQ_max_h2[j0/nwarps] = __half2half2(kqmax_new);
const half KQ_max_scale_h = hexp(kqmax_old - kqmax_new);
const half2 KQ_max_scale = __half2half2(KQ_max_scale_h);
KQ_rowsum_h2[j0/nwarps] = KQ_rowsum_h2[j0/nwarps] * KQ_max_scale;
const half val = hexp(sinkh - kqmax_new);
KQ_rowsum_h2[j0/nwarps].x = __hadd(KQ_rowsum_h2[j0/nwarps].x, val);
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
const int i = i0 + threadIdx.x;
if (i0 + warp_size > D/2 && i >= D/2) break;
VKQ2[j*(D_padded/2) + i] *= KQ_max_scale;
}
}
}
__syncthreads();
}
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j_VKQ = j0 + threadIdx.y;
@@ -423,7 +472,7 @@ static __global__ void flash_attn_ext_f16(
dst_meta[j_dst_unrolled] = dst_meta_val;
}
#else
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
+9 -8
View File
@@ -269,11 +269,11 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
}
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];
ggml_cuda_set_device(ctx.device);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
@@ -315,8 +315,9 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies &&
(Q->ne[3] > 1 || cc < GGML_CUDA_CC_ADA_LOVELACE) && !mma_needs_data_conversion;
const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (Q->ne[2] > 4*K->ne[2] && K->ne[1] >= 8192);
const bool mma_faster_for_bs1 = turing_mma_available(cc) && gqa_opt_applies && !mma_needs_data_conversion &&
(cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000);
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
if (prec == GGML_PREC_DEFAULT) {
@@ -328,7 +329,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
}
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
if (fp16_mma_available(cc) && !new_mma_available(cc)) {
if (fp16_mma_available(cc) && !turing_mma_available(cc)) {
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
return;
}
+62 -20
View File
@@ -4,6 +4,7 @@
#include "ggml-cuda/common.cuh"
#include "ggml-cuda/acc.cuh"
#include "ggml-cuda/add-id.cuh"
#include "ggml-cuda/arange.cuh"
#include "ggml-cuda/argmax.cuh"
#include "ggml-cuda/argsort.cuh"
@@ -21,8 +22,9 @@
#include "ggml-cuda/fattn.cuh"
#include "ggml-cuda/getrows.cuh"
#include "ggml-cuda/im2col.cuh"
#include "ggml-cuda/mmf.cuh"
#include "ggml-cuda/mmq.cuh"
#include "ggml-cuda/mmv.cuh"
#include "ggml-cuda/mmvf.cuh"
#include "ggml-cuda/mmvq.cuh"
#include "ggml-cuda/norm.cuh"
#include "ggml-cuda/opt-step-adamw.cuh"
@@ -1852,6 +1854,9 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
ggml_cuda_pool_alloc<cuda_t> src0_alloc(ctx.pool());
ggml_cuda_pool_alloc<cuda_t> src1_alloc(ctx.pool());
bool is_src0_cont_2 = ggml_is_contiguous_2(src0);
bool is_src1_cont_2 = ggml_is_contiguous_2(src1);
// Handle src0
src0_ptr = (const cuda_t *) src0->data;
@@ -1870,6 +1875,8 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
s11 = ne10;
s12 = ne11*s11;
s13 = ne12*s12;
is_src1_cont_2 = true;
}
// Setup destination buffer
@@ -1918,15 +1925,19 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
const int64_t r2 = ne12/ne02;
const int64_t r3 = ne13/ne03;
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) {
// with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3:
const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00;
const int64_t smb = ne12 == 1 ? s13 : s12;
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
// use cublasGemmStridedBatchedEx
CUBLAS_CHECK(
cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
src1_ptr, cu_data_type_b, s11, s12, // strideB
beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
alpha, src0_ptr, cu_data_type_a, nb01/nb00, sma, // strideA
src1_ptr, cu_data_type_b, s11, smb, // strideB
beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
ne12*ne13,
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
@@ -1998,7 +2009,9 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE
&& ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src;
bool use_mul_mat_vec = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
bool use_mul_mat_vec_f = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
bool use_mul_mat_f = !ggml_is_quantized(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
@@ -2018,14 +2031,18 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
}
const int cc = ggml_cuda_info().devices[id].cc;
const int warp_size = ggml_cuda_info().devices[id].warp_size;
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]);
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
}
} else {
const int cc = ggml_cuda_info().devices[ctx.device].cc;
const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]);
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
}
@@ -2038,15 +2055,17 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
//TODO update for generic tensor parallelism
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
bool use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);
bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
if (!split && use_mul_mat_vec) {
if (!split && use_mul_mat_vec_f) {
// the custom F16 vector kernel can be used over batched cuBLAS GEMM
// but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
ggml_cuda_mul_mat_vec_f(ctx, src0, src1, nullptr, dst);
} else if (!split && use_mul_mat_f) {
ggml_cuda_mul_mat_f(ctx, src0, src1, nullptr, dst);
} else if (!split && use_mul_mat_vec_q) {
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
} else if (!split && use_mul_mat_q) {
@@ -2055,8 +2074,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
&& !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
// general KQ + KQV multi-batch without FlashAttention
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
} else if (use_mul_mat_vec) {
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec, nullptr);
} else if (use_mul_mat_vec_f) {
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_f, nullptr);
} else if (use_mul_mat_vec_q) {
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
} else if (use_mul_mat_q) {
@@ -2084,7 +2103,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
if (ggml_is_quantized(src0->type)) {
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
} else {
ggml_cuda_mul_mat_vec(ctx, src0, src1, ids, dst);
ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst);
}
return;
}
@@ -2250,6 +2269,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_ADD1: // TODO: more efficient implementation
ggml_cuda_op_add(ctx, dst);
break;
case GGML_OP_ADD_ID:
ggml_cuda_op_add_id(ctx, dst);
break;
case GGML_OP_SUB:
ggml_cuda_op_sub(ctx, dst);
break;
@@ -2324,6 +2346,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_GLU_OP_SWIGLU:
ggml_cuda_op_swiglu(ctx, dst);
break;
case GGML_GLU_OP_SWIGLU_OAI:
ggml_cuda_op_swiglu_oai(ctx, dst);
break;
case GGML_GLU_OP_GEGLU_ERF:
ggml_cuda_op_geglu_erf(ctx, dst);
break;
@@ -2598,6 +2623,9 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
@@ -2620,7 +2648,13 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
#endif
}
if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1 && (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) && (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true)) {
if (node->op == GGML_OP_ADD &&
node->src[1] && node->src[1]->ne[1] > 1 &&
(node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
(node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0) {
// disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
// by means of matching node names. See
// https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
@@ -3218,6 +3252,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_SWIGLU_OAI:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
return ggml_is_contiguous_1(op->src[0]);
@@ -3268,6 +3303,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@@ -3414,6 +3450,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
case GGML_OP_ADD:
case GGML_OP_ADD_ID:
case GGML_OP_ADD1:
case GGML_OP_SUB:
case GGML_OP_MUL:
@@ -3488,12 +3525,17 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
#endif // FLASH_ATTN_AVAILABLE
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
if (!new_mma_available(cc)) {
if (!turing_mma_available(cc)) {
return false;
}
const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
}
// TODO: more general-purpose attention sink support [TAG_ATTN_SINKS]
if (op->src[4] && !fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc)
&& op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) {
return false;
}
if (op->src[0]->ne[0] == 192) {
return false;
}
@@ -3757,10 +3799,10 @@ ggml_backend_t ggml_backend_cuda_init(int device) {
}
ggml_backend_t cuda_backend = new ggml_backend {
/* .guid = */ ggml_backend_cuda_guid(),
/* .interface = */ ggml_backend_cuda_interface,
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device),
/* .context = */ ctx,
/* .guid = */ ggml_backend_cuda_guid(),
/* .iface = */ ggml_backend_cuda_interface,
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device),
/* .context = */ ctx,
};
return cuda_backend;
+46 -35
View File
@@ -1,65 +1,76 @@
#include "im2col.cuh"
#define MAX_GRIDDIM_Z 65535
template <typename T>
static __global__ void im2col_kernel(
const float * x, T * dst, int64_t batch_offset,
int64_t offset_delta, int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW,
const float * x, T * dst,
int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH,
int64_t IC_IH_IW, int64_t IH_IW, int64_t N_OH, int64_t KH_KW, int64_t IC_KH_KW,
int s0, int s1, int p0, int p1, int d0, int d1) {
const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= pelements) {
if (i >= IC_KH_KW) {
return;
}
const int64_t ksize = OW * KH;
const int64_t kx = i / ksize;
const int64_t kd = kx * ksize;
const int64_t ky = (i - kd) / OW;
const int64_t ix = i % OW;
const int64_t iic = i / (KH_KW);
const int64_t rem = i - iic * KH_KW;
const int64_t ikh = rem / KW;
const int64_t ikw = rem - ikh * KW;
const int64_t oh = blockIdx.y;
const int64_t batch = blockIdx.z / IC;
const int64_t ic = blockIdx.z % IC;
const int64_t iow = blockIdx.y;
for (int64_t iz = blockIdx.z; iz < N_OH; iz+=MAX_GRIDDIM_Z) {
const int64_t in = iz / OH;
const int64_t ioh = iz - in * OH;
const int64_t iiw = ix * s0 + kx * d0 - p0;
const int64_t iih = oh * s1 + ky * d1 - p1;
const int64_t iiw = iow * s0 + ikw * d0 - p0;
const int64_t iih = ioh * s1 + ikh * d1 - p1;
const int64_t offset_dst =
((batch * OH + oh) * OW + ix) * CHW +
(ic * (KW * KH) + ky * KW + kx);
const int64_t offset_dst =
((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw;
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst[offset_dst] = 0.0f;
} else {
const int64_t offset_src = ic * offset_delta + batch * batch_offset;
dst[offset_dst] = x[offset_src + iih * IW + iiw];
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst[offset_dst] = 0.0f;
} else {
const int64_t offset_src = iic * IC_IH_IW + in * IH_IW;
dst[offset_dst] = x[offset_src + iih * IW + iiw];
}
}
GGML_UNUSED(IC);
GGML_UNUSED(KH);
}
// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
template <typename T>
static void im2col_cuda(const float * x, T* dst,
int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
int64_t batch, int64_t batch_offset, int64_t offset_delta,
int64_t N, int64_t IC_IH_IW, int64_t IH_IW,
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
const int parallel_elements = OW * KW * KH;
const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
dim3 block_nums(num_blocks, OH, batch * IC);
im2col_kernel<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
const int64_t IC_KH_KW = IC * KH * KW;
const int64_t num_blocks = (IC_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
const int64_t N_OH = N * OH;
const int64_t KH_KW = KW*KH;
dim3 block_nums(num_blocks, OW, MIN(N_OH, MAX_GRIDDIM_Z));
im2col_kernel<<<block_nums, MIN(IC_KH_KW, CUDA_IM2COL_BLOCK_SIZE) , 0, stream>>>(x, dst, IC, IW, IH, OH, OW, KW, KH,
IC_IH_IW, IH_IW, N_OH, KH_KW, IC_KH_KW,
s0, s1, p0, p1, d0, d1);
}
static void im2col_cuda_f16(const float * x, half * dst,
int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
int64_t batch, int64_t batch_offset, int64_t offset_delta,
int64_t N, int64_t IC_IH_IW, int64_t IH_IW,
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
im2col_cuda<half>(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, d0, d1, stream);
im2col_cuda<half>(x, dst, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
}
static void im2col_cuda_f32(const float * x, float * dst,
int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
int64_t batch, int64_t batch_offset, int64_t offset_delta,
int64_t N, int64_t IC_IH_IW, int64_t IH_IW,
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
im2col_cuda<float>(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, d0, d1, stream);
im2col_cuda<float>(x, dst, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
}
void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -91,13 +102,13 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int64_t OH = is_2D ? dst->ne[2] : 1;
const int64_t OW = dst->ne[1];
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
const int64_t batch = src1->ne[is_2D ? 3 : 2];
const size_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
const int64_t IC_IH_IW = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
const int64_t N = src1->ne[is_2D ? 3 : 2];
const int64_t IH_IW = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
if(dst->type == GGML_TYPE_F16) {
im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
} else {
im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
}
}
+88 -22
View File
@@ -23,13 +23,13 @@
static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
int ret = 0;
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
: "=r"(ret) : "r"(x));
#else
GGML_UNUSED(x);
NO_DEVICE_CODE;
#endif // defined(NEW_MMA_AVAILABLE)
#endif // defined(TURING_MMA_AVAILABLE)
return ret;
}
@@ -167,6 +167,38 @@ namespace ggml_cuda_mma {
}
};
template <int I_, int J_>
struct tile<I_, J_, nv_bfloat162> {
static constexpr int I = I_;
static constexpr int J = J_;
static constexpr int ne = I * J / WARP_SIZE;
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 8 && J == 8) {
return threadIdx.x / 4;
} else if constexpr (I == 16 && J == 4) {
return l * 8 + threadIdx.x / 4;
} else if constexpr (I == 16 && J == 8) {
return (l % 2) * 8 + threadIdx.x / 4;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 8 && J == 8) {
return l * 4 + threadIdx.x % 4;
} else if constexpr (I == 16 && J == 4) {
return threadIdx.x % 4;
} else if constexpr (I == 16 && J == 8) {
return (l / 2) * 4 + threadIdx.x % 4;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}
};
template <int I, int J>
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
tile<I, J/2, half2> ret;
@@ -209,7 +241,7 @@ namespace ggml_cuda_mma {
template <typename T>
static __device__ __forceinline__ void load_ldmatrix(
tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) {
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
int * xi = (int *) t.x;
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J;
asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
@@ -217,13 +249,13 @@ namespace ggml_cuda_mma {
: "l"(xs));
#else
load_generic(t, xs0, stride);
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}
template <typename T>
static __device__ __forceinline__ void load_ldmatrix(
tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) {
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
int * xi = (int *) t.x;
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride;
asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
@@ -232,13 +264,13 @@ namespace ggml_cuda_mma {
#else
load_generic(xs0, stride);
GGML_UNUSED(t);
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}
template <typename T>
static __device__ __forceinline__ void load_ldmatrix(
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
#if defined(NEW_MMA_AVAILABLE)
#if defined(TURING_MMA_AVAILABLE)
int * xi = (int * ) t.x;
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
@@ -246,13 +278,13 @@ namespace ggml_cuda_mma {
: "l"(xs));
#else
load_generic(t, xs0, stride);
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}
template <typename T>
static __device__ __forceinline__ void load_ldmatrix_trans(
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
int * xi = (int * ) t.x;
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
@@ -263,12 +295,12 @@ namespace ggml_cuda_mma {
GGML_UNUSED(xs0);
GGML_UNUSED(stride);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) {
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
: "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
@@ -287,12 +319,12 @@ namespace ggml_cuda_mma {
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) {
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
@@ -317,12 +349,12 @@ namespace ggml_cuda_mma {
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x;
@@ -344,12 +376,12 @@ namespace ggml_cuda_mma {
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x;
@@ -380,12 +412,29 @@ namespace ggml_cuda_mma {
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 8, float> & D, const tile<16, 8, float> & A, const tile<8, 8, float> & B) {
#ifdef AMPERE_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x;
asm("mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
#else
GGML_UNUSED(D);
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // AMPERE_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x;
@@ -407,12 +456,29 @@ namespace ggml_cuda_mma {
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 8, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<8, 8, nv_bfloat162> & B) {
#ifdef AMPERE_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x;
asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
#else
GGML_UNUSED(D);
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // AMPERE_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x;
@@ -443,7 +509,7 @@ namespace ggml_cuda_mma {
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
+431
View File
@@ -0,0 +1,431 @@
#include "ggml.h"
#include "common.cuh"
#include "mma.cuh"
#include "mmf.cuh"
using namespace ggml_cuda_mma;
#define MMF_ROWS_PER_BLOCK 32
template <typename T, int rows_per_block, int cols_per_block, int nwarps>
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
static __global__ void mul_mat_f(
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
const int ncols, const int nchannels_y, const int stride_row, const int stride_col_y, const int stride_col_dst,
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
typedef tile<16, 8, T> tile_A;
typedef tile< 8, 8, T> tile_B;
typedef tile<16, 8, float> tile_C;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr int tile_k_padded = warp_size + 4;
constexpr int ntA = rows_per_block / tile_A::I;
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
const int row0 = blockIdx.x * rows_per_block;
const int channel_dst = blockIdx.y;
const int channel_x = channel_dst / channel_ratio;
const int channel_y = channel_dst;
const int sample_dst = blockIdx.z;
const int sample_x = sample_dst / sample_ratio;
const int sample_y = sample_dst;
x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row ;
y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
const float2 * y2 = (const float2 *) y;
extern __shared__ char data_mmv[];
tile_C C[ntA][ntB];
T * tile_xy = (T *) data_mmv + threadIdx.y*(tile_A::I * tile_k_padded);
for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
tile_A A[ntA][warp_size / tile_A::J];
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
#pragma unroll
for (int i = 0; i < tile_A::I; ++i) {
tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col];
}
#pragma unroll
for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
}
}
#pragma unroll
for (int itB = 0; itB < ntB; ++itB) {
if constexpr (std::is_same_v<T, float>) {
#pragma unroll
for (int j0 = 0; j0 < tile_B::I; ++j0) {
const int j = j0 + itB*tile_B::I;
tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
}
} else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
#pragma unroll
for (int j0 = 0; j0 < tile_B::I; ++j0) {
const int j = j0 + itB*tile_B::I;
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
}
} else {
static_assert(std::is_same_v<T, void>, "unsupported type");
}
#pragma unroll
for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
tile_B B;
load_ldmatrix(B, tile_xy + k0, tile_k_padded);
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
mma(C[itA][itB], A[itA][k0/tile_B::J], B);
}
}
}
}
float * buf_iw = (float *) data_mmv;
constexpr int kiw = nwarps*rows_per_block + 4;
if (nwarps > 1) {
__syncthreads();
}
#pragma unroll
for (int itB = 0; itB < ntB; ++itB) {
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
#pragma unroll
for (int l = 0; l < tile_C::ne; ++l) {
const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
const int j = itB*tile_C::J + tile_C::get_j(l);
buf_iw[j*kiw + i] = C[itA][itB].x[l];
}
}
}
if (nwarps > 1) {
__syncthreads();
}
#pragma unroll
for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
const int j = j0 + threadIdx.y;
if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
return;
}
float sum = 0.0f;
static_assert(rows_per_block == warp_size, "need loop/check");
#pragma unroll
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
const int i = i0 + threadIdx.x;
sum += buf_iw[j*kiw + i];
}
dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
}
#else
NO_DEVICE_CODE;
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(ids); GGML_UNUSED(dst);
GGML_UNUSED(ncols); GGML_UNUSED(nchannels_y); GGML_UNUSED(stride_row); GGML_UNUSED(stride_col_y); GGML_UNUSED(stride_col_dst);
GGML_UNUSED(channel_ratio); GGML_UNUSED(stride_channel_x); GGML_UNUSED(stride_channel_y); GGML_UNUSED(stride_channel_dst);
GGML_UNUSED(sample_ratio); GGML_UNUSED(stride_sample_x); GGML_UNUSED(stride_sample_y); GGML_UNUSED(stride_sample_dst);
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
}
template <typename T, int cols_per_block>
static void mul_mat_f_cuda(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t nrows_x,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) {
typedef tile<16, 8, T> tile_A;
typedef tile< 8, 8, T> tile_B;
typedef tile<16, 8, float> tile_C;
GGML_ASSERT(!ids && "mul_mat_id not implemented");
GGML_ASSERT(ncols_x % 2 == 0);
GGML_ASSERT(stride_row % 2 == 0);
GGML_ASSERT(stride_col_y % 2 == 0);
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
const int64_t channel_ratio = nchannels_dst / nchannels_x;
const int64_t sample_ratio = nsamples_dst / nsamples_x;
const int device = ggml_cuda_get_device();
const int warp_size = ggml_cuda_info().devices[device].warp_size;
int64_t nwarps_best = 1;
int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2);
int64_t max_block_size = 256;
for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
if (niter < niter_best) {
niter_best = niter;
nwarps_best = nwarps;
}
}
constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4;
const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
const dim3 block_nums(nrows_x/rows_per_block, nchannels_dst, nsamples_dst);
const dim3 block_dims(warp_size, nwarps_best, 1);
switch (nwarps_best) {
case 1: {
mul_mat_f<T, rows_per_block, cols_per_block, 1><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 2: {
mul_mat_f<T, rows_per_block, cols_per_block, 2><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 3: {
mul_mat_f<T, rows_per_block, cols_per_block, 3><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 4: {
mul_mat_f<T, rows_per_block, cols_per_block, 4><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 5: {
mul_mat_f<T, rows_per_block, cols_per_block, 5><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 6: {
mul_mat_f<T, rows_per_block, cols_per_block, 6><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 7: {
mul_mat_f<T, rows_per_block, cols_per_block, 7><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 8: {
mul_mat_f<T, rows_per_block, cols_per_block, 8><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
default: {
GGML_ABORT("fatal error");
} break;
}
}
template <typename T>
static void mul_mat_f_switch_cols_per_block(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) {
switch (ncols_dst) {
case 1: {
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 2: {
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 3: {
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 4: {
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 5: {
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 6: {
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 7: {
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 8: {
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 9: {
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 10: {
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 11: {
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 12: {
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 13: {
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 14: {
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 15: {
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 16: {
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
default: {
GGML_ABORT("fatal error");
} break;
}
}
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
GGML_ASSERT( src1->type == GGML_TYPE_F32);
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_TENSOR_BINARY_OP_LOCALS;
const size_t ts_src0 = ggml_type_size(src0->type);
const size_t ts_src1 = ggml_type_size(src1->type);
const size_t ts_dst = ggml_type_size(dst->type);
GGML_ASSERT(ne13 == ne3);
GGML_ASSERT( nb00 == ts_src0);
GGML_ASSERT( nb10 == ts_src1);
GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
GGML_ASSERT( nb0 == ts_dst);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
const float * src1_d = (const float *) src1->data;
const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
float * dst_d = (float *) dst->data;
const int64_t s01 = src0->nb[1] / ts_src0;
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s1 = dst->nb[1] / ts_dst;
const int64_t s02 = src0->nb[2] / ts_src0;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s2 = dst->nb[2] / ts_dst;
const int64_t s03 = src0->nb[3] / ts_src0;
const int64_t s13 = src1->nb[3] / ts_src1;
const int64_t s3 = dst->nb[3] / ts_dst;
// For MUL_MAT_ID the memory layout is different than for MUL_MAT:
const int64_t ncols_dst = ids ? ne2 : ne1;
const int64_t nchannels_y = ids ? ne11 : ne12;
const int64_t nchannels_dst = ids ? ne1 : ne2;
const int64_t stride_channel_dst = ids ? s1 : s2;
const int64_t stride_channel_y = ids ? s11 : s12;
GGML_ASSERT(!ids || ncols_dst == 1);
switch (src0->type) {
case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0->data;
constexpr int vals_per_T = 1;
mul_mat_f_switch_cols_per_block(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
} break;
case GGML_TYPE_F16: {
const half2 * src0_d = (const half2 *) src0->data;
constexpr int vals_per_T = 2;
mul_mat_f_switch_cols_per_block(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
} break;
case GGML_TYPE_BF16: {
const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
constexpr int vals_per_T = 2;
mul_mat_f_switch_cols_per_block(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
} break;
default:
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
}
}
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, int64_t ne11) {
if (src0_ne[0] % (warp_size * (4/ggml_type_size(type))) != 0) {
return false;
}
if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) {
return false;
}
if (ne11 > 16) {
return false;
}
switch (type) {
case GGML_TYPE_F32:
return ampere_mma_available(cc);
case GGML_TYPE_F16:
return turing_mma_available(cc);
case GGML_TYPE_BF16:
return ampere_mma_available(cc);
default:
return false;
}
}
+5
View File
@@ -0,0 +1,5 @@
#include "common.cuh"
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, int64_t ne11);
+5 -1
View File
@@ -20,6 +20,9 @@ static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, con
case GGML_TYPE_Q8_0:
mul_mat_q_case<GGML_TYPE_Q8_0>(ctx, args, stream);
break;
case GGML_TYPE_MXFP4:
mul_mat_q_case<GGML_TYPE_MXFP4>(ctx, args, stream);
break;
case GGML_TYPE_Q2_K:
mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream);
break;
@@ -282,6 +285,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@@ -306,7 +310,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
return false;
}
if (new_mma_available(cc)) {
if (turing_mma_available(cc)) {
return true;
}
+206 -128
View File
File diff suppressed because it is too large Load Diff
@@ -1,9 +1,9 @@
#include "ggml.h"
#include "common.cuh"
#include "mmv.cuh"
#include "mmvf.cuh"
template <typename T, typename type_acc, int ncols_dst, int block_size>
static __global__ void mul_mat_vec(
static __global__ void mul_mat_vec_f(
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
@@ -37,7 +37,7 @@ static __global__ void mul_mat_vec(
float sumf[ncols_dst] = {0.0f};
if constexpr (std::is_same<T, float>::value) {
if constexpr (std::is_same_v<T, float>) {
const float2 * x2 = (const float2 *) x;
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
@@ -50,10 +50,10 @@ static __global__ void mul_mat_vec(
sumf[j] += tmpx.y*tmpy.y;
}
}
} else if constexpr (std::is_same<T, half>::value) {
} else if constexpr (std::is_same_v<T, half>) {
const half2 * x2 = (const half2 *) x;
if (std::is_same<type_acc, float>::value) {
if (std::is_same_v<type_acc, float>) {
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const float2 tmpx = __half22float2(x2[col2]);
@@ -86,7 +86,7 @@ static __global__ void mul_mat_vec(
NO_DEVICE_CODE;
#endif // FP16_AVAILABLE
}
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
} else if constexpr (std::is_same_v<T, nv_bfloat16>) {
const int * x2 = (const int *) x;
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const int tmpx = x2[col2];
@@ -98,7 +98,7 @@ static __global__ void mul_mat_vec(
}
}
} else {
static_assert(std::is_same<T, void>::value, "unsupported type");
static_assert(std::is_same_v<T, void>, "unsupported type");
}
#pragma unroll
@@ -126,7 +126,7 @@ static __global__ void mul_mat_vec(
}
template <typename T, typename type_acc, int ncols_dst>
static void launch_mul_mat_vec_cuda(
static void launch_mul_mat_vec_f_cuda(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols, const int64_t nrows,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
@@ -141,11 +141,9 @@ static void launch_mul_mat_vec_cuda(
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
const int64_t channel_ratio = nchannels_dst / nchannels_x;
const int64_t sample_ratio = nsamples_dst / nsamples_x;
int device;
int warp_size;
CUDA_CHECK(cudaGetDevice(&device));
warp_size = ggml_cuda_info().devices[device].warp_size;
const int device = ggml_cuda_get_device();
const int warp_size = ggml_cuda_info().devices[device].warp_size;
int64_t block_size_best = warp_size;
int64_t niter_best = (ncols + 2*warp_size - 1) / (2*warp_size);
@@ -161,54 +159,54 @@ static void launch_mul_mat_vec_cuda(
}
}
const int smem = warp_size*sizeof(float);
const int nbytes_shared = warp_size*sizeof(float);
const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
const dim3 block_dims(block_size_best, 1, 1);
switch (block_size_best) {
case 32: {
mul_mat_vec<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, smem, stream>>>
mul_mat_vec_f<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 64: {
mul_mat_vec<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, smem, stream>>>
mul_mat_vec_f<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 96: {
mul_mat_vec<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, smem, stream>>>
mul_mat_vec_f<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 128: {
mul_mat_vec<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, smem, stream>>>
mul_mat_vec_f<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 160: {
mul_mat_vec<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, smem, stream>>>
mul_mat_vec_f<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 192: {
mul_mat_vec<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, smem, stream>>>
mul_mat_vec_f<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 224: {
mul_mat_vec<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, smem, stream>>>
mul_mat_vec_f<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 256: {
mul_mat_vec<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, smem, stream>>>
mul_mat_vec_f<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
@@ -220,7 +218,7 @@ static void launch_mul_mat_vec_cuda(
}
template <typename T, typename type_acc>
static void mul_mat_vec_cuda_switch_ncols_dst(
static void mul_mat_vec_f_cuda_switch_ncols_dst(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
@@ -230,49 +228,49 @@ static void mul_mat_vec_cuda_switch_ncols_dst(
cudaStream_t stream) {
switch (ncols_dst) {
case 1:
launch_mul_mat_vec_cuda<T, type_acc, 1>
launch_mul_mat_vec_f_cuda<T, type_acc, 1>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 2:
launch_mul_mat_vec_cuda<T, type_acc, 2>
launch_mul_mat_vec_f_cuda<T, type_acc, 2>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 3:
launch_mul_mat_vec_cuda<T, type_acc, 3>
launch_mul_mat_vec_f_cuda<T, type_acc, 3>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 4:
launch_mul_mat_vec_cuda<T, type_acc, 4>
launch_mul_mat_vec_f_cuda<T, type_acc, 4>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 5:
launch_mul_mat_vec_cuda<T, type_acc, 5>
launch_mul_mat_vec_f_cuda<T, type_acc, 5>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 6:
launch_mul_mat_vec_cuda<T, type_acc, 6>
launch_mul_mat_vec_f_cuda<T, type_acc, 6>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 7:
launch_mul_mat_vec_cuda<T, type_acc, 7>
launch_mul_mat_vec_f_cuda<T, type_acc, 7>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 8:
launch_mul_mat_vec_cuda<T, type_acc, 8>
launch_mul_mat_vec_f_cuda<T, type_acc, 8>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
@@ -284,7 +282,7 @@ static void mul_mat_vec_cuda_switch_ncols_dst(
}
template<typename T>
static void mul_mat_vec_cuda(
static void mul_mat_vec_f_cuda(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,
@@ -292,22 +290,22 @@ static void mul_mat_vec_cuda(
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
enum ggml_prec prec, cudaStream_t stream) {
if constexpr(std::is_same<T, half>::value) {
if constexpr(std::is_same_v<T, half>) {
if (prec == GGML_PREC_DEFAULT) {
mul_mat_vec_cuda_switch_ncols_dst<T, half>
mul_mat_vec_f_cuda_switch_ncols_dst<T, half>
(x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
return;
}
}
mul_mat_vec_cuda_switch_ncols_dst<T, float>
mul_mat_vec_f_cuda_switch_ncols_dst<T, float>
(x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
}
void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
GGML_ASSERT( src1->type == GGML_TYPE_F32);
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
@@ -355,19 +353,19 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
switch (src0->type) {
case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0->data;
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, prec, ctx.stream());
} break;
case GGML_TYPE_F16: {
const half * src0_d = (const half *) src0->data;
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, prec, ctx.stream());
} break;
case GGML_TYPE_BF16: {
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, prec, ctx.stream());
} break;
@@ -376,7 +374,7 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
}
}
void ggml_cuda_op_mul_mat_vec(
void ggml_cuda_op_mul_mat_vec_f(
ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
@@ -414,19 +412,19 @@ void ggml_cuda_op_mul_mat_vec(
switch (src0->type) {
case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0_dd_i;
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
} break;
case GGML_TYPE_F16: {
const half * src0_d = (const half *) src0_dd_i;
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
} break;
case GGML_TYPE_BF16: {
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
} break;
@@ -442,15 +440,15 @@ void ggml_cuda_op_mul_mat_vec(
GGML_UNUSED(src1_padded_row_size);
}
bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) {
bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) {
if (src0_ne[0] % 2 != 0) {
return false;
}
switch (type) {
case GGML_TYPE_F32:
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
return ne11 <= 8;
if (ampere_mma_available(cc)) {
return ne11 <= 3;
}
if (cc >= GGML_CUDA_CC_TURING) {
return ne11 <= 4;
@@ -466,6 +464,9 @@ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_
case GGML_TYPE_F16:
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
if (ampere_mma_available(cc)) {
return src0_small && ne11 == 1;
}
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
return src0_small && ne11 <= 4;
}
@@ -486,6 +487,9 @@ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_
case GGML_TYPE_BF16:
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
if (ampere_mma_available(cc)) {
return src0_small && ne11 == 1;
}
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
return src0_small && ne11 <= 4;
}
@@ -1,11 +1,11 @@
#include "common.cuh"
void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
void ggml_cuda_op_mul_mat_vec(
void ggml_cuda_op_mul_mat_vec_f(
ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream);
bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11);
bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11);
+9
View File
@@ -13,6 +13,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1;
case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1;
case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1;
case GGML_TYPE_MXFP4: return vec_dot_mxfp4_q8_1;
case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1;
case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1;
case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1;
@@ -38,6 +39,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ;
case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ;
case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ;
case GGML_TYPE_MXFP4: return VDR_MXFP4_Q8_1_MMVQ;
case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ;
case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ;
case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ;
@@ -384,6 +386,13 @@ static void mul_mat_vec_q_switch_type(
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
break;
case GGML_TYPE_MXFP4:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
break;
case GGML_TYPE_Q2_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+16 -10
View File
@@ -45,7 +45,7 @@ struct soft_max_params {
#endif // __clang__
template <bool use_shared, int ncols_template, int block_size_template, typename T>
static __global__ void soft_max_f32(
const float * x, const T * mask, float * dst, const soft_max_params p) {
const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params p) {
const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
const int tid = threadIdx.x;
@@ -77,7 +77,7 @@ static __global__ void soft_max_f32(
// shared memory buffer to cache values between iterations:
float * vals = use_shared ? buf_iw + WARP_SIZE : dst;
float max_val = -INFINITY;
float max_val = sinks ? sinks[i02] : -INFINITY;
#pragma unroll
for (int col0 = 0; col0 < ncols; col0 += block_size) {
@@ -143,6 +143,10 @@ static __global__ void soft_max_f32(
tmp = warp_reduce_sum(tmp);
}
if (sinks) {
tmp += expf(sinks[i02] - max_val);
}
const float inv_sum = 1.0f / tmp;
#pragma unroll
@@ -183,7 +187,7 @@ static __global__ void soft_max_back_f32(
}
template<int... Ns, typename T>
static void launch_soft_max_kernels(const float * x, const T * mask, float * dst,
static void launch_soft_max_kernels(const float * x, const T * mask, const float * sinks, float * dst,
const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
{
const int id = ggml_cuda_get_device();
@@ -196,7 +200,7 @@ static void launch_soft_max_kernels(const float * x, const T * mask, float * dst
if (p.ncols == ncols) {
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, p);
(x, mask, sinks, dst, p);
return true;
}
return false;
@@ -209,12 +213,12 @@ static void launch_soft_max_kernels(const float * x, const T * mask, float * dst
//default case
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, dst, p);
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, sinks, dst, p);
}
template<typename T>
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) {
static void soft_max_f32_cuda(const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params & params, cudaStream_t stream) {
int nth = WARP_SIZE;
const int64_t ncols_x = params.ncols;
@@ -230,10 +234,10 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
if (nbytes_shared <= smpbo) {
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, dst, params, stream, block_dims, block_nums, nbytes_shared);
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared);
} else {
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, params);
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, sinks, dst, params);
}
}
@@ -249,9 +253,11 @@ static void soft_max_back_f32_cuda(
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
const float * src0_d = (const float *) src0->data;
const void * src1_d = src1 ? (const void *) src1->data : nullptr;
const void * src2_d = src2 ? (const void *) src2->data : nullptr;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
@@ -309,9 +315,9 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
params.m1 = m1;
if (use_f16) {
soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, params, stream);
soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream);
} else {
soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, params, stream);
soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream);
}
}
+153 -71
View File
@@ -1,87 +1,117 @@
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
#define USE_CUB
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
#ifdef USE_CUB
#include <cub/cub.cuh>
using namespace cub;
#endif // USE_CUB
#include "ssm-scan.cuh"
template <size_t splitD, size_t N>
__global__ void __launch_bounds__(splitD, 2)
ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
// We would like to keep pragma unroll for cases where L_template is not 0,
// so we suppress the clang transformation warning.
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wpass-failed"
#endif // __clang__
template <size_t splitD, size_t N, size_t L_template>
__global__ void __launch_bounds__(splitD, 1)
ssm_scan_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2,
const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5,
const int32_t * __restrict__ src6, float * __restrict__ dst,
const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
const int src2_nb1, const int src2_nb2, const int src3_nb1,
const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
const int64_t s_off, const int64_t d_inner, const int64_t L) {
const int64_t s_off, const int64_t d_inner, const int64_t L_param)
{
const size_t L = L_template == 0 ? L_param : L_template;
const float *s0_block = (const float *)((const char *)src0 + src6[blockIdx.x] * src0_nb3 + blockIdx.y * splitD * src0_nb2);
const float *x_block = (const float *)((const char *)src1 + (blockIdx.x * src1_nb3) + blockIdx.y * splitD * sizeof(float));
const float *dt_block = (const float *)((const char *)src2 + (blockIdx.x * src2_nb2) + blockIdx.y * splitD * sizeof(float));
const float *A_block = (const float *)((const char *)src3 + blockIdx.y * splitD * src3_nb1);
const float *B_block = (const float *)((const char *)src4 + (blockIdx.x * src4_nb3));
const float *C_block = (const float *)((const char *)src5 + (blockIdx.x * src5_nb3));
float *y_block = (float *)((char *)dst + (blockIdx.x * d_inner * L * sizeof(float)) + blockIdx.y * splitD * sizeof(float));
float *s_block = (float *)((char *)dst + s_off + blockIdx.x * src0_nb3 + blockIdx.y * splitD * src0_nb2);
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
const int bidx = blockIdx.x; // split along B (sequences)
const int bidy = blockIdx.y; // split along D (d_inner)
const int tid = threadIdx.x;
const int wid = tid / 32;
const int wtid = tid % 32;
extern __shared__ float smem[];
const int stride_sA = N + 1;
const int stride_ss0 = N + 1;
float * smem_A = smem;
float * smem_s0 = smem_A + splitD * stride_sA;
const float * s0_block = (const float *) ((const char *) src0 + src6[bidx] * src0_nb3 + bidy * splitD * src0_nb2);
const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb3) + bidy * splitD * sizeof(float));
const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1);
const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb3));
const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb3));
float * y_block = (float *) ((char *) dst + (bidx * d_inner * L * sizeof(float)) + bidy * splitD * sizeof(float));
float * s_block = (float *) ((char *) dst + s_off + bidx * src0_nb3 + bidy * splitD * src0_nb2);
const int stride_s0 = src0_nb2 / sizeof(float);
const int stride_x = src1_nb2 / sizeof(float);
const int stride_x = src1_nb2 / sizeof(float);
const int stride_dt = src2_nb1 / sizeof(float);
const int stride_A = src3_nb1 / sizeof(float);
const int stride_B = src4_nb2 / sizeof(float);
const int stride_C = src5_nb2 / sizeof(float);
const int stride_s = stride_s0;
const int stride_y = d_inner;
const int stride_B = src4_nb2 / sizeof(float);
const int stride_C = src5_nb2 / sizeof(float);
const int stride_y = d_inner;
// can N not be 16? for example 32?
if (N == 16) {
float regA[N];
float regs0[N];
__shared__ float smemB[N];
__shared__ float smemC[N];
#ifdef USE_CUB
using BlockLoad = cub::BlockLoad<float, splitD, N, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
using BlockStore = cub::BlockStore<float, splitD, N, cub::BLOCK_STORE_WARP_TRANSPOSE>;
union CubTempStorage {
typename BlockLoad::TempStorage load_temp;
typename BlockStore::TempStorage store_temp;
};
__shared__ CubTempStorage cub_temp_storage;
BlockLoad(cub_temp_storage.load_temp).Load(A_block, regA);
BlockLoad(cub_temp_storage.load_temp).Load(s0_block, regs0);
#else
const int stride_s0 = src0_nb2 / sizeof(float);
const int stride_A = src3_nb1 / sizeof(float);
#pragma unroll
for (size_t i = 0; i < splitD / 4; i += 2) {
float value = A_block[(wid * warp_size + i) * stride_A + wtid];
// todo: bank conflict
// I am always confused with how to use the swizzling method to solve
// bank conflit. Hoping somebody can tell me.
smem_A[(wid * warp_size + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
}
#pragma unroll
for (size_t i = 0; i < splitD / 4; i += 2) {
float value = s0_block[(wid * warp_size + i) * stride_s0 + wtid];
smem_s0[(wid * warp_size + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
}
for (size_t n = 0; n < N; ++n)
{
regA[n] = A_block[threadIdx.x * stride_A + n];
regs0[n] = s0_block[threadIdx.x * stride_s0 + n];
}
#endif
__syncthreads();
for (int64_t i = 0; i < L; i++) {
float dt_soft_plus = dt_block[i * stride_dt + tid];
if (dt_soft_plus <= 20.0f) {
dt_soft_plus = log1pf(exp(dt_soft_plus));
}
float x_dt = x_block[i * stride_x + tid] * dt_soft_plus;
float sumf = 0.0f;
#pragma unroll
for (size_t j = 0; j < N; j++) {
float state = (smem_s0[tid * stride_ss0 + j] * expf(dt_soft_plus * smem_A[tid * stride_sA + j])) +
(B_block[i * stride_B + j] * x_dt);
sumf += state * C_block[i * stride_C + j];
if (i == L - 1) {
s_block[tid * stride_s + j] = state;
} else {
smem_s0[tid * stride_ss0 + j] = state;
}
for (size_t i = 0; i < L; i++)
{
if (threadIdx.x < N)
{
smemB[threadIdx.x] = B_block[i * stride_B + threadIdx.x];
smemC[threadIdx.x] = C_block[i * stride_C + threadIdx.x];
}
__syncthreads();
y_block[i * stride_y + tid] = sumf;
float dt_soft_plus = dt_block[i * stride_dt + threadIdx.x];
if (dt_soft_plus <= 20.0f)
{
dt_soft_plus = log1pf(expf(dt_soft_plus));
}
float x_dt = x_block[i * stride_x + threadIdx.x] * dt_soft_plus;
float sumf = 0.0f;
#pragma unroll
for (size_t n = 0; n < N; n++)
{
float state = regs0[n] * expf(dt_soft_plus * regA[n]) + smemB[n] * x_dt;
sumf += state * smemC[n];
regs0[n] = state;
}
y_block[i * stride_y + threadIdx.x] = sumf;
}
#ifdef USE_CUB
BlockStore(cub_temp_storage.store_temp).Store(s_block, regs0);
#else
const int stride_s = stride_s0;
#pragma unroll
for (size_t n = 0; n < N; ++n)
{
s_block[threadIdx.x * stride_s + n] = regs0[n];
}
#endif
}
#ifdef __clang__
#pragma clang diagnostic pop
#endif // __clang__
// assumes as many threads as d_state
template <int splitH, int d_state>
@@ -201,11 +231,11 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
cudaStream_t stream) {
const int threads = 128;
// NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
if (src3_nb1 == sizeof(float)) {
// Mamba-2
if (d_state == 128) {
const int threads = 128;
GGML_ASSERT(d_state % threads == 0);
// NOTE: can be any power of two between 4 and 64
const int splitH = 16;
@@ -229,7 +259,6 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
GGML_ABORT("doesn't support d_state!=(128 or 256).");
}
} else {
const int threads = 128;
// Mamba-1
GGML_ASSERT(n_head % threads == 0);
GGML_ASSERT(head_dim == 1);
@@ -237,10 +266,63 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1);
const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float);
if (d_state == 16) {
ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
switch (n_tok)
{
case 1:
ssm_scan_f32<threads, 16, 1><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
break;
case 2:
ssm_scan_f32<threads, 16, 2><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
break;
case 3:
ssm_scan_f32<threads, 16, 3><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
break;
case 4:
ssm_scan_f32<threads, 16, 4><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
break;
case 5:
ssm_scan_f32<threads, 16, 5><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
break;
case 6:
ssm_scan_f32<threads, 16, 6><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
break;
case 7:
ssm_scan_f32<threads, 16, 7><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
break;
case 8:
ssm_scan_f32<threads, 16, 8><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
break;
default:
ssm_scan_f32<threads, 16, 0><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
break;
}
} else {
GGML_ABORT("doesn't support d_state!=16.");
}
@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmq.cuh"
DECL_MMQ_CASE(GGML_TYPE_MXFP4);
+75
View File
@@ -300,6 +300,81 @@ void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst
ggml_cuda_op_unary_gated<op_gelu_quick>(ctx, dst);
}
// swiglu_oai
template <typename T>
static __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, float alpha, float limit) {
const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
// perform base op and multiply with gate (either offset in same tensor or a separate one)
const int64_t j0 = (i / n) * o0 + (i % n);
const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
float xi = x[j0];
float gi = g[j1];
xi = fminf(xi, limit);
gi = fmaxf(fminf(gi, limit), -limit);
float out_glu = xi / (1.0f + expf(-xi * alpha));
out_glu = out_glu * (1.0f + gi);
dst[i] = out_glu;
}
template <typename T>
static void swiglu_oai_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream) {
const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE;
swiglu_oai_kernel<<<num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream>>>(x, g, dst, k, n, o0, o1, alpha, limit);
}
void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
void * src0_d = src0->data;
void * src1_d = src1 ? src1->data : src0->data;
const int64_t src0_o = src0->nb[1];
const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
void * dst_d = dst->data;
const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(src0->nb[0] == ggml_element_size(src0));
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == dst->type);
GGML_ASSERT(dst->ne[0] == nc);
GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
if (src1) {
GGML_ASSERT(ggml_is_contiguous_1(src1));
GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
GGML_ASSERT(src1->ne[0] == nc);
GGML_ASSERT(src0->type == src1->type);
}
//const int32_t swapped = ((const int32_t *) dst->op_params)[1];
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
const float alpha = ggml_get_op_params_f32(dst, 2);
const float limit = ggml_get_op_params_f32(dst, 3);
float * src0_p = (float *) src0_d;
float * src1_p = (float *) src1_d;
if (!src1) {
src0_p += swapped ? nc : 0;
src1_p += swapped ? 0 : nc;
}
swiglu_oai_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream);
}
/* silu_back */
static __device__ __forceinline__ float op_silu_back(float grad, float x) {
+2
View File
@@ -67,6 +67,8 @@ void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+52 -16
View File
@@ -1,8 +1,20 @@
#pragma once
#include "common.cuh"
#include <cstdint>
static __device__ __forceinline__ int get_int_b1(const void * x, const int & i32) {
const uint8_t * x8 = (const uint8_t *) x;
int x32 = x8[4*i32 + 0] << 0;
x32 |= x8[4*i32 + 1] << 8;
x32 |= x8[4*i32 + 2] << 16;
x32 |= x8[4*i32 + 3] << 24;
return x32;
}
static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) {
const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
@@ -16,6 +28,20 @@ static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32
return ((const int *) x)[i32]; // assume at least 4 byte alignment
}
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * table) {
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
const int8_t * q0_8 = (const int8_t *) &q0_32;
const char4 val0_8 = make_char4(
table[q0_8[0]], table[q0_8[1]], table[q0_8[2]], table[q0_8[3]]);
const int q1_32 = (q4 >> 4) & 0x0F0F0F0F;
const int8_t * q1_8 = (const int8_t *) &q1_32;
const char4 val1_8 = make_char4(
table[q1_8[0]], table[q1_8[1]], table[q1_8[2]], table[q1_8[3]]);
return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
}
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
@@ -211,6 +237,30 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_16_q8_1_
return d8_1*sumf;
}
#define VDR_MXFP4_Q8_1_MMVQ 2
#define VDR_MXFP4_Q8_1_MMQ 4
static __device__ __forceinline__ float vec_dot_mxfp4_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
const block_mxfp4 * bq4 = (const block_mxfp4 *) vbq + kbx;
const int * q8 = (const int *) bq8_1->qs + iqs;
int sumi = 0;
#pragma unroll
for (int l = 0; l < VDR_MXFP4_Q8_1_MMVQ; ++l) {
const int aux_q4 = get_int_b1(bq4->qs, iqs + l);
const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
}
const float d = ggml_cuda_e8m0_to_fp32(bq4->e) * 0.5f * __low2float(bq8_1->ds);
return d * sumi;
}
#define VDR_Q2_K_Q8_1_MMVQ 1
#define VDR_Q2_K_Q8_1_MMQ 4
@@ -1068,20 +1118,6 @@ static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1);
}
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) {
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
const int8_t * q0_8 = (const int8_t *) &q0_32;
const char4 val0_8 = make_char4(
kvalues_iq4nl[q0_8[0]], kvalues_iq4nl[q0_8[1]], kvalues_iq4nl[q0_8[2]], kvalues_iq4nl[q0_8[3]]);
const int q1_32 = (q4 >> 4) & 0x0F0F0F0F;
const int8_t * q1_8 = (const int8_t *) &q1_32;
const char4 val1_8 = make_char4(
kvalues_iq4nl[q1_8[0]], kvalues_iq4nl[q1_8[1]], kvalues_iq4nl[q1_8[2]], kvalues_iq4nl[q1_8[3]]);
return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
}
#define VDR_IQ4_NL_Q8_1_MMVQ 2
#define VDR_IQ4_NL_Q8_1_MMQ 4
@@ -1096,7 +1132,7 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
#pragma unroll
for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
const int aux_q4 = get_int_b2(bq4->qs, iqs + l);
const int2 v = get_int_from_table_16(aux_q4);
const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
@@ -1118,7 +1154,7 @@ static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
#pragma unroll
for (int j = 0; j < 4; ++j) {
const int aux_q4 = get_int_b4(bq4->qs, iqs + j);
const int2 v = get_int_from_table_16(aux_q4);
const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
const int u0 = get_int_b4(bq8_1[iqs/4].qs, j + 0);
const int u1 = get_int_b4(bq8_1[iqs/4].qs, j + 4);
+4
View File
@@ -6,6 +6,10 @@
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#if CUDART_VERSION >= 12050
#include <cuda_fp8.h>
#endif // CUDART_VERSION >= 12050
#if CUDART_VERSION < 11020
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
+1
View File
@@ -200,6 +200,7 @@
#endif
typedef hip_bfloat16 nv_bfloat16;
typedef short2 nv_bfloat162; // FIXME there is no 2x BF16 type being defined in bfloat16.h, ad-hoc compilation fix
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
+2 -1
View File
@@ -137,4 +137,5 @@
#define cudaStreamEndCapture musaStreamEndCapture
#define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor
typedef mt_bfloat16 nv_bfloat16;
typedef __mt_bfloat16 nv_bfloat16;
typedef __mt_bfloat162 nv_bfloat162;
+4
View File
@@ -121,6 +121,10 @@ if (GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 OR ${hip_VERSION} VERSION_GREATER_EQUAL 7
add_compile_definitions(GGML_HIP_ROCWMMA_FATTN_GFX12)
endif()
if (GGML_HIP_EXPORT_METRICS)
set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -Rpass-analysis=kernel-resource-usage --save-temps")
endif()
if (NOT GGML_CUDA_FA)
add_compile_definitions(GGML_CUDA_NO_FA)
endif()
+61
View File
@@ -410,6 +410,67 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
#define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
static inline float ggml_e8m0_to_fp32(uint8_t x) {
uint32_t bits; // Stores the raw bit representation of the float
// Handle special case for minimum exponent (denormalized float)
if (x == 0) {
// Bit pattern for 2^(-127):
// - Sign bit: 0 (positive)
// - Exponent: 0 (denormalized number)
// - Mantissa: 0x400000 (0.5 in fractional form)
// Value = 0.5 * 2^(-126) = 2^(-127)
bits = 0x00400000;
}
// note: disabled as we don't need to handle NaNs
//// Handle special case for NaN (all bits set)
//else if (x == 0xFF) {
// // Standard quiet NaN pattern:
// // - Sign bit: 0
// // - Exponent: all 1s (0xFF)
// // - Mantissa: 0x400000 (quiet NaN flag)
// bits = 0x7FC00000;
//}
// Normalized values (most common case)
else {
// Construct normalized float by shifting exponent into position:
// - Exponent field: 8 bits (positions 30-23)
// - Mantissa: 0 (implicit leading 1)
// Value = 2^(x - 127)
bits = (uint32_t) x << 23;
}
float result; // Final float value
// Safely reinterpret bit pattern as float without type-punning issues
memcpy(&result, &bits, sizeof(float));
return result;
}
// Equal to ggml_e8m0_to_fp32/2
// Useful with MXFP4 quantization since the E0M2 values are doubled
static inline float ggml_e8m0_to_fp32_half(uint8_t x) {
uint32_t bits;
// For x < 2: use precomputed denormal patterns
if (x < 2) {
// 0x00200000 = 2^(-128), 0x00400000 = 2^(-127)
bits = 0x00200000 << x;
}
// For x >= 2: normalized exponent adjustment
else {
// 0.5 * 2^(x-127) = 2^(x-128) = normalized with exponent (x-1)
bits = (uint32_t)(x - 1) << 23;
}
// Note: NaNs are not handled here
float result;
memcpy(&result, &bits, sizeof(float));
return result;
}
#define GGML_E8M0_TO_FP32(x) ggml_e8m0_to_fp32(x)
#define GGML_E8M0_TO_FP32_HALF(x) ggml_e8m0_to_fp32_half(x)
/**
* Converts brain16 to float32.
*
+14
View File
@@ -23,6 +23,9 @@
#define N_R0_Q8_0 4
#define N_SG_Q8_0 2
#define N_R0_MXFP4 2
#define N_SG_MXFP4 2
#define N_R0_Q2_K 4
#define N_SG_Q2_K 2
@@ -129,6 +132,15 @@ typedef struct {
uint64_t o1[8];
} ggml_metal_kargs_bin;
typedef struct {
int64_t ne0;
int64_t ne1;
size_t nb01;
size_t nb02;
size_t nb11;
size_t nb21;
} ggml_metal_kargs_add_id;
typedef struct {
int32_t ne00;
int32_t ne01;
@@ -444,6 +456,8 @@ typedef struct{
uint64_t nb1;
int32_t i00;
int32_t i10;
float alpha;
float limit;
} ggml_metal_kargs_glu;
typedef struct {
+109 -9
View File
@@ -195,6 +195,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_ROW_C4,
GGML_METAL_KERNEL_TYPE_DIV,
GGML_METAL_KERNEL_TYPE_DIV_ROW_C4,
GGML_METAL_KERNEL_TYPE_ADD_ID,
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
GGML_METAL_KERNEL_TYPE_REPEAT_I32,
@@ -234,6 +235,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,
GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4,
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,
@@ -286,6 +288,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
@@ -310,6 +313,10 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4,
@@ -351,6 +358,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,
@@ -373,6 +381,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,
@@ -397,6 +406,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16,
@@ -579,6 +589,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_REGLU,
GGML_METAL_KERNEL_TYPE_GEGLU,
GGML_METAL_KERNEL_TYPE_SWIGLU,
GGML_METAL_KERNEL_TYPE_SWIGLU_OAI,
GGML_METAL_KERNEL_TYPE_GEGLU_ERF,
GGML_METAL_KERNEL_TYPE_GEGLU_QUICK,
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
@@ -1199,6 +1210,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, mul_row_c4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, div_row_c4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ID, add_id, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
@@ -1238,6 +1250,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4, get_rows_mxfp4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
@@ -1290,6 +1303,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, mul_mv_mxfp4_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
@@ -1314,6 +1328,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2, mul_mv_ext_mxfp4_f32_r1_2, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3, mul_mv_ext_mxfp4_f32_r1_3, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4, mul_mv_ext_mxfp4_f32_r1_4, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5, mul_mv_ext_mxfp4_f32_r1_5, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction);
@@ -1355,6 +1373,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32, mul_mv_id_mxfp4_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, has_simdgroup_reduction);
@@ -1377,6 +1396,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm);
@@ -1401,6 +1422,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, mul_mm_id_q5_0_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, mul_mm_id_q5_1_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, mul_mm_id_q8_0_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16, mul_mm_id_mxfp4_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, mul_mm_id_q2_K_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, mul_mm_id_q3_K_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, mul_mm_id_q4_K_f16, has_simdgroup_mm);
@@ -1583,6 +1605,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU_OAI, swiglu_oai, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF, geglu_erf, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, geglu_quick, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
@@ -1774,6 +1797,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_SWIGLU_OAI:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
@@ -1791,6 +1815,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_ADD_ID:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ACC:
case GGML_OP_REPEAT:
@@ -2042,6 +2067,7 @@ static int ggml_metal_encode_node(
const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT;
const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
size_t offs_src0 = 0;
@@ -2291,6 +2317,38 @@ static int ggml_metal_encode_node(
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
}
} break;
case GGML_OP_ADD_ID:
{
GGML_ASSERT(src0t == GGML_TYPE_F32);
GGML_ASSERT(src1t == GGML_TYPE_F32);
GGML_ASSERT(src2t == GGML_TYPE_I32);
GGML_ASSERT(dstt == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous_rows(src0));
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ID].pipeline;
ggml_metal_kargs_add_id args = {
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb11 =*/ nb11,
/*.nb21 =*/ nb21,
};
[encoder setComputePipelineState:pipeline];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_REPEAT:
{
id<MTLComputePipelineState> pipeline;
@@ -2710,6 +2768,9 @@ static int ggml_metal_encode_node(
case GGML_GLU_OP_SWIGLU:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
break;
case GGML_GLU_OP_SWIGLU_OAI:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU_OAI].pipeline;
break;
case GGML_GLU_OP_GEGLU_ERF:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_ERF].pipeline;
break;
@@ -2720,7 +2781,9 @@ static int ggml_metal_encode_node(
GGML_ABORT("fatal error");
}
const int32_t swp = ((const int32_t *) dst->op_params)[1];
const int32_t swp = ggml_get_op_params_i32(dst, 1);
const float alpha = ggml_get_op_params_f32(dst, 2);
const float limit = ggml_get_op_params_f32(dst, 3);
const int32_t i00 = swp ? ne0 : 0;
const int32_t i10 = swp ? 0 : ne0;
@@ -2734,6 +2797,8 @@ static int ggml_metal_encode_node(
/*.nb1 =*/ nb1,
/*.i00 =*/ src1 ? 0 : i00,
/*.i10 =*/ src1 ? 0 : i10,
/*.alpha=*/ alpha,
/*.limit=*/ limit
};
[encoder setComputePipelineState:pipeline];
@@ -2992,8 +3057,13 @@ static int ggml_metal_encode_node(
} else {
[encoder setBuffer:h_src0 offset:offs_src0 atIndex:1];
}
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&args length:sizeof(args) atIndex:3];
if (id_src2) {
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
} else {
[encoder setBuffer:h_src0 offset:offs_src0 atIndex:2];
}
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
[encoder setBytes:&args length:sizeof(args) atIndex:4];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
@@ -3291,6 +3361,7 @@ static int ggml_metal_encode_node(
src0t == GGML_TYPE_Q5_0 ||
src0t == GGML_TYPE_Q5_1 ||
src0t == GGML_TYPE_Q8_0 ||
src0t == GGML_TYPE_MXFP4 ||
src0t == GGML_TYPE_IQ4_NL ||
false) && (ne11 >= 2 && ne11 <= 8)
) ||
@@ -3383,6 +3454,14 @@ static int ggml_metal_encode_node(
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break;
default: GGML_ABORT("not implemented");
} break;
case GGML_TYPE_MXFP4:
switch (r1ptg) {
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2].pipeline; break;
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3].pipeline; break;
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4].pipeline; break;
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5].pipeline; break;
default: GGML_ABORT("not implemented");
} break;
case GGML_TYPE_Q4_K:
switch (r1ptg) {
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2].pipeline; break;
@@ -3481,6 +3560,7 @@ static int ggml_metal_encode_node(
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break;
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
@@ -3623,6 +3703,13 @@ static int ggml_metal_encode_node(
nr0 = N_R0_Q8_0;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
} break;
case GGML_TYPE_MXFP4:
{
nsg = N_SG_MXFP4;
nr0 = N_R0_MXFP4;
smem = 32*sizeof(float);
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32].pipeline;
} break;
case GGML_TYPE_Q2_K:
{
nsg = N_SG_Q2_K;
@@ -3756,8 +3843,6 @@ static int ggml_metal_encode_node(
case GGML_OP_MUL_MAT_ID:
{
// src2 = ids
const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
GGML_ASSERT(src2t == GGML_TYPE_I32);
GGML_ASSERT(!ggml_is_transposed(src0));
@@ -3883,6 +3968,7 @@ static int ggml_metal_encode_node(
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16 ].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16 ].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16 ].pipeline; break;
case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16 ].pipeline; break;
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16 ].pipeline; break;
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16 ].pipeline; break;
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16 ].pipeline; break;
@@ -4018,6 +4104,13 @@ static int ggml_metal_encode_node(
nr0 = N_R0_Q8_0;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
} break;
case GGML_TYPE_MXFP4:
{
nsg = N_SG_MXFP4;
nr0 = N_R0_MXFP4;
smem = 32*sizeof(float);
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32].pipeline;
} break;
case GGML_TYPE_Q2_K:
{
nsg = N_SG_Q2_K;
@@ -4170,6 +4263,7 @@ static int ggml_metal_encode_node(
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break;
case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4 ].pipeline; break;
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break;
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break;
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break;
@@ -4980,11 +5074,14 @@ static int ggml_metal_encode_node(
GGML_ASSERT(ne11 == ne21);
GGML_ASSERT(ne12 == ne22);
struct ggml_tensor * src3 = node->src[3];
struct ggml_tensor * src3 = node->src[3]; // mask
struct ggml_tensor * src4 = node->src[4]; // sinks
size_t offs_src3 = 0;
size_t offs_src4 = 0;
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
@@ -5000,8 +5097,6 @@ static int ggml_metal_encode_node(
const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32);
const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33);
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
float scale;
float max_bias;
float logit_softcap;
@@ -5389,7 +5484,12 @@ static int ggml_metal_encode_node(
} else {
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:4];
}
[encoder setBuffer:id_dst offset:offs_dst atIndex:5];
if (id_src4) {
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:5];
} else {
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:5];
}
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
if (!use_vec_kernel) {
// half8x8 kernel
+272 -15
View File
@@ -35,6 +35,10 @@ constexpr constant static float kvalues_iq4nl_f[16] = {
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
};
constexpr constant static float kvalues_mxfp4_f[16] = {
0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f
};
static inline int best_index_int8(int n, constant float * val, float x) {
if (x <= val[0]) return 0;
if (x >= val[n-1]) return n-1;
@@ -46,6 +50,18 @@ static inline int best_index_int8(int n, constant float * val, float x) {
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
}
static inline float e8m0_to_fp32(uint8_t x) {
uint32_t bits;
if (x == 0) {
bits = 0x00400000;
} else {
bits = (uint32_t) x << 23;
}
return as_type<float>(bits);
}
// NOTE: this is not dequantizing - we are simply fitting the template
template <typename type4x4>
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
@@ -242,6 +258,27 @@ void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
}
}
void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max
for (int j = 0; j < QK8_0; j++) {
const float v = src[j];
amax = MAX(amax, fabs(v));
}
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
dst.d = d;
for (int j = 0; j < QK8_0; ++j) {
const float x0 = src[j]*id;
dst.qs[j] = round(x0);
}
}
void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max
@@ -462,25 +499,34 @@ void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & re
}
}
void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max
template <typename type4x4>
void dequantize_mxfp4(device const block_mxfp4 * xb, short il, thread type4x4 & reg) {
device const uint8_t * q2 = (device const uint8_t *)xb->qs;
for (int j = 0; j < QK8_0; j++) {
const float v = src[j];
amax = MAX(amax, fabs(v));
const float d = e8m0_to_fp32(xb->e);
const uint8_t shr = il >= 1 ? 4 : 0;
for (int i = 0; i < 4; ++i) {
reg[i][0] = d * kvalues_mxfp4_f[(q2[4*i + 0] >> shr) & 0x0F];
reg[i][1] = d * kvalues_mxfp4_f[(q2[4*i + 1] >> shr) & 0x0F];
reg[i][2] = d * kvalues_mxfp4_f[(q2[4*i + 2] >> shr) & 0x0F];
reg[i][3] = d * kvalues_mxfp4_f[(q2[4*i + 3] >> shr) & 0x0F];
}
}
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
template <typename type4>
void dequantize_mxfp4_t4(device const block_mxfp4 * xb, short il, thread type4 & reg) {
device const uint8_t * q2 = (device const uint8_t *)xb->qs;
dst.d = d;
const float d = e8m0_to_fp32(xb->e);
const short il4 = il%4;
for (int j = 0; j < QK8_0; ++j) {
const float x0 = src[j]*id;
const uint8_t shr = il >= 4 ? 4 : 0;
dst.qs[j] = round(x0);
}
reg[0] = d * kvalues_mxfp4_f[(q2[4*il4 + 0] >> shr) & 0x0F];
reg[1] = d * kvalues_mxfp4_f[(q2[4*il4 + 1] >> shr) & 0x0F];
reg[2] = d * kvalues_mxfp4_f[(q2[4*il4 + 2] >> shr) & 0x0F];
reg[3] = d * kvalues_mxfp4_f[(q2[4*il4 + 3] >> shr) & 0x0F];
}
template <typename type4x4>
@@ -960,6 +1006,32 @@ kernel void kernel_div(
}
}
kernel void kernel_add_id(
constant ggml_metal_kargs_add_id & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i1 = tgpig.x;
const int i2 = tgpig.y;
const int i11 = *((device const int32_t *) (src2 + i1*sizeof(int32_t) + i2*args.nb21));
const size_t nb1 = args.ne0 * sizeof(float);
const size_t nb2 = args.ne1 * nb1;
device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2);
device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02);
device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
dst_row[i0] = src0_row[i0] + src1_row[i0];
}
}
template<typename T>
kernel void kernel_repeat(
constant ggml_metal_kargs_repeat & args,
@@ -1431,6 +1503,32 @@ kernel void kernel_swiglu(
}
}
kernel void kernel_swiglu_oai(
device const char * src0,
device const char * src1,
device char * dst,
constant ggml_metal_kargs_glu & args,
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
float x0 = src0_row[i0];
float x1 = src1_row[i0];
x0 = min(x0, args.limit);
x1 = max(min(x1, args.limit), -args.limit);
float out_glu = x0 / (1.0f + exp(-x0 * args.alpha));
out_glu = out_glu * (1.0f + x1);
dst_row[i0] = out_glu;
}
}
kernel void kernel_geglu_erf(
device const char * src0,
device const char * src1,
@@ -1534,6 +1632,7 @@ template<typename T>
kernel void kernel_soft_max(
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
constant ggml_metal_kargs_soft_max & args,
threadgroup float * buf [[threadgroup(0)]],
@@ -1552,6 +1651,7 @@ kernel void kernel_soft_max(
device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
device const float * psrc2 = src2 != src0 ? (device const float *) (src2) : nullptr;
device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
float slope = 1.0f;
@@ -1567,7 +1667,7 @@ kernel void kernel_soft_max(
}
// parallel max
float lmax = -INFINITY;
float lmax = psrc2 ? psrc2[i02] : -INFINITY;
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
@@ -1623,6 +1723,10 @@ kernel void kernel_soft_max(
sum = simd_sum(sum);
}
if (psrc2) {
sum += exp(psrc2[i02] - max_val);
}
const float inv_sum = 1.0f/sum;
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
@@ -1634,6 +1738,7 @@ template<typename T>
kernel void kernel_soft_max_4(
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
constant ggml_metal_kargs_soft_max & args,
threadgroup float * buf [[threadgroup(0)]],
@@ -1652,6 +1757,7 @@ kernel void kernel_soft_max_4(
device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
device const float * psrc2 = src2 != src0 ? (device const float * ) (src2) : nullptr;
device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
float slope = 1.0f;
@@ -1666,7 +1772,7 @@ kernel void kernel_soft_max_4(
}
// parallel max
float4 lmax4 = -INFINITY;
float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
@@ -1725,6 +1831,10 @@ kernel void kernel_soft_max_4(
sum = simd_sum(sum);
}
if (psrc2) {
sum += exp(psrc2[i02] - max_val);
}
const float inv_sum = 1.0f/sum;
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
@@ -3106,6 +3216,11 @@ template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_3")]] kernel mul_mv_ext_q4
template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q8_0, 32, dequantize_q8_0_t4>;
template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q8_0, 32, dequantize_q8_0_t4>;
template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_mxfp4, 32, dequantize_mxfp4_t4>;
template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_mxfp4, 32, dequantize_mxfp4_t4>;
template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_mxfp4, 32, dequantize_mxfp4_t4>;
template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_mxfp4, 32, dequantize_mxfp4_t4>;
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
@@ -4092,6 +4207,7 @@ kernel void kernel_flash_attn_ext(
device const char * k,
device const char * v,
device const char * mask,
device const char * sinks,
device char * dst,
threadgroup half * shmem_f16 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
@@ -4407,6 +4523,35 @@ kernel void kernel_flash_attn_ext(
}
}
if (sinks != q && sgitg == 0) {
for (ushort j = 0; j < Q; ++j) {
const float m = M[j];
const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
M[j] = simd_max(max(M[j], s));
const float ms = exp(m - M[j]);
const float vs = exp(s - M[j]);
S[j] = S[j]*ms + simd_sum(vs);
if (tiisg == j) {
ss[j*TS + 2*C + j] = ms;
}
}
// O = diag(ms)*O
{
s8x8_t ms;
simdgroup_load(ms, ss + 2*C, TS, 0, false);
#pragma unroll(DV8)
for (short i = 0; i < DV8; ++i) {
simdgroup_multiply(lo[i], ms, lo[i]);
}
}
}
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
for (short j = tiisg; j < Q; j += NW) {
ss[j*TS + 0] = S[j];
@@ -4618,6 +4763,7 @@ kernel void kernel_flash_attn_ext_vec(
device const char * k,
device const char * v,
device const char * mask,
device const char * sinks,
device char * dst,
threadgroup half * shmem_f16 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
@@ -4835,6 +4981,23 @@ kernel void kernel_flash_attn_ext_vec(
}
}
if (sinks != q && sgitg == 0) {
const float m = M;
const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
M = simd_max(max(M, s));
const float ms = exp(m - M);
const float vs = exp(s - M);
S = S*ms + simd_sum(vs);
#pragma unroll(DV4/NL)
for (short ii = 0; ii < DV4; ii += NL) {
lo[ii/NL] *= ms;
}
}
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
if (tiisg == 0) {
ss[0] = (s_t) S;
@@ -6940,6 +7103,95 @@ kernel void kernel_mul_mv_iq4_xs_f32(
kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, N_SG_IQ4_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template<int nr0, int nsg, int nw, typename args_t>
void kernel_mul_mv_mxfp4_f32_impl(
args_t args,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
const int nb = args.ne00/QK_MXFP4;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * nsg + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
const short ix = tiisg/2; // 0...15
const short it = tiisg%2; // 0 or 1
shmem_f32[tiisg] = kvalues_mxfp4_f[tiisg%16];
threadgroup_barrier(mem_flags::mem_threadgroup);
float4 yl[4];
float sumf[nr0]={0.f};
device const float * yb = y + ix * QK_MXFP4 + it * 8;
for (int ib = ix; ib < nb; ib += 16) {
device const float4 * y4 = (device const float4 *)yb;
yl[0] = y4[0];
yl[1] = y4[4];
yl[2] = y4[1];
yl[3] = y4[5];
#pragma unroll(nr0)
for (short row = 0; row < nr0; row++) {
device const block_mxfp4 & xb = x[row*nb + ib];
device const uint8_t * q2 = (device const uint8_t *)(xb.qs + 8*it);
float4 acc1 = yl[0]*float4(shmem_f32[q2[0] & 0x0F], shmem_f32[q2[1] & 0x0F], shmem_f32[q2[2] & 0x0F], shmem_f32[q2[3] & 0x0F]);
float4 acc2 = yl[1]*float4(shmem_f32[q2[0] >> 4 ], shmem_f32[q2[1] >> 4 ], shmem_f32[q2[2] >> 4 ], shmem_f32[q2[3] >> 4 ]);
float4 acc3 = yl[2]*float4(shmem_f32[q2[4] & 0x0F], shmem_f32[q2[5] & 0x0F], shmem_f32[q2[6] & 0x0F], shmem_f32[q2[7] & 0x0F]);
float4 acc4 = yl[3]*float4(shmem_f32[q2[4] >> 4 ], shmem_f32[q2[5] >> 4 ], shmem_f32[q2[6] >> 4 ], shmem_f32[q2[7] >> 4 ]);
acc1 = (acc1 + acc3) + (acc2 + acc4);
sumf[row] += e8m0_to_fp32(xb.e) * ((acc1[0] + acc1[1]) + (acc1[2] + acc1[3]));
}
yb += 16 * QK_MXFP4;
}
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = sum_all;
}
}
}
[[host_name("kernel_mul_mv_mxfp4_f32")]]
kernel void kernel_mul_mv_mxfp4_f32(
constant ggml_metal_kargs_mul_mv & args,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, N_SG_MXFP4, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
kernel void kernel_get_rows_q(
constant ggml_metal_kargs_get_rows & args,
@@ -7475,6 +7727,7 @@ template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_get_rows_mxfp4")]] kernel get_rows_q_t kernel_get_rows_q<block_mxfp4, 2, dequantize_mxfp4>;
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
@@ -7527,6 +7780,7 @@ template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_m
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4>;
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
@@ -7558,6 +7812,7 @@ template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_m
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4>;
template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
@@ -7703,6 +7958,8 @@ template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SG_Q5_0, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SG_Q5_1, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, N_SG_MXFP4, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl <N_R0_Q2_K, N_SG_Q2_K, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl <N_R0_Q3_K, N_SG_Q3_K, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl <N_R0_Q4_K, N_SG_Q4_K, N_SIMDWIDTH>>>;
+1
View File
@@ -55,6 +55,7 @@ endfunction()
set(GGML_OPENCL_KERNELS
add
add_id
argsort
clamp
cpy
+145 -27
View File
@@ -345,6 +345,7 @@ struct ggml_backend_opencl_context {
cl_command_queue queue;
cl_program program_add;
cl_program program_add_id;
cl_program program_clamp;
cl_program program_cpy;
cl_program program_cvt;
@@ -404,6 +405,7 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_mul, kernel_mul_row, kernel_mul_f16, kernel_mul_row_f16;
cl_kernel kernel_div, kernel_div_row, kernel_div_f16, kernel_div_row_f16;
cl_kernel kernel_sub, kernel_sub_row, kernel_sub_f16, kernel_sub_row_f16;
cl_kernel kernel_add_id;
cl_kernel kernel_scale;
cl_kernel kernel_silu, kernel_silu_4;
cl_kernel kernel_gelu, kernel_gelu_4;
@@ -412,7 +414,7 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_relu;
cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16;
cl_kernel kernel_clamp;
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_geglu_erf, kernel_geglu_quick,
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_swiglu_oai, kernel_geglu_erf, kernel_geglu_quick,
kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
cl_kernel kernel_norm;
cl_kernel kernel_rms_norm, kernel_rms_norm_mul;
@@ -600,6 +602,7 @@ struct ggml_backend_opencl_context {
if (ref_count == 0) {
#ifdef GGML_OPENCL_PROFILING
write_profiling_info();
profiling_info.clear();
#endif
}
}
@@ -681,6 +684,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
GGML_LOG_CONT(".");
}
// add_id
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "add_id.cl.h"
};
#else
const std::string kernel_src = read_file("add_id.cl");
#endif
backend_ctx->program_add_id =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_add_id = clCreateKernel(backend_ctx->program_add_id, "kernel_add_id", &err), err));
GGML_LOG_CONT(".");
}
// clamp
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -787,6 +806,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
CL_CHECK((backend_ctx->kernel_geglu = clCreateKernel(backend_ctx->program_glu, "kernel_geglu", &err), err));
CL_CHECK((backend_ctx->kernel_reglu = clCreateKernel(backend_ctx->program_glu, "kernel_reglu", &err), err));
CL_CHECK((backend_ctx->kernel_swiglu = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu", &err), err));
CL_CHECK((backend_ctx->kernel_swiglu_oai = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu_oai", &err), err));
CL_CHECK((backend_ctx->kernel_geglu_erf = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_erf", &err), err));
CL_CHECK((backend_ctx->kernel_geglu_quick = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_quick", &err), err));
CL_CHECK((backend_ctx->kernel_geglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_f16", &err), err));
@@ -2046,8 +2066,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
backend_ctx->adreno_cl_compiler_version = get_adreno_cl_compiler_version(driver_version);
backend_ctx->has_vector_subgroup_broadcast =
backend_ctx->adreno_cl_compiler_version.major >= 47 ||
backend_ctx->adreno_cl_compiler_version.major == 17;
(backend_ctx->adreno_cl_compiler_version.type == E031 && backend_ctx->adreno_cl_compiler_version.major >= 47) ||
(backend_ctx->adreno_cl_compiler_version.type == DX && backend_ctx->adreno_cl_compiler_version.major >= 17);
GGML_LOG_INFO("ggml_opencl: vector subgroup broadcast support: %s\n",
backend_ctx->has_vector_subgroup_broadcast ? "true" : "false");
@@ -2467,6 +2487,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
return (op->src[0]->type == op->src[1]->type) &&
(op->src[0]->type == op->type) &&
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
case GGML_OP_ADD_ID:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_GELU:
@@ -2488,6 +2510,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_SWIGLU_OAI:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
return ggml_is_contiguous_1(op->src[0]) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
@@ -2601,10 +2624,10 @@ ggml_backend_t ggml_backend_opencl_init(void) {
ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(dev);
ggml_backend_t backend = new ggml_backend {
/* .guid = */ ggml_backend_opencl_guid(),
/* .interface = */ ggml_backend_opencl_i,
/* .device = */ dev,
/* .context = */ backend_ctx
/* .guid = */ ggml_backend_opencl_guid(),
/* .iface = */ ggml_backend_opencl_i,
/* .device = */ dev,
/* .context = */ backend_ctx
};
return backend;
@@ -3822,6 +3845,75 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
}
}
static void ggml_cl_add_id(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
GGML_ASSERT(src1);
GGML_ASSERT(src1->extra);
GGML_ASSERT(dst);
GGML_ASSERT(dst->extra);
const ggml_tensor * src2 = dst->src[2];
GGML_ASSERT(src2);
GGML_ASSERT(src2->extra);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(src2->type == GGML_TYPE_I32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous_rows(src0));
const int ne00 = src0->ne[0];
const int ne01 = src0->ne[1];
const int ne02 = src0->ne[2];
const cl_ulong nb01 = src0->nb[1];
const cl_ulong nb02 = src0->nb[2];
const cl_ulong nb11 = src1->nb[1];
const cl_ulong nb21 = src2->nb[1];
const int ne0 = dst->ne[0];
const int ne1 = dst->ne[1];
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offset1 = extra1->offset + src1->view_offs;
cl_ulong offset2 = extra2->offset + src2->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
cl_kernel kernel = backend_ctx->kernel_add_id;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb21));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne0));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne1));
int nth = MIN(ne00, (int) backend_ctx->get_kernel_workgroup_size(kernel));
size_t global_work_size[] = { (size_t)ne01*nth, (size_t)ne02, 1 };
size_t local_work_size[] = { (size_t)nth, 1, 1 };
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
}
static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
@@ -6500,17 +6592,24 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
GGML_ASSERT(src1->extra);
}
const ggml_tensor * src2 = dst->src[2];
if (src2) {
GGML_ASSERT(src2->extra);
}
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
ggml_tensor_extra_cl * extra1 = src1 ? (ggml_tensor_extra_cl *)src1->extra : nullptr;
ggml_tensor_extra_cl * extra2 = src2 ? (ggml_tensor_extra_cl *)src2->extra : nullptr;
cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
cl_ulong offset2 = extra2 ? extra2->offset + src2->view_offs : offset0;
const int ne00 = src0->ne[0];
const int ne01 = src0->ne[1];
@@ -6578,25 +6677,27 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), extra1 ? &extra1->data_device : &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne13));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb1));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb3));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float), &scale));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(float), &max_bias));
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float), &m0));
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float), &m1));
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &n_head_log2));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), extra2 ? &extra2->data_device : &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne13));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb12));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb13));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb1));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb3));
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float), &scale));
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float), &max_bias));
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(float), &m0));
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float), &m1));
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &n_head_log2));
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
size_t local_work_size[] = {(size_t)nth, 1, 1};
@@ -7003,6 +7104,9 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const
kernel = backend_ctx->kernel_swiglu_f16;
}
break;
case GGML_GLU_OP_SWIGLU_OAI:
kernel = backend_ctx->kernel_swiglu_oai;
break;
case GGML_GLU_OP_GEGLU_ERF:
if (dst->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_geglu_erf;
@@ -7038,7 +7142,10 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const
const cl_ulong nb1 = dst->nb[1];
const int swp = ((const int32_t *) dst->op_params)[1];
const int swp = ggml_get_op_params_i32(dst, 1);
const float alpha = ggml_get_op_params_f32(dst, 2);
const float limit = ggml_get_op_params_f32(dst, 3);
const int ne00_off = src1 ? 0 : (swp ? ne0 : 0);
const int ne10_off = src1 ? 0 : (swp ? 0 : ne0);
@@ -7055,6 +7162,11 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne00_off));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10_off));
if (ggml_get_glu_op(dst) == GGML_GLU_OP_SWIGLU_OAI) {
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &limit));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(float), &alpha));
}
const size_t nrows = ggml_nrows(src0);
size_t nth = 512;
size_t global_work_size[] = {nrows*nth, 1, 1};
@@ -7111,6 +7223,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
}
func = ggml_cl_add;
break;
case GGML_OP_ADD_ID:
if (!any_on_device) {
return false;
}
func = ggml_cl_add_id;
break;
case GGML_OP_MUL:
if (!any_on_device) {
return false;
+42
View File
@@ -0,0 +1,42 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
//------------------------------------------------------------------------------
// add_id
//------------------------------------------------------------------------------
kernel void kernel_add_id(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * src2,
ulong offset2,
global char * dst,
ulong offsetd,
ulong nb01,
ulong nb02,
ulong nb11,
ulong nb21,
int ne0,
int ne1
) {
src0 = (global char*)((global char*)src0 + offset0);
src1 = (global char*)((global char*)src1 + offset1);
src2 = (global char*)((global char*)src2 + offset2);
dst = (global char*)((global char*)dst + offsetd);
int i1 = get_group_id(0);
int i2 = get_group_id(1);
const int i11 = *((global const int *) (src2 + i1*sizeof(int) + i2*nb21));
const size_t nb1 = ne0 * sizeof(float);
const size_t nb2 = ne1 * nb1;
global float * dst_row = (global float *)((global char *)dst + i1*nb1 + i2*nb2);
global float * src0_row = (global float *)((global char *)src0 + i1*nb01 + i2*nb02);
global float * src1_row = (global float *)((global char *)src1 + i11*nb11);
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
dst_row[i0] = src0_row[i0] + src1_row[i0];
}
}
+41
View File
@@ -202,6 +202,47 @@ kernel void kernel_swiglu_f16(
}
}
//------------------------------------------------------------------------------
// swiglu_oai
//------------------------------------------------------------------------------
kernel void kernel_swiglu_oai(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
ulong nb01,
ulong nb11,
int ne0,
ulong nb1,
int ne00_off,
int ne10_off,
float limit,
float alpha
) {
src0 = (global char*)((global char*)src0 + offset0);
src1 = (global char*)((global char*)src1 + offset1);
dst = (global char*)((global char*)dst + offsetd);
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
float x0 = src0_row[i0];
float x1 = src1_row[i0];
x0 = min(x0, limit);
x1 = max(min(x1, limit), -limit);
float out_glu = x0 / (1.0f + exp(-x0 * alpha));
out_glu = out_glu * (1.0f + x1);
dst_row[i0] = out_glu;
}
}
//------------------------------------------------------------------------------
// geglu_erf
//------------------------------------------------------------------------------
+10 -2
View File
@@ -26,6 +26,8 @@ kernel void kernel_soft_max_4_f16(
ulong offset0,
global char * src1,
ulong offset1,
global char * src2,
ulong offset2,
global char * dst,
ulong offsetd,
int ne00,
@@ -48,6 +50,7 @@ kernel void kernel_soft_max_4_f16(
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
src2 = src2 + offset2;
dst = dst + offsetd;
int i03 = get_group_id(2);
@@ -60,6 +63,7 @@ kernel void kernel_soft_max_4_f16(
global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
global half4 * pmask = src1 != src0 ? (global half4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
float slope = 1.0f;
@@ -75,7 +79,7 @@ kernel void kernel_soft_max_4_f16(
}
// parallel max
float4 lmax4 = -INFINITY;
float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
lmax4 = fmax(lmax4, psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f));
}
@@ -92,7 +96,11 @@ kernel void kernel_soft_max_4_f16(
}
float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3;
const float sum = sub_group_reduce_add(lsum);
float sum = sub_group_reduce_add(lsum);
if (psrc2) {
sum += exp(psrc2[i02] - max);
}
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
pdst4[i00] /= sum;
+10 -2
View File
@@ -26,6 +26,8 @@ kernel void kernel_soft_max_4(
ulong offset0,
global char * src1,
ulong offset1,
global char * src2,
ulong offset2,
global char * dst,
ulong offsetd,
int ne00,
@@ -48,6 +50,7 @@ kernel void kernel_soft_max_4(
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
src2 = src2 + offset2;
dst = dst + offsetd;
int i03 = get_group_id(2);
@@ -60,6 +63,7 @@ kernel void kernel_soft_max_4(
global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
float slope = 1.0f;
@@ -75,7 +79,7 @@ kernel void kernel_soft_max_4(
}
// parallel max
float4 lmax4 = -INFINITY;
float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
}
@@ -92,7 +96,11 @@ kernel void kernel_soft_max_4(
}
float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3;
const float sum = sub_group_reduce_add(lsum);
float sum = sub_group_reduce_add(lsum);
if (psrc2) {
sum += exp(psrc2[i02] - max);
}
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
pdst4[i00] /= sum;
+10 -2
View File
@@ -26,6 +26,8 @@ kernel void kernel_soft_max_f16(
ulong offset0,
global char * src1,
ulong offset1,
global char * src2,
ulong offset2,
global char * dst,
ulong offsetd,
int ne00,
@@ -48,6 +50,7 @@ kernel void kernel_soft_max_f16(
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
src2 = src2 + offset2;
dst = dst + offsetd;
int i03 = get_group_id(2);
@@ -60,6 +63,7 @@ kernel void kernel_soft_max_f16(
global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
global half * pmask = src1 != src0 ? (global half *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
float slope = 1.0f;
@@ -75,7 +79,7 @@ kernel void kernel_soft_max_f16(
}
// parallel max
float lmax = -INFINITY;
float lmax = psrc2 ? psrc2[i02] : -INFINITY;
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
}
@@ -91,7 +95,11 @@ kernel void kernel_soft_max_f16(
pdst[i00] = exp_psrc0;
}
const float sum = sub_group_reduce_add(lsum);
float sum = sub_group_reduce_add(lsum);
if (psrc2) {
sum += exp(psrc2[i02] - max);
}
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
pdst[i00] /= sum;
+10 -2
View File
@@ -26,6 +26,8 @@ kernel void kernel_soft_max(
ulong offset0,
global char * src1,
ulong offset1,
global char * src2,
ulong offset2,
global char * dst,
ulong offsetd,
int ne00,
@@ -48,6 +50,7 @@ kernel void kernel_soft_max(
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
src2 = src2 + offset2;
dst = dst + offsetd;
int i03 = get_group_id(2);
@@ -60,6 +63,7 @@ kernel void kernel_soft_max(
global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
global float * pmask = src1 != src0 ? (global float *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
float slope = 1.0f;
@@ -75,7 +79,7 @@ kernel void kernel_soft_max(
}
// parallel max
float lmax = -INFINITY;
float lmax = psrc2 ? psrc2[i02] : -INFINITY;
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
}
@@ -91,7 +95,11 @@ kernel void kernel_soft_max(
pdst[i00] = exp_psrc0;
}
const float sum = sub_group_reduce_add(lsum);
float sum = sub_group_reduce_add(lsum);
if (psrc2) {
sum += exp(psrc2[i02] - max);
}
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
pdst[i00] /= sum;
+105 -11
View File
@@ -21,6 +21,17 @@
#define UNUSED GGML_UNUSED
static inline int best_index_int8(int n, const int8_t * val, float x) {
if (x <= val[0]) return 0;
if (x >= val[n-1]) return n-1;
int ml = 0, mu = n-1;
while (mu-ml > 1) {
int mav = (ml+mu)/2;
if (x < val[mav]) mu = mav; else ml = mav;
}
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
}
// reference implementation for deterministic creation of model files
void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k) {
static const int qk = QK4_0;
@@ -246,6 +257,53 @@ void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_REST
}
}
static inline int best_index_mxfp4(float x, float e) {
int best_index = 0;
float best_err = fabsf(kvalues_mxfp4[0]*e - x);
for (int i = 1; i < 16; i++) {
float err = fabsf(kvalues_mxfp4[i]*e - x);
if (err < best_err) {
best_index = i;
best_err = err;
}
}
return best_index;
}
void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k) {
static const int qk = QK_MXFP4;
assert(k % qk == 0);
const int nb = k / qk;
for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max
for (int j = 0; j < qk; j++) {
const float v = x[i*qk + j];
if (amax < fabsf(v)) {
amax = fabsf(v);
}
}
const uint8_t e = amax > 0.0f ? (uint8_t) (floorf(log2f(amax)) - 2 + 127) : 0;
const float d = GGML_E8M0_TO_FP32_HALF(e);
y[i].e = e;
for (int j = 0; j < qk/2; ++j) {
const uint8_t x0 = best_index_mxfp4(x[i*qk + 0 + j], d);
const uint8_t x1 = best_index_mxfp4(x[i*qk + qk/2 + j], d);
y[i].qs[j] = x0;
y[i].qs[j] |= x1 << 4;
}
}
}
void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
static const int qk = QK4_0;
@@ -356,6 +414,26 @@ void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRI
}
}
void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
static const int qk = QK_MXFP4;
assert(k % qk == 0);
const int nb = k / qk;
for (int i = 0; i < nb; i++) {
const float d = GGML_E8M0_TO_FP32_HALF(x[i].e);
for (int j = 0; j < qk/2; ++j) {
const int8_t x0 = kvalues_mxfp4[x[i].qs[j] & 0x0F];
const int8_t x1 = kvalues_mxfp4[x[i].qs[j] >> 4];
y[i*qk + j + 0 ] = x0*d;
y[i*qk + j + qk/2] = x1*d;
}
}
}
//
// 2-6 bit quantization in super-blocks
//
@@ -2014,6 +2092,12 @@ size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
return nrow * row_size;
}
size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
GGML_UNUSED(quant_weights);
quantize_row_mxfp4_ref(src, dst, (int64_t)nrow*n_per_row);
return nrow * ggml_row_size(GGML_TYPE_MXFP4, n_per_row);
}
// ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k) {
@@ -4551,17 +4635,6 @@ size_t quantize_iq1_m(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
// ============================ 4-bit non-linear quants
static inline int best_index_int8(int n, const int8_t * val, float x) {
if (x <= val[0]) return 0;
if (x >= val[n-1]) return n-1;
int ml = 0, mu = n-1;
while (mu-ml > 1) {
int mav = (ml+mu)/2;
if (x < val[mav]) mu = mav; else ml = mav;
}
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
}
static void quantize_row_iq4_nl_impl(const int super_block_size, const int block_size, const float * GGML_RESTRICT x,
ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l,
float * scales, float * weight, uint8_t * L,
@@ -4961,6 +5034,15 @@ static bool validate_fp16(ggml_fp16_t f, size_t i) {
return true;
}
static bool validate_e_e8m0(uint8_t e, size_t i) {
if (e == 0xff) {
fprintf(stderr, "ggml_validate_row_data: found invalid e value %d at block %zu\n", e, i);
return false;
}
return true;
}
#define VALIDATE_ROW_DATA_D_F16_IMPL(type, data, nb) \
const type * q = (const type *) (data); \
for (size_t i = 0; i < (nb); ++i) { \
@@ -4977,6 +5059,14 @@ static bool validate_fp16(ggml_fp16_t f, size_t i) {
} \
}
#define VALIDATE_ROW_DATA_E_E8M0_IMPL(type, data, nb) \
const type * q = (const type *) (data); \
for (size_t i = 0; i < (nb); ++i) { \
if (!validate_e_e8m0(q[i].e, i)) { \
return false; \
} \
}
#define VALIDATE_ROW_DATA_DVEC_F16_IMPL(type, data, nb, nr) \
const type * q = (const type *) (data); \
for (size_t i = 0; i < (nb); ++i) { \
@@ -5130,6 +5220,10 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
{
VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb);
} break;
case GGML_TYPE_MXFP4:
{
VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp4, data, nb);
} break;
case GGML_TYPE_Q2_K:
{
VALIDATE_ROW_DATA_DM_F16_IMPL(block_q2_K, data, nb, d, dmin);
+6
View File
@@ -21,6 +21,8 @@ GGML_API void quantize_row_q5_1_ref(const float * GGML_RESTRICT x, block_q5_1 *
GGML_API void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k);
@@ -45,6 +47,8 @@ GGML_API void dequantize_row_q5_1(const block_q5_1 * GGML_RESTRICT x, float * GG
GGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
//GGML_API void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
@@ -90,6 +94,8 @@ GGML_API size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTR
GGML_API size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
GGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
GGML_API void iq2xs_init_impl(enum ggml_type type);
GGML_API void iq2xs_free_impl(enum ggml_type type);
GGML_API void iq3xs_init_impl(int grid_size);
+4 -4
View File
@@ -823,10 +823,10 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
};
ggml_backend_t backend = new ggml_backend {
/* .guid = */ ggml_backend_rpc_guid(),
/* .interface = */ ggml_backend_rpc_interface,
/* .device = */ ggml_backend_rpc_add_device(endpoint),
/* .context = */ ctx
/* .guid = */ ggml_backend_rpc_guid(),
/* .iface = */ ggml_backend_rpc_interface,
/* .device = */ ggml_backend_rpc_add_device(endpoint),
/* .context = */ ctx
};
return backend;
}
+29 -18
View File
@@ -2609,6 +2609,8 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(src1->ne[1] == 1);
GGML_ASSERT(src1->ne[3] == 1);
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
@@ -2688,6 +2690,9 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
const size_t type_size_src0 = ggml_type_size(src0->type);
const size_t type_size_src1 = ggml_type_size(src1->type);
bool is_src0_cont_2 = ggml_is_contiguous_2(src0);
bool is_src1_cont_2 = ggml_is_contiguous_2(src1);
// SRC1 strides
int64_t s11 = nb11 / type_size_src1;
int64_t s12 = nb12 / type_size_src1;
@@ -2737,6 +2742,8 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
s11 = ne10;
s12 = ne11 * s11;
s13 = ne12 * s12;
is_src1_cont_2 = true;
}
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
@@ -2852,12 +2859,16 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
else
#endif
{
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) {
// with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3:
const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00;
const int64_t smb = ne12 == 1 ? s13 : s12;
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_ddf,
src0_f16, dpct::library_data_t::real_half, nb01 / nb00, sma,
src1_f16, dpct::library_data_t::real_half, s11, smb, beta, dst_ddf,
mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
} else {
const int ne23 = ne12 * ne13;
@@ -3187,7 +3198,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
// The kernel from the if path is faster for that specific case, but does not support all mul mats.
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
}
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1 && src1->ne[3] == 1) {
// KQV single-batch
ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2] * src1->ne[3] > 1) {
@@ -4182,15 +4193,9 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
{
struct ggml_tensor * a;
struct ggml_tensor * b;
if (op->op == GGML_OP_MUL_MAT) {
a = op->src[0];
b = op->src[1];
} else {
a = op->src[2];
b = op->src[1];
}
struct ggml_tensor * a = op->src[0];
struct ggml_tensor * b = op->src[1];
if (a->ne[3] != b->ne[3]) {
return false;
}
@@ -4205,7 +4210,9 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
}
}
ggml_type src0_type = op->src[0]->type;
if (src0_type == GGML_TYPE_BF16) {
if (src0_type == GGML_TYPE_BF16 || src0_type == GGML_TYPE_MXFP4) {
// TODO: support MXFP4
// FIXME: keep a list of supported types to avoid breaking the backend when a new type is added
return false;
}
return true;
@@ -4350,6 +4357,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
if (op->src[0]->ne[3] != 1) {
return false;
}
// TODO: support attention sinks [TAG_ATTN_SINKS]
if (op->src[2]) {
return false;
}
// TODO: support broadcast
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
@@ -4575,10 +4586,10 @@ ggml_backend_t ggml_backend_sycl_init(int device) {
};
ggml_backend_t sycl_backend = new ggml_backend {
/* .guid = */ ggml_backend_sycl_guid(),
/* .interface = */ ggml_backend_sycl_interface,
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), device),
/* .context = */ ctx
/* .guid = */ ggml_backend_sycl_guid(),
/* .iface = */ ggml_backend_sycl_interface,
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), device),
/* .context = */ ctx
};
return sycl_backend;
+241 -50
View File
@@ -449,6 +449,8 @@ struct vk_device_struct {
vk_pipeline pipeline_div[2][2][2];
vk_pipeline pipeline_div_norepeat[2][2][2];
vk_pipeline pipeline_add_id_f32;
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32;
vk_pipeline pipeline_scale_f32;
@@ -483,6 +485,7 @@ struct vk_device_struct {
vk_pipeline pipeline_geglu[2];
vk_pipeline pipeline_reglu[2];
vk_pipeline pipeline_swiglu[2];
vk_pipeline pipeline_swiglu_oai[2];
vk_pipeline pipeline_geglu_erf[2];
vk_pipeline pipeline_geglu_quick[2];
@@ -531,6 +534,7 @@ struct vk_device_struct {
ggml_backend_buffer_type buffer_type;
bool disable_fusion;
bool disable_host_visible_vidmem;
#ifdef GGML_VULKAN_MEMORY_DEBUG
std::unique_ptr<vk_memory_logger> memory_logger;
@@ -705,6 +709,8 @@ struct vk_op_glu_push_constants {
uint32_t ne00;
uint32_t ne20;
uint32_t mode; // 0: default, 1: swapped, 2: split
float alpha; // for swiglu_oai
float limit;
};
struct vk_op_unary_push_constants {
@@ -794,6 +800,15 @@ struct vk_op_binary_push_constants {
float param1; float param2; int32_t param3;
};
struct vk_op_add_id_push_constants {
uint32_t ne0;
uint32_t ne1;
uint32_t s01;
uint32_t s02;
uint32_t s11;
uint32_t s21;
};
struct vk_op_diag_mask_push_constants {
uint32_t ncols;
uint32_t rows_per_channel;
@@ -835,6 +850,7 @@ struct vk_op_soft_max_push_constants {
float m1;
uint32_t n_head_log2;
uint32_t nrows_x;
uint32_t has_sinks;
};
struct vk_op_argsort_push_constants {
@@ -1789,6 +1805,8 @@ static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) {
} else if (device->uma) {
// Fall back to host memory type
buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
} else if (device->disable_host_visible_vidmem) {
buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eDeviceLocal);
} else {
// use rebar if available, otherwise fallback to device only visible memory
buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal);
@@ -1977,6 +1995,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
break;
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_MXFP4:
lut_size = 4*16;
break;
default:
@@ -2106,12 +2125,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
s_mmq_wg_denoms = { 32, 64, 1 };
// spec constants and tile sizes for quant matmul (Qi_K)
l_warptile_mmq_k = { 256, 64, 128, 64, 1 };
m_warptile_mmq_k = { 256, 32, 64, 64, 0 };
s_warptile_mmq_k = { 256, 32, 32, 128, 0 };
l_mmq_wg_denoms_k = { 64, 128, 1 };
m_mmq_wg_denoms_k = { 32, 64, 1 };
s_mmq_wg_denoms_k = { 32, 32, 1 };
l_warptile_mmq_k = { 256, 128, 256, 64, 1 };
m_warptile_mmq_k = { 256, 128, 128, 64, 1 };
s_warptile_mmq_k = { 256, 32, 64, 128, 0 };
l_mmq_wg_denoms_k = { 128, 256, 1 };
m_mmq_wg_denoms_k = { 128, 128, 1 };
s_mmq_wg_denoms_k = { 32, 64, 1 };
// spec constants and tile sizes for quant matmul_id
l_warptile_mmqid = { 256, 128, 128, 16, 0 };
@@ -2267,14 +2286,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
};
#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, HSK, HSV, HEAD_SIZES) \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64, 64, 64) \
@@ -2353,6 +2372,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S], matmul_iq3_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_MXFP4], matmul_mxfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
@@ -2379,6 +2399,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
#undef CREATE_MM
#undef CREATE_MM2
} else
@@ -2440,6 +2461,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
} else {
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@@ -2461,6 +2483,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
}
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
@@ -2493,6 +2516,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
} else {
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
@@ -2514,6 +2538,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
}
#undef CREATE_MM2
#undef CREATE_MM
@@ -2581,6 +2606,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
@@ -2618,6 +2644,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
#undef CREATE_MM2
#undef CREATE_MMQ
#undef CREATE_MM
@@ -2672,6 +2699,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
@@ -2709,6 +2737,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
}
// reusing CREATE_MM from the fp32 path
if ((device->coopmat2 || device->coopmat_support)
@@ -2767,6 +2796,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f32_f32_len, mul_mat_vec_iq4_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32_"+std::to_string(i+1), mul_mat_vec_mxfp4_f32_f32_len, mul_mat_vec_mxfp4_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
@@ -2790,6 +2820,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f16_f32_len, mul_mat_vec_iq4_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32_"+std::to_string(i+1), mul_mat_vec_mxfp4_f16_f32_len, mul_mat_vec_mxfp4_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
}
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
@@ -2814,6 +2845,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", mul_mat_vec_id_mxfp4_f32_len, mul_mat_vec_id_mxfp4_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
// dequant shaders
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
@@ -2836,6 +2868,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_S], "dequant_iq3_s", dequant_iq3_s_len, dequant_iq3_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4], "dequant_mxfp4", dequant_mxfp4_len, dequant_mxfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
// get_rows
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
@@ -2855,6 +2888,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_S], "get_rows_iq3_s", get_rows_iq3_s_len, get_rows_iq3_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
@@ -2873,9 +2907,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_S], "get_rows_iq3_s_f32", get_rows_iq3_s_f32_len, get_rows_iq3_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs_f32", get_rows_iq4_xs_f32_len, get_rows_iq4_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 4 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, 5 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
@@ -2976,6 +3011,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_BINARY(div, _norepeat, {1})
#undef CREATE_BINARY
ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
@@ -3026,6 +3063,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_GLU(geglu)
CREATE_GLU(reglu)
CREATE_GLU(swiglu)
CREATE_GLU(swiglu_oai)
CREATE_GLU(geglu_erf)
CREATE_GLU(geglu_quick)
#undef CREATE_GLU
@@ -3035,10 +3073,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
@@ -3096,6 +3134,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
uint32_t conv2d_SHMEM_PAD = 4;
bool conv2d_UNROLL = true;
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (device->coopmat2) {
conv2d_SHMEM_PAD = 8; // 8 float16_t
}
#endif
if (device->vendor_id == VK_VENDOR_ID_INTEL) {
conv2d_SHMEM_PAD = 0;
conv2d_UNROLL = false;
@@ -3154,6 +3198,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
std::array<uint32_t, 3> wg_denoms = { conv2d_BS_K, conv2d_BS_NPQ, 1 };
std::vector<uint32_t> spec_constants = { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD };
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (device->coopmat2) {
ggml_vk_create_pipeline(
device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_cm2_len, conv2d_f32_cm2_data, "main", 3,
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
ggml_vk_create_pipeline(
device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_cm2_len, conv2d_f16_f32_cm2_data, "main", 3,
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
} else
#endif
if (conv2d_UNROLL) {
ggml_vk_create_pipeline(
device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_unroll_len, conv2d_f32_unroll_data, "main", 3,
@@ -3214,6 +3268,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
const char* GGML_VK_PREFER_HOST_MEMORY = getenv("GGML_VK_PREFER_HOST_MEMORY");
device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr;
const char* GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM = getenv("GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM");
device->disable_host_visible_vidmem = GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM != nullptr;
bool fp16_storage = false;
bool fp16_compute = false;
bool maintenance4_support = false;
@@ -4228,6 +4285,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
break;
default:
return nullptr;
@@ -4298,6 +4356,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
break;
default:
return nullptr;
@@ -4341,6 +4400,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
break;
default:
return nullptr;
@@ -4395,6 +4455,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
break;
default:
return nullptr;
@@ -4430,6 +4491,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
break;
default:
return nullptr;
@@ -4615,6 +4677,7 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context&
std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size());
GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT);
GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size());
vk::DescriptorSet& descriptor_set = ctx->descriptor_sets[ctx->descriptor_set_idx++];
vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() };
@@ -5022,26 +5085,37 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz
ggml_vk_queue_command_pools_cleanup(dst->device);
}
static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int n, int k, const vk_pipeline& pipeline) {
static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, uint32_t n, uint32_t k, const vk_pipeline& pipeline) {
VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")");
uint32_t split_k = 1;
if (ctx->device->shader_core_count != 0 && m >= (int)pipeline->wg_denoms[0] && n >= (int)pipeline->wg_denoms[1]) {
if (ctx->device->shader_core_count != 0 && m >= pipeline->wg_denoms[0] && n >= pipeline->wg_denoms[1]) {
// If k is 'large' and the SMs will fill less than halfway, use split_k.
uint32_t m_tiles = CEIL_DIV(m, pipeline->wg_denoms[0]);
uint32_t n_tiles = CEIL_DIV(n, pipeline->wg_denoms[1]);
if (k >= 2048 && m_tiles * n_tiles < ctx->device->shader_core_count / 2) {
split_k = ctx->device->shader_core_count / (m_tiles * n_tiles);
// Clamp to 2 or 4
split_k = std::min(split_k, 4u);
if (split_k == 3) {
split_k = 2;
if (k >= 2048) {
if (m_tiles * n_tiles <= ctx->device->shader_core_count / 2) {
split_k = ctx->device->shader_core_count / (m_tiles * n_tiles);
} else if (m_tiles * n_tiles <= ctx->device->shader_core_count * 2 / 3) {
split_k = 3;
}
if (ctx->device->coopmat2) {
// coopmat2 shader expects splits to be aligned to 256
while (split_k > 1 && ((k / split_k) % 256) != 0) {
split_k /= 2;
// Cap the split at 8x. Unless k is huge this is a lot of overhead.
split_k = std::min(split_k, 8u);
// ggml_vk_matmul will align the splits to be a multiple of 256.
// If this rounded up size would cause the last split to be empty,
// then reduce the split count.
while (true) {
if (split_k == 1) {
break;
}
uint32_t k_split = CEIL_DIV(k, split_k);
k_split = ROUNDUP_POW2(k_split, 256);
if (k_split * (split_k - 1) < k) {
break;
}
split_k--;
}
}
}
@@ -5053,9 +5127,22 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
if (ctx->device->coopmat2) {
const uint32_t shader_core_count = ctx->device->shader_core_count;
const uint32_t tiles_l = CEIL_DIV(m, mmp->a_l->wg_denoms[0]) * CEIL_DIV(n, mmp->a_l->wg_denoms[1]);
const uint32_t tiles_m = CEIL_DIV(m, mmp->a_m->wg_denoms[0]) * CEIL_DIV(n, mmp->a_m->wg_denoms[1]);
// Use large shader when the N dimension is greater than the medium shader's tile size
uint32_t crossover_large = mmp->m->wg_denoms[1];
if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
// Prefer large over medium if either:
// - medium or large tiles would overfill the GPU
// - large tiles with a split_k==3 fits in the GPU and medium tiles with split_k==2 does not
// (medium with split_k==2 is probably better if it fits - more workgroups running and less split_k overhead)
bool prefer_large = tiles_m > shader_core_count || tiles_l > shader_core_count ||
// split_k==3 with large tiles likely better than medium tiles with no split_k.
(tiles_l <= shader_core_count / 3 && tiles_m > shader_core_count / 2);
if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large && prefer_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
return aligned ? mmp->a_l : mmp->l;
}
// Use medium shader when the N dimension is greater than the small shader's tile size
@@ -5099,7 +5186,11 @@ static void ggml_vk_matmul(
GGML_ASSERT(batch_stride_d == m * n);
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, padded_n };
// Round the split size up to a multiple of 256 (k-quant alignment)
uint32_t k_split = CEIL_DIV(k, split_k);
k_split = ROUNDUP_POW2(k_split, 256);
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
// Make sure enough workgroups get assigned for split k to work
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
ggml_vk_sync_buffers(subctx);
@@ -6416,11 +6507,14 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
return supported;
}
static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, const ggml_tensor * sinks, ggml_tensor * dst, bool dryrun = false) {
VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3];
std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3];
std::cerr << "), (" << v << ", name=" << v->name << ", type=" << v->type << ", ne0=" << v->ne[0] << ", ne1=" << v->ne[1] << ", ne2=" << v->ne[2] << ", ne3=" << v->ne[3] << ", nb0=" << v->nb[0] << ", nb1=" << v->nb[1] << ", nb2=" << v->nb[2] << ", nb3=" << v->nb[3];
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
if (sinks) {
std::cerr << "), (" << sinks << ", name=" << sinks->name << ", type=" << sinks->type << ", ne0=" << sinks->ne[0] << ", ne1=" << sinks->ne[1] << ", ne2=" << sinks->ne[2] << ", ne3=" << sinks->ne[3] << ", nb0=" << sinks->nb[0] << ", nb1=" << sinks->nb[1] << ", nb2=" << sinks->nb[2] << ", nb3=" << sinks->nb[3];
}
std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
@@ -6619,10 +6713,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr;
size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0;
vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr, d_S = nullptr;
size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0, s_buf_offset = 0;
bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false;
bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false, S_uma = false;
if (ctx->device->uma) {
ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset);
@@ -6637,6 +6731,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
ggml_vk_host_get(ctx->device, mask->data, d_M, m_buf_offset);
M_uma = d_M != nullptr;
}
if (sinks) {
ggml_vk_host_get(ctx->device, sinks->data, d_S, s_buf_offset);
S_uma = d_S != nullptr;
}
}
@@ -6672,7 +6770,17 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
}
}
uint32_t mask_n_head_log2 = ((mask != nullptr) << 16) | n_head_log2;
if (!S_uma) {
d_S = d_Q;
s_buf_offset = q_buf_offset;
if (sinks) {
ggml_backend_vk_buffer_context * s_buf_ctx = (ggml_backend_vk_buffer_context*)sinks->buffer->context;
d_S = s_buf_ctx->dev_buffer;
s_buf_offset = vk_tensor_offset(sinks) + sinks->view_offs;
}
}
uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2;
const vk_flash_attn_push_constants pc = { N, KV,
(uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
@@ -6696,6 +6804,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
},
// We only use split_k when group query attention is enabled, which means
@@ -6705,10 +6814,11 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
ggml_vk_sync_buffers(subctx);
const std::array<uint32_t, 4> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k };
const std::array<uint32_t, 5> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) };
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
{
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
},
pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
@@ -6719,6 +6829,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
},
pc, { workgroups_x, workgroups_y, workgroups_z });
@@ -6803,6 +6914,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
break;
}
return nullptr;
case GGML_OP_ADD_ID:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && src2->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_add_id_f32;
}
return nullptr;
case GGML_OP_CONCAT:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_concat_f32;
@@ -6948,6 +7064,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
case GGML_GLU_OP_SWIGLU:
return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
case GGML_GLU_OP_SWIGLU_OAI:
return ctx->device->pipeline_swiglu_oai[dst->type == GGML_TYPE_F16];
case GGML_GLU_OP_GEGLU_ERF:
return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16];
case GGML_GLU_OP_GEGLU_QUICK:
@@ -6963,6 +7081,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return nullptr;
case GGML_OP_SOFT_MAX:
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32);
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
@@ -7133,6 +7252,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_ADD_ID:
case GGML_OP_CONCAT:
case GGML_OP_UPSCALE:
case GGML_OP_SQR:
@@ -7479,6 +7599,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
elements = { ne, 1, 1 };
}
} break;
case GGML_OP_ADD_ID:
{
elements = { (uint32_t)ne01, (uint32_t)ne02, 1 };
} break;
case GGML_OP_SET_ROWS:
{
uint32_t ne = ggml_nelements(src0);
@@ -7518,8 +7642,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
}
}
if (op == GGML_OP_SOFT_MAX || op == GGML_OP_GLU) {
// Empty src1 is possible in soft_max, but the shader needs a buffer
if (op == GGML_OP_GLU) {
// Empty src1 is possible in glu, but the shader needs a buffer
vk_subbuffer subbuf_y;
if (use_src1) {
subbuf_y = { d_Y, y_buf_offset, y_sz };
@@ -7529,6 +7653,24 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (op == GGML_OP_SOFT_MAX) {
// Empty src1 and src2 is possible in soft_max, but the shader needs a buffer
vk_subbuffer subbuf_y;
if (use_src1) {
subbuf_y = { d_Y, y_buf_offset, y_sz };
} else {
subbuf_y = { d_X, 0, x_sz };
}
vk_subbuffer subbuf_z;
if (use_src2) {
subbuf_z = { d_Z, z_buf_offset, z_sz };
} else {
subbuf_z = { d_X, 0, x_sz };
}
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
// Empty src2 is possible in rope, but the shader needs a buffer
vk_subbuffer subbuf_z;
@@ -7657,6 +7799,21 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
}, dryrun);
}
static void ggml_vk_add_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t src2_type_size = ggml_type_size(src2->type);
ggml_vk_op_f32<vk_op_add_id_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ADD_ID, {
(uint32_t)dst->ne[0],
(uint32_t)dst->ne[1],
(uint32_t)src0->nb[1] / src0_type_size,
(uint32_t)src0->nb[2] / src0_type_size,
(uint32_t)src1->nb[1] / src1_type_size,
(uint32_t)src2->nb[1] / src2_type_size,
}, dryrun);
}
static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version, bool dryrun = false) {
GGML_ASSERT(version == 6 || version == 7);
int num_srcs = version == 6 ? 6 : 7;
@@ -8075,8 +8232,12 @@ static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, con
}
static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
const float * op_params_f = (const float *)dst->op_params;
const bool swapped = (bool)dst->op_params[1];
const bool split = src1 != nullptr;
const float alpha = op_params_f[2];
const float limit = op_params_f[3];
GGML_ASSERT(ggml_is_contiguous(src0));
@@ -8090,7 +8251,15 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const
const uint32_t mode = split ? 2 : (swapped ? 1 : 0);
ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, { (uint32_t)ggml_nelements(dst), (uint32_t)src0->ne[0], (uint32_t)dst->ne[0], mode }, dryrun);
ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU,
{
(uint32_t)ggml_nelements(dst),
(uint32_t)src0->ne[0],
(uint32_t)dst->ne[0],
mode,
alpha,
limit
}, dryrun);
}
static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
@@ -8098,7 +8267,7 @@ static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& sub
ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
}
static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
float * op_params = (float *)dst->op_params;
float scale = op_params[0];
@@ -8120,7 +8289,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, {
ncols,
src1 != nullptr ? nrows_y : (uint32_t)0,
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
@@ -8130,6 +8299,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
m0, m1,
n_head_log2,
nrows_x,
src2 != nullptr
}, dryrun);
}
@@ -9369,6 +9539,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_SWIGLU_OAI:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
break;
@@ -9380,6 +9551,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_REPEAT_BACK:
case GGML_OP_GET_ROWS:
case GGML_OP_ADD:
case GGML_OP_ADD_ID:
case GGML_OP_ACC:
case GGML_OP_SUB:
case GGML_OP_MUL:
@@ -9534,6 +9706,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_DIV:
ggml_vk_div(ctx, compute_ctx, src0, src1, node, dryrun);
break;
case GGML_OP_ADD_ID:
ggml_vk_add_id(ctx, compute_ctx, src0, src1, src2, node, dryrun);
break;
case GGML_OP_CONCAT:
ggml_vk_concat(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -9631,6 +9807,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_SWIGLU_OAI:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -9644,7 +9821,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break;
case GGML_OP_SOFT_MAX:
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun);
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node, dryrun);
break;
case GGML_OP_SOFT_MAX_BACK:
@@ -9717,7 +9894,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break;
case GGML_OP_FLASH_ATTN_EXT:
ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node, dryrun);
ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node->src[4], node, dryrun);
break;
@@ -9790,6 +9967,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_ADD_ID:
case GGML_OP_CONCAT:
case GGML_OP_UPSCALE:
case GGML_OP_SCALE:
@@ -9859,6 +10037,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_SWIGLU_OAI:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
buf = tensor->buffer;
@@ -10588,10 +10767,10 @@ ggml_backend_t ggml_backend_vk_init(size_t dev_num) {
ggml_vk_init(ctx, dev_num);
ggml_backend_t vk_backend = new ggml_backend {
/* .guid = */ ggml_backend_vk_guid(),
/* .interface = */ ggml_backend_vk_interface,
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), dev_num),
/* .context = */ ctx,
/* .guid = */ ggml_backend_vk_guid(),
/* .iface = */ ggml_backend_vk_interface,
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), dev_num),
/* .context = */ ctx,
};
return vk_backend;
@@ -10708,6 +10887,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_SWIGLU_OAI:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
return ggml_is_contiguous(op->src[0]) &&
@@ -10753,6 +10933,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
break;
default:
return false;
@@ -10790,6 +10971,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) {
return false;
}
if (op->src[4] && op->src[4]->type != GGML_TYPE_F32) {
return false;
}
if (op->src[0]->type != GGML_TYPE_F32) {
return false;
}
@@ -10862,6 +11046,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
return true;
default:
return false;
@@ -10960,6 +11145,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
(op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
case GGML_OP_ADD_ID:
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->src[2]->type == GGML_TYPE_I32 &&
op->type == GGML_TYPE_F32;
case GGML_OP_SILU_BACK:
case GGML_OP_RMS_NORM_BACK:
case GGML_OP_SQR:
@@ -11378,6 +11566,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
const float * params = (const float *)tensor->op_params;
tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]);
if (src_clone[4]) {
ggml_flash_attn_ext_add_sinks(tensor_clone, src_clone[4]);
}
} else if (tensor->op == GGML_OP_MUL_MAT) {
tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_MUL_MAT_ID) {
@@ -0,0 +1,42 @@
#version 450
#extension GL_EXT_control_flow_attributes : require
#include "types.comp"
layout (push_constant) uniform parameter
{
uint ne0;
uint ne1;
uint s01;
uint s02;
uint s11;
uint s21;
} p;
#define BLOCK_SIZE 512
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
layout (binding = 2) readonly buffer Z {int32_t data_c[];};
layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i1 = gl_WorkGroupID.x;
const uint i2 = gl_WorkGroupID.y;
const uint i11 = data_c[i1 + i2 * p.s21];
const uint s1 = p.ne0;
const uint s2 = p.ne0 * p.ne1;
const uint d0 = i1 * s1 + i2 * s2;
const uint a0 = i1 * p.s01 + i2 * p.s02;
const uint b0 = i11 * p.s11;
for (uint i0 = gl_LocalInvocationID.x; i0 < p.ne0; i0 += BLOCK_SIZE) {
data_d[d0 + i0] = data_a[a0 + i0] + data_b[b0 + i0];
}
}
@@ -1,6 +1,11 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#ifdef COOPMAT2
#extension GL_NV_cooperative_matrix2 : enable
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_KHR_memory_scope_semantics : enable
#endif
#ifdef USE_COLLECTIVES
# extension GL_KHR_shader_subgroup_shuffle : enable
@@ -91,6 +96,12 @@ uint32_t n_elems_out = K * NPQ;
// Number of blocktiles per input
uint32_t NB_CRS = splitWork(CRS, BS_CRS);
#ifdef COOPMAT2
#define SHMEM_TYPE float16_t
#else
#define SHMEM_TYPE float
#endif
const uint32_t Ash_stride = BS_CRS + SHMEM_PAD;
const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD;
@@ -100,8 +111,8 @@ const uint32_t Bsh_numel = BS_CRS * BS_NPQ;
const uint32_t Ash_len = BS_K * Ash_stride;
const uint32_t Bsh_len = BS_CRS * Bsh_stride;
shared float Ash[Ash_len]; // K x CRS
shared float Bsh[Bsh_len]; // CRS x NPQ
shared SHMEM_TYPE Ash[Ash_len]; // K x CRS
shared SHMEM_TYPE Bsh[Bsh_len]; // CRS x NPQ
// Threadtile sizes
const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K;
@@ -110,10 +121,6 @@ const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K;
const uint32_t NT_K = BS_K / TS_K;
const uint32_t NT_NPQ = BS_NPQ / TS_NPQ;
float regA[TS_K];
float regB[TS_NPQ];
float regC[TS_K][TS_NPQ];
/*
Compute
KxCRS @ CRSxNPQ = K x NPQ
@@ -145,12 +152,36 @@ uint fastdiv(uint n, uint mp, uint L) {
return (msbs + n) >> L;
}
#ifdef COOPMAT2
#define ACC_TYPE float16_t
ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem)
{
uint32_t K_idx = B_idx_K * BS_K + r;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + c;
uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW;
uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW;
uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW;
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3;
if (K_idx < K && NPQ_idx < NPQ) {
dst_data[dst_idx] = D_TYPE(elem);
}
return elem;
}
#endif
void main() {
#ifdef COOPMAT2
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator> matC;
matC = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator>(0.0);
#else
float regC[TS_K][TS_NPQ];
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regC[T_ly][T_lx] = 0.0;
}
}
#endif
/* Advance block in CRS dim */
for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) {
uint32_t CRS_idx_a;
@@ -199,7 +230,7 @@ void main() {
if (K_idx >= K || CRS_idx_a >= CRS) {
val = 0.0;
}
Ash[B_ly * Ash_stride + B_lx] = val;
Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val);
}
/* Load input to B_block: (BS_CRS x BS_NPQ) */
UNROLL for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
@@ -244,11 +275,21 @@ void main() {
if (CRS_idx_b >= CRS || NPQ_idx >= NPQ || H_idx < 0 || H_idx >= p.H || W_idx < 0 || W_idx >= p.W) {
val = 0.0;
}
Bsh[B_ly * Bsh_stride + B_lx] = val;
Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val);
}
barrier();
#ifdef COOPMAT2
coopmat<float16_t, gl_ScopeWorkgroup, BS_K, BS_CRS, gl_MatrixUseA> matA;
coopmat<float16_t, gl_ScopeWorkgroup, BS_CRS, BS_NPQ, gl_MatrixUseB> matB;
coopMatLoad(matA, Ash, 0, Ash_stride, gl_CooperativeMatrixLayoutRowMajor);
coopMatLoad(matB, Bsh, 0, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor);
matC = coopMatMulAdd(matA, matB, matC);
#else
if (T_y * TS_K < K) {
UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
float regA[TS_K];
float regB[TS_NPQ];
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
}
@@ -262,9 +303,13 @@ void main() {
}
}
}
#endif
barrier();
}
/* Save C* */
#ifdef COOPMAT2
coopMatPerElementNV(matC, matC, perElemOpStore);
#else
if (T_y * TS_K < K) {
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
@@ -280,4 +325,5 @@ void main() {
}
}
}
#endif
}
@@ -4,8 +4,8 @@
#include "generic_unary_head.comp"
#include "dequant_funcs.comp"
#if defined(DATA_A_IQ4_NL)
// 16 invocations needed for init_iq4nl_shmem
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4)
// 16 invocations needed for init_iq_shmem
layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
#else
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
@@ -434,6 +434,18 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
}
#endif
#if defined(DATA_A_MXFP4)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
vec2 v0 = dequantize(ib, iqs, a_offset);
vec2 v1 = dequantize(ib, iqs + 1, a_offset);
return vec4(v0.x, v0.y, v1.x, v1.y);
}
#endif
#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
vec2 get_dm(uint ib, uint a_offset) {
return vec2(0, 0);
@@ -455,6 +467,12 @@ vec2 get_dm(uint ib, uint a_offset) {
}
#endif
#if defined(DATA_A_MXFP4)
vec2 get_dm(uint ib, uint a_offset) {
return vec2(e8m0_to_fp32(data_a[a_offset + ib].e), 0);
}
#endif
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
vec2 get_dm(uint ib, uint a_offset) {
return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m));
@@ -654,6 +654,25 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
}
#endif
#if defined(DATA_A_MXFP4)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufMXFP4 {
block_mxfp4 block;
};
float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float d = e8m0_to_fp32(bl.block.e);
const uint idx = coordInBlock[1];
const uint iqs = idx & 0xF;
const uint shift = (idx & 0x10) >> 2;
uint32_t qs = bl.block.qs[iqs];
qs >>= shift;
qs &= 0xF;
float16_t ret = float16_t(kvalues_mxfp4[qs] * d);
return ret;
}
#endif
#if defined(DATA_A_Q4_0)
#define dequantFuncA dequantFuncQ4_0
#elif defined(DATA_A_Q4_1)
@@ -696,4 +715,6 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
#define dequantFuncA dequantFuncIQ4_XS
#elif defined(DATA_A_IQ4_NL)
#define dequantFuncA dequantFuncIQ4_NL
#elif defined(DATA_A_MXFP4)
#define dequantFuncA dequantFuncMXFP4
#endif
@@ -0,0 +1,32 @@
#version 450
#include "dequant_head.comp"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_mxfp4 data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
init_iq_shmem(gl_WorkGroupSize);
const uint tid = gl_LocalInvocationID.x % 64;
const uint il = tid/32;
const uint ir = tid%32;
const uint ib = 32*i + ir;
if (ib >= p.nel / 32) {
return;
}
const uint q_idx = 8*il;
const uint b_idx = 1024*i + 32*ir + q_idx;
const float d = e8m0_to_fp32(data_a[ib].e);
[[unroll]] for (uint l = 0; l < 8; ++l) {
data_b[b_idx + l + 0] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]);
data_b[b_idx + l + 16] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]);
}
}
@@ -305,6 +305,27 @@ void main() {
return;
}
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2);
float ms = 1.0f;
float vs = 1.0f;
if (sink > Mf[r]) {
ms = exp(Mf[r] - sink);
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] *= ms;
}
} else {
vs = exp(sink - Mf[r]);
}
Lf[r] = Lf[r]*ms + vs;
}
}
float Lfrcp[Br];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Lfrcp[r] = 1.0 / Lf[r];
@@ -50,10 +50,13 @@ layout (push_constant) uniform parameter {
uint32_t k_num;
} p;
#define SINK_ENABLE_BIT (1<<24)
#define MASK_ENABLE_BIT (1<<16)
#define N_LOG2_MASK 0xFFFF
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
layout (binding = 4) readonly buffer S {float data_s[];};
layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
#if defined(A_TYPE_PACKED16)
#define BINDING_IDX_K 0
@@ -111,6 +114,14 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
}
// Load the sink value, indexed by Q's dimension 2.
ACC_TYPE perElemOpGetSink(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
{
const uint32_t h = iq2 + (r % p.gqa_ratio);
return ACC_TYPE(data_s[h]);
}
uint32_t i, N, KV, split_k_index, Tr, start_j, end_j,
iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
q_stride, k_stride, v_stride, m_stride;
@@ -329,6 +329,27 @@ void main() {
return;
}
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2);
float ms = 1.0f;
float vs = 1.0f;
if (sink > Mf[r]) {
ms = exp(Mf[r] - sink);
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] *= ACC_TYPE(ms);
}
} else {
vs = exp(sink - Mf[r]);
}
Lf[r] = Lf[r]*ms + vs;
}
}
float Lfrcp[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Lfrcp[r] = 1.0 / Lf[r];
@@ -248,6 +248,34 @@ void main() {
// resize L by using smear/reduce
coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> S;
coopMatPerElementNV(S, S, perElemOpGetSink, iq2);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> Mr;
// resize M by using smear/reduce
coopMatReduceNV(Mr, M, gl_CooperativeMatrixReduceRowNV, smearReduce);
// O, Ldiag, Mr all have the same type so all element locations match
[[unroll]] for (uint32_t i = 0; i < Ldiag.length(); ++i) {
ACC_TYPE sink = S[i];
ACC_TYPE ms = ACC_TYPE(1.0f);
ACC_TYPE vs = ACC_TYPE(1.0f);
if (sink > Mr[i]) {
ms = exp(Mr[i] - sink);
O[i] *= ms;
} else {
vs = exp(sink - Mr[i]);
}
Ldiag[i] = Ldiag[i]*ms + vs;
}
}
[[unroll]]
for (int k = 0; k < Ldiag.length(); ++k) {
Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k];
@@ -7,13 +7,15 @@ layout(constant_id = 0) const uint BLOCK_SIZE = 32;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {float data_a[];};
layout (binding = 1) writeonly buffer D {float data_d[];};
layout (binding = 1) readonly buffer B {float data_s[];};
layout (binding = 2) writeonly buffer D {float data_d[];};
layout (push_constant) uniform parameter {
uint D;
uint N;
uint ne3;
uint k_num;
uint sinks;
} p;
shared float tmpsh[BLOCK_SIZE];
@@ -73,6 +75,22 @@ void main() {
}
L = tmpsh[0];
float sink;
if (p.sinks != 0) {
sink = data_s[n];
float ms = 1.0f;
float vs = 1.0f;
if (sink > m_max) {
ms = exp(m_max - sink);
} else {
vs = exp(sink - m_max);
}
L = L*ms + vs;
}
L = 1.0 / L;
// D dimension is split across workgroups in the y dimension
@@ -85,6 +103,13 @@ void main() {
float m = data_a[m_offset + k * lm_stride];
O += exp(m - m_max) * data_a[o_offset];
}
if (p.sinks != 0) {
if (sink > m_max) {
float ms = 1.0f;
ms = exp(m_max - sink);
O *= ms;
}
}
O *= L;
data_d[iq3 * D * N + D * n + d] = O;
}
@@ -14,4 +14,6 @@ layout (push_constant) uniform parameter
uint ne00;
uint ne20;
uint mode;
float alpha;
float limit;
} p;
@@ -747,6 +747,21 @@ void main() {
buf_a[buf_idx + 1 ] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]) * d;
buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)]) * d;
buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_iq4nl[vui >> 12]) * d;
#elif defined(DATA_A_MXFP4)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;
const uint ib = idx / 8;
const uint iqs = (idx & 0x07) * 2;
const float d = e8m0_to_fp32(data_a[ib].e);
const uint vui = uint(data_a[ib].qs[iqs]);
const uint vui2 = uint(data_a[ib].qs[iqs+1]);
buf_a[buf_idx ] = FLOAT_TYPE(kvalues_mxfp4[vui & 0xF] * d);
buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_mxfp4[vui >> 4] * d);
buf_a[buf_idx + 1] = FLOAT_TYPE(kvalues_mxfp4[vui2 & 0xF] * d);
buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_mxfp4[vui2 >> 4] * d);
#endif
}
[[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
@@ -92,6 +92,12 @@ FLOAT_TYPE get_d(uint ib) {
}
#endif
#if defined(DATA_A_MXFP4)
FLOAT_TYPE get_d(uint ib) {
return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e));
}
#endif
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
FLOAT_TYPE_VEC2 get_dm(uint ib) {
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
@@ -20,6 +20,7 @@ layout (push_constant) uniform parameter
float m1;
uint n_head_log2;
uint nrows_x;
uint has_sinks;
} p;
#include "types.comp"
@@ -29,7 +30,8 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
layout (binding = 2) buffer D {D_TYPE data_d[];};
layout (binding = 2) readonly buffer Z {float data_c[];};
layout (binding = 3) buffer D {D_TYPE data_d[];};
shared FLOAT_TYPE vals[BLOCK_SIZE];
@@ -60,13 +62,13 @@ void soft_max(uint num_iters) {
const uint h = (rowx / p.ne01) % p.ne02; // head index
const float base = h < p.n_head_log2 ? p.m0 : p.m1;
const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1;
const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1;
slope = pow(base, exp);
}
// Find max
FLOAT_TYPE max_val = uintBitsToFloat(0xFF800000);
FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02];
// Cache values while we compute the max, so we don't need to read them
// again when we're ready to compute exp(x-max).
@@ -148,6 +150,10 @@ void soft_max(uint num_iters) {
}
sum = vals[0];
if (p.has_sinks != 0) {
sum += FLOAT_TYPE(exp(FLOAT_TYPE(data_c[i02]) - max_val));
}
FLOAT_TYPE rcpdivisor = 1.0/sum;
[[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
@@ -0,0 +1,14 @@
#version 450
#include "glu_head.comp"
float op(float a, float b) {
float xi = min(a, p.limit);
float gi = max(min(b, p.limit), -p.limit);
float out_glu = xi / (1.0f + exp(-xi * p.alpha));
out_glu = out_glu * (1.0f + gi);
return out_glu;
}
#include "glu_main.comp"
@@ -1337,6 +1337,29 @@ struct block_iq4_nl_packed16
#define A_TYPE_PACKED16 block_iq4_nl_packed16
#endif
#define QUANT_K_MXFP4 32
#define QUANT_R_MXFP4 2
struct block_mxfp4
{
uint8_t e;
uint8_t qs[QUANT_K_MXFP4/2];
};
//struct block_mxfp4_packed16
//{
// uint8_t e;
// uint16_t qs[QUANT_K_MXFP4/2/2];
//};
#if defined(DATA_A_MXFP4)
#define QUANT_K QUANT_K_MXFP4
#define QUANT_R QUANT_R_MXFP4
#define QUANT_AUXF 1
#define A_TYPE block_mxfp4
//#define A_TYPE_PACKED16 block_mxfp4_packed16
#endif
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS)
const int8_t kvalues_iq4nl_const[16] = {
int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10),
@@ -1356,6 +1379,25 @@ void init_iq_shmem(uvec3 wgsize)
}
#endif
#if defined(DATA_A_MXFP4)
const FLOAT_TYPE kvalues_mxfp4_const[16] = {
FLOAT_TYPE(0.0f), FLOAT_TYPE(0.5f), FLOAT_TYPE(1.0f), FLOAT_TYPE(1.5f), FLOAT_TYPE(2.0f), FLOAT_TYPE(3.0f), FLOAT_TYPE(4.0f), FLOAT_TYPE(6.0f),
FLOAT_TYPE(-0.0f), FLOAT_TYPE(-0.5f), FLOAT_TYPE(-1.0f), FLOAT_TYPE(-1.5f), FLOAT_TYPE(-2.0f), FLOAT_TYPE(-3.0f), FLOAT_TYPE(-4.0f), FLOAT_TYPE(-6.0f)
};
shared FLOAT_TYPE kvalues_mxfp4[16];
#define NEEDS_INIT_IQ_SHMEM
void init_iq_shmem(uvec3 wgsize)
{
// copy the table into shared memory and sync
for (uint i = gl_LocalInvocationIndex.x; i < kvalues_mxfp4.length(); i += wgsize.x) {
kvalues_mxfp4[i] = kvalues_mxfp4_const[i];
}
barrier();
}
#endif
// returns the bfloat value in the low 16b.
// See ggml_compute_fp32_to_bf16
uint32_t fp32_to_bf16(float f)
@@ -1370,4 +1412,17 @@ float bf16_to_fp32(uint32_t u)
return uintBitsToFloat(u << 16);
}
float e8m0_to_fp32(uint8_t x) {
uint32_t bits;
if (x == 0) {
bits = 0x00400000;
} else {
bits = x;
bits = bits << 23;
}
return uintBitsToFloat(bits);
}
#endif // !defined(GGML_TYPES_COMP)

Some files were not shown because too many files have changed in this diff Show More