Compare commits

...

63 Commits

Author SHA1 Message Date
Acly 385c3da5e6 vulkan : fix FA mask load with bounds check (coopmat2) (#17606) 2025-11-30 01:03:21 +01:00
Xuan-Son Nguyen ab49f094d2 server: move server-context to its own cpp|h (#17595)
* git mv

* add server-context.h

* add server-context.h

* clean up headers

* cont : cleanup

* also expose server_response_reader (to be used by CLI)

* fix windows build

* decouple server_routes and server_http

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-11-29 22:04:44 +01:00
Haiyue Wang 8c32d9d96d server: explicitly set the function name in lambda (#17538)
As [1] explained, the real debug message will be like:
	"res    operator(): operator() : queue result stop"

Set the name explicitly, the message is easy for debugging:
	"res    operator(): recv : queue result stop"

The left "operator()" is generated by 'RES_DBG() ... __func__'

[1]: https://clang.llvm.org/extra/clang-tidy/checks/bugprone/lambda-function-name.html

Signed-off-by: Haiyue Wang <haiyuewa@163.com>
2025-11-29 18:43:29 +01:00
Igor Smirnov 0874693b44 common : fix json schema with '\' in literals (#17307)
* Fix json schema with '\' in literals

* Add "literal string with escapes" test
2025-11-29 17:06:32 +01:00
Neo Zhang 7d2add51d8 sycl : support to malloc memory on device more than 4GB, update the doc and script (#17566)
Co-authored-by: Neo Zhang Jianyu <jianyu.zhang@intel.com>
2025-11-29 14:59:44 +02:00
ixgbe f698a79c63 ggml: replace hwcap with riscv_hwprobe for RVV detection (#17567)
Signed-off-by: Wang Yang <yangwang@iscas.ac.cn>
2025-11-29 14:56:31 +02:00
Ruben Ortlam 47a268ea50 Vulkan: MMVQ Integer Dot K-Quant and MUL_MAT_ID support (#16900)
* vulkan: split mul_mmq_funcs for mul_mat_vecq use

* add mxfp4 mmvq

* add q2_k mmvq

* add q3_k mmvq

* add q4_k and q5_k mmvq

* add q6_k mmvq

* handle 4x4 quants per mmvq thread

* enable MUL_MAT_ID mmvq support

* enable subgroup optimizations for mul_mat_vec_id shaders

* device tuning

* request prealloc_y sync after quantization

* fix indentation

* fix llvmpipe test failures

* fix mul_mat_id mmvq condition

* fix unused variable warning
2025-11-29 09:37:22 +01:00
Jeff Bolz 59d8d4e963 vulkan: improve topk perf for large k, fix overflow in unit tests (#17582) 2025-11-29 08:39:57 +01:00
Aleksei Nikiforov d82b7a7c1d gguf-py : fix passing non-native endian tensors (editor-gui and new-metadata) (#17553)
gguf_new_metadata.py reads data from reader.
Reader doesn't byteswap tensors to native endianness.
But writer does expect tensors in native endianness to convert them
into requested endianness.

There are two ways to fix this: update reader and do conversion to native endianness and back,
or skip converting endianness in writer in this particular USE-case.

gguf_editor_gui.py doesn't allow editing or viewing tensor data.
Let's go with skipping excessive byteswapping.

If eventually capability to view or edit tensor data is added,
tensor data should be instead byteswapped when reading it.
2025-11-28 20:53:01 +01:00
DAN™ 03914c7ef8 common : move all common_chat_parse_* to chat-parser.cpp. (#17481) 2025-11-28 19:29:36 +01:00
o7si 3ce7a65c2f server: fix: /metrics endpoint returning JSON-escaped Prometheus format (#17386)
* fix: /metrics endpoint returning JSON-escaped Prometheus format

* mod: remove string overload from ok() method
2025-11-28 19:14:00 +01:00
Diego Devesa e072b2052e ggml : add GGML_SCHED_NO_REALLOC option to disable reallocations in ggml_backend_sched (#17276)
* ggml : add GGML_SCHED_NO_REALLOC option to disable reallocations in ggml_backend_sched
Enabled in ggml-ci for testing.

* llama : update worst-case graph for unified cache

* ci : disable op offload in some tests

* fix spelling

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-11-28 17:33:23 +02:00
R0CKSTAR c6f7a423c8 [MUSA] enable fp16/fast_fp16/bf16_mma on PH1 (#17551)
* [MUSA] enable fp16/fast_fp16/bf16_mma on PH1

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>

* Update ggml/src/ggml-cuda/fattn-vec.cuh

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Update ggml/src/ggml-cuda/fattn-vec.cuh

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Update ggml/src/ggml-cuda/fattn-tile.cuh

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Address review comments

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>

---------

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
2025-11-28 14:08:29 +01:00
Aman Gupta 2e7ef98f18 ggml-cuda: add stricter checking for fusion (#17568)
* ggml-cuda: make conditions for fusion more explicit

* ggml-cuda: remove size check as std::equal already does it
2025-11-28 20:34:51 +08:00
Fredrik Hultin ddf9f94389 server : add Anthropic Messages API support (#17570)
* server : add Anthropic Messages API support

* remove -@pytest.mark.slow from tool calling/jinja tests

* server : remove unused code and slow/skip on test_anthropic_vision_base64_with_multimodal_model in test_anthropic_api.py

* server : removed redundant n field logic in anthropic_params_from_json

* server : use single error object instead of error_array in streaming response handler for /v1/chat/completions and use unordered_set instead of set in to_json_anthropic_stream()

* server : refactor Anthropic API to use OAI conversion

* make sure basic test always go first

* clean up

* clean up api key check, add test

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
2025-11-28 12:57:04 +01:00
Piotr Wilkin (ilintar) ff55414c42 model : Qwen3 Next (#16095)
* Qwen3 Next - cleaned up version

* Whitespaces and stuff

* Correct minor errors

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Misc. fixes.

* Clean up code, add missing hybrid qualifier

* Did someone transpose the SOLVE_TRI result matrix? Perhaps...

* Whitespace

* Proper tensors for cb calls

* Use llama-graph.h vertical alignment

* BROKEN: chunking

* Set new tensors as inputs.

* Proper chunk logic

* It's the circle of life...

* More shenanigans for n_seq > 1

* Nail in the coffin?

* Fix Windows build

* Eh, one fails on Windows, the other fails on Mac... just use general capture.

* quant : cleanup

* model : cleanup

* qwen3 : cleanup

* cont : cleanup

* cont : cleanup

* ggml : revert change

* qwen3 : cleanup

* cont : cleanup

* Readd cmath

* qwen3 : fix typo

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Usual suspects

* fix my bad suggestion

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-11-28 12:02:56 +01:00
Johannes Gäßler 73955f7d2a CUDA: no FP16 arithmetic for vector FA kernel (#17558) 2025-11-28 10:29:09 +01:00
Jeff Bolz 35cf8887e1 vulkan: Implement GGML_OP_TRI (#17503)
* vulkan: Implement GGML_OP_TRI

* check types match
2025-11-28 10:07:29 +01:00
Radoslav Gerganov 15d2b46b4d rpc : cache and reuse compute graphs (#15405)
Store the last computed graph and reuse it when possible.
Also do not return response from GRAPH_COMPUTE and assume it always
completes successfully. If this this is not the case, the server closes
the connection. This saves us a network round trip to the server.
2025-11-28 08:33:51 +00:00
yulo 6bca76ff5e HIP: enable mul_mat_f for RDNA4 (#17437)
* enable mmf for rdna4

* move some mmvf to mmf

* revert lds128 for wmma loading

* Revert "revert lds128 for wmma loading"

This reverts commit db9ae8b6b4.

* Revert "enable mmf for rdna4"

This reverts commit 698c9f2418.

* Revert "move some mmvf to mmf"

This reverts commit 99b92bd665.

* enable mul_mat for rdna4

---------

Co-authored-by: zhang hui <you@example.com>
2025-11-28 08:24:30 +01:00
Piotr Wilkin (ilintar) cd0e3a7a3b SOLVE_TRI CUDA kernel for small matrices (#17457) 2025-11-28 12:15:32 +08:00
Neo Zhang Jianyu efaaccdd69 refactor pad_reflect_1d to make the UT case pass (#17204)
Co-authored-by: Zhang Jianyu <zhang.jianyu@outlook.com>
2025-11-28 08:50:56 +08:00
Jeff Bolz 4abef75f2c vulkan: Implement SOLVE_TRI (#17486)
* vulkan: Implement SOLVE_TRI

* load B matrix through shared memory

* use FLOAT_TYPE
2025-11-27 15:48:00 +01:00
Georgi Gerganov c386114922 arch : add description about LLM_TENSOR_INFOS (#17550) 2025-11-27 16:34:13 +02:00
Georgi Gerganov 6783b11fb0 models : fix LFM2 tensors (#17548) 2025-11-27 16:04:29 +02:00
matt23654 909072abcf cuda : fix UMA detection on discrete GPUs. (#17537) 2025-11-27 13:35:35 +02:00
Alberto Cabrera Pérez cd8370b408 ggml-cpu: aarm64: q4_K repack gemm and gemv implementations (dotprod only) (#17494)
* Enabled q4_K_4x8 path

* Fixed generic Q4_K 8x4 implementation

* wip: dotprod gemm

* Working arm q4_K dotprod gemm

Signed-off-by: Alberto Cabrera <alberto.cabrera@liquid.ai>

* Undo acc rename

Signed-off-by: Alberto Cabrera <alberto.cabrera@liquid.ai>

* Q4_K arm dotprod gemm

Signed-off-by: Alberto Cabrera <alberto.cabrera@liquid.ai>

* Fix: q4_qs reinterpret from uint to int

Signed-off-by: Alberto Cabrera <alberto.cabrera@liquid.ai>

* Removed comments

* Fixed macro guards

* Fixed unused vars in generic implementation

* Fixed unused vars in 8x4 repack

* Fixed unused vars in generic implementation, unneeded comment

* Missing arch fallback for x86

* minor : style

---------

Signed-off-by: Alberto Cabrera <alberto.cabrera@liquid.ai>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-11-27 13:25:14 +02:00
Eric Curtin d21a76ac38 devops: Add build-essential to Ubuntu 26.04 image (#17531)
This is no longer passing the build, needs more packages.

Signed-off-by: Eric Curtin <eric.curtin@docker.com>
2025-11-27 18:35:47 +08:00
Aleksei Nikiforov 4fcd87cf7c gguf-py : skip endian-conversion of MXFP4 data (#17523)
* gguf_convert_endian.py: skip MXFP4 data

* Use gguf.constants.GGML_QUANT_SIZES to determine block sizes
2025-11-27 11:35:38 +01:00
Acly b78db3bd50 vulkan : move contiguous checks to device_supports_op (#17490)
* vulkan : remove op_supports_incontiguous and add missing constraints in device_supports_op

* im2col: remove contraints on src0 (kernel input)
2025-11-27 06:54:19 +01:00
Jeff Bolz 142df17c9c vulkan: use a fixed 1KB buffer for the add_rms_fusion opt (#17514) 2025-11-27 06:32:30 +01:00
Xuan-Son Nguyen e509411cf1 server: enable jinja by default, update docs (#17524)
* server: enable jinja by default, update docs

* fix tests
2025-11-27 01:02:50 +01:00
lhez 7cba58bbea opencl: add sqr, sqrt, mean and ssm_conv (#17476)
* opencl: add sqr

* opencl: add sqrt

* opencl: add mean

* opencl: add ssm_conv

* opencl: add missing cl_khr_fp16

* opencl: do sqrt in f32 then convert to f16 for better precision
2025-11-26 13:29:58 -08:00
Alberto Cabrera Pérez 5449367b21 Fix chunks being too small with small matrix sizes (#17526) 2025-11-26 13:14:54 -08:00
Han Qingzhe 1d594c295c clip: (minicpmv) fix resampler kq_scale (#17516)
* debug:"solve minicpmv precision problem"

* “debug minicpmv”

* Apply suggestion from @ngxson

---------

Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
2025-11-26 21:44:07 +01:00
Jeff Bolz eec1e33a9e vulkan: allow graph_optimize for prompt processing workloads (#17475) 2025-11-26 16:46:33 +01:00
Jeff Bolz 879d673759 vulkan: Implement top-k (#17418)
* vulkan: Implement top-k

Each pass launches workgroups that each sort 2^N elements (where N is usually 7-10)
and discards all but the top K. Repeat until only K are left. And there's a fast
path when K==1 to just find the max value rather than sorting.

* fix pipeline selection

* vulkan: Add N-ary search algorithm for topk

* microoptimizations
2025-11-26 16:45:43 +01:00
xctan 6ab4e50d9c ggml-cpu : add RISC-V Zvfh impl for ggml_vec_mad_f16 (#17448)
* ggml-cpu : add RISC-V Zvfh impl for ggml_vec_mad_f16

* ggml-cpu : dedup scalar impl

* Update ggml/src/ggml-cpu/vec.h

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-11-26 15:33:05 +02:00
Adrien Gallouët 2336cc4784 cmake : use EXCLUDE_FROM_ALL to avoid patch-boringssl.cmake (#17520)
We have to separate the code path starting 3.28 because
`FetchContent_Populate` is now deprecated and will be completely removed
in a future version.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2025-11-26 15:15:21 +02:00
Adrien Gallouët e6923caaec ggml : fix ARM feature verification (#17519)
On arm64 with `cmake` version 3.31.6, the final feature verification fails:

    -- ARM detected flags: -mcpu=neoverse-v2+crc+sve2-aes+sve2-sha3+nossbs
    -- Performing Test GGML_MACHINE_SUPPORTS_dotprod
    -- Performing Test GGML_MACHINE_SUPPORTS_dotprod - Success
    -- Performing Test GGML_MACHINE_SUPPORTS_i8mm
    -- Performing Test GGML_MACHINE_SUPPORTS_i8mm - Success
    -- Performing Test GGML_MACHINE_SUPPORTS_sve
    -- Performing Test GGML_MACHINE_SUPPORTS_sve - Success
    -- Performing Test GGML_MACHINE_SUPPORTS_sme
    -- Performing Test GGML_MACHINE_SUPPORTS_sme - Failed
    -- Performing Test GGML_MACHINE_SUPPORTS_nosme
    -- Performing Test GGML_MACHINE_SUPPORTS_nosme - Success
    -- Checking for ARM features using flags:
    --   -U__ARM_FEATURE_SME
    --   -mcpu=neoverse-v2+crc+sve2-aes+sve2-sha3+nossbs+dotprod+i8mm+sve+nosme
    -- Performing Test HAVE_DOTPROD
    -- Performing Test HAVE_DOTPROD - Failed
    -- Performing Test HAVE_SVE
    -- Performing Test HAVE_SVE - Failed
    -- Performing Test HAVE_MATMUL_INT8
    -- Performing Test HAVE_MATMUL_INT8 - Failed
    -- Performing Test HAVE_FMA
    -- Performing Test HAVE_FMA - Success
    -- Performing Test HAVE_FP16_VECTOR_ARITHMETIC
    -- Performing Test HAVE_FP16_VECTOR_ARITHMETIC - Failed
    -- Performing Test HAVE_SME
    -- Performing Test HAVE_SME - Failed
    -- Adding CPU backend variant ggml-cpu: -U__ARM_FEATURE_SME;-mcpu=neoverse-v2+crc+sve2-aes+sve2-sha3+nossbs+dotprod+i8mm+sve+nosme

We need to explicitly replace `;` with spaces from the list to make
`CMAKE_REQUIRED_FLAGS` work correctly...

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2025-11-26 15:14:41 +02:00
Jiacheng (Jason) Chen 3e18dba9fd HIP: Patch failed testcase in WMMA-MMQ kernels for RDNA 4 (#17502)
* patch failed test case MUL_MAT(type_a=q4_0,type_b=f32,m=576,n=512,k=576,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) for enabling WMMA on RDNA4

* Quick clean up on mma.cuh to add ggml_cuda_memcpy_1 back in for half2 and bfloat162
2025-11-26 11:18:48 +01:00
hipudding eeb5605de2 CANN: Add MROPE and IMROPE support (#17401)
* CANN: ROPE supports both MROPE and IMROPE.

1. Optimize the caching logic of rope_cache_init.
2. Add support for mRoPE and i-mRoPE.

Note that on Ascend 910B devices, it is necessary to disable FA
in CLIP and disable NZ-format conversion. These two issues are
still under investigation.

* Resolve review comments
2025-11-26 16:44:19 +08:00
o7si f3a848a3b1 chore: upgrade cpp-httplib from v0.27.0 to v0.28.0 (#17513) 2025-11-26 09:21:06 +02:00
Jeff Bolz b3b03a7baf vulkan: Implement GGML_OP_CUMSUM (#17479) 2025-11-26 07:08:10 +01:00
Georgi Gerganov 583cb83416 ggml : add ggml_top_k (#17365)
* ggml : add ggml_top_k

* cont : add ggml_argsort_top_k

* metal : add top_k support

* ggml : cleanup

* tests : add virtual err() function for test_case

* ggml : add comments
2025-11-25 15:31:43 +02:00
Aleksei Nikiforov 05872ac885 convert : fix big-endian conversion (#17431)
* Fix convert_hf_to_gguf.py script on s390x

Assume converted model data is originally little-endian.
Byteswap data on s390x after reading it to put values in correct presentation
for any transformation needed, like calculating weight tensors.

Then byteswap data to little-endian before passing it to GGUFWriter while
GGUFWriter will byteswap data back to big endian if big endian output is requested.

byteswap(inplace=True) calls don't work with lazy tensor and array wrappers.
Use byteswap with copying data to workaround this behaviour.

* Make GGUFWriter accept tensors in native endianness instead of little-endian

With this change if no byteswapping is actually needed, 2 excessive byteswaps can be omitted on s390x

* Fix byteswapping in convert_hf_to_gguf.py for remote models
2025-11-25 14:18:16 +01:00
Diego Devesa 55ab25caf5 codeowners : remove slaren (#17492) 2025-11-25 13:00:23 +01:00
TianHao324 064c90d843 CANN: supports out_prod operator for F32 and F16 (#17406)
Co-authored-by: tianhao <tianhao42@huawei.com>
2025-11-25 17:39:06 +08:00
Pascal b1846f1c8e webui: add rehype plugin to restore HTML in Markdown table cells (#17477)
* webui: add rehype plugin to restore HTML in Markdown table cells

The remark/rehype pipeline neutralizes inline HTML as literal text
(remarkLiteralHtml) so that XML/HTML snippets in LLM responses display
as-is instead of being rendered. This causes <br> and <ul> markup in
table cells to show as plain text.

This plugin traverses the HAST post-conversion, parses whitelisted HTML
patterns (<br>, <ul><li>) from text nodes, and replaces them with actual
HAST element nodes. For lists, adjacent siblings must be combined first
as the AST fragmentation breaks pattern matching.

Strict validation rejects malformed markup, keeping it as raw text.

* chore: update webui build output
2025-11-25 08:01:02 +01:00
Jeff Bolz d414db02d3 vulkan: Use fewer rows for scalar FA when HS is not a multiple of 16 (#17455) 2025-11-25 07:11:27 +01:00
Aaron Teo 877566d512 llama: introduce support for model-embedded sampling parameters (#17120) 2025-11-25 09:56:07 +08:00
Jeff Bolz 3d07caa99b vulkan: more FA details in vk_perf_logger (#17443) 2025-11-24 22:25:24 +01:00
Daniel Bevenius 134e6940ca llama : skip output reordering for single token batches (#17466)
This commit adds a check to skip the output reordering logic when
n_outputs == 1. With a single output token, the data is trivially
sorted and the reordering code is currently doing unnecessary work
(resetting and rebuilding output_ids to the same values).

The motivation for this change is improved code clarity and avoiding
confusion when debugging. While the performance impact is probably
negligible, this unnecessary work happens on every decode call in
llama-server when processing batches with single-token outputs.
2025-11-24 21:06:17 +01:00
Jiacheng (Jason) Chen 0543f928a3 HIP: WMMA-MMQ kernels for RDNA 4 (#17156)
* first commit naive test to enable mmq for RDNA4

* adding appropriate WMMA instructions

* git rebase on top of master: fixing the correctness of the mat mul operations, updating layout mappings for RDNA4

* clean up merge conflicts

* add comments and code clean up

* PR clean up, addressed comments

* enable MMQ fallback on RDNA4

* addressed comments: add guards in load generic, separate wmma branch for use_mmq function

* Revert build-xcframework.sh

* Formating: remove trailing whitespace

* revert CMake files

* clean up after rebase: remove duplicated change, revert cmake files

* clean up after rebase: revert changes from build-xcframework.sh

* clean up: remove extra space line in mma.cuh

* Revert "clean up: remove extra space line in mma.cuh"

This reverts commit b39ed57c45.
2025-11-24 20:00:10 +01:00
Sigbjørn Skjæret b61de2b2df convert : allow quantizing lora again (#17453) 2025-11-24 15:50:55 +01:00
Xuan-Son Nguyen b8372eecd9 server: split server.cpp code into server/common/task/queue (#17362)
* add server-task, server-common

* add server-queue

* rm redundant includes

* move enum stop_type to server-task

* server : headers cleanup

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-11-24 14:41:53 +01:00
Daniel Bevenius 6ab8eacddf examples : add -kvu to batched usage example [no ci] (#17469)
This commit adds the --kv-unified flag to the usage example
in the README.md file for the batched example.

The motivation for this is that without this flag the example will fail
with the following error:
```console
Hello my name is
split_equal: sequential split is not supported when there are coupled
sequences in the input batch (you may need to use the -kvu flag)
decode: failed to find a memory slot for batch of size 4
main: llama_decode() failed
```
2025-11-24 15:38:45 +02:00
Georgi Gerganov 2d50b9d8cb sync : ggml 2025-11-24 15:26:31 +02:00
Daniel Bevenius 697edfeead ggml : remove dirty flag from version string (ggml/1391)
This commit removes the "-dirty" suffix from the GGML version string.

The motivation for this change is to ensure that the version string
works with different ways of checking out ggml and using it in projects.
By removing the dirty flag from the version string, we avoid potential
artifacts like shared libraries getting a -dirty suffix in their names.

Instead, if the project is built from a dirty git state, the dirty flag
will be appended to the commit hash in the GGML_BUILD_COMMIT variable.
This will enable users to still identify that the build was made from
from a modified/dirty state even though the version might match a "real"
version.

For example, the commit can be produces as follows:
```c++
    printf("commit: %s\n", ggml_commit());
```
Which would print the following for a dirty build:
```console
commit: 781baf2a-dirty
```

Refs: https://github.com/ggml-org/ggml/pull/1363#issuecomment-3569691546
2025-11-24 15:26:31 +02:00
Alberto Cabrera Pérez dbb852b549 ggml-cpu: arm64: q4_K repack gemm and gemv implementations (i8mm) (#16739)
* Enabled q4_K_8x8_q8_K path on ARM

* wip: I8mm qs multiplication, pending bias

* cpu : arm : REPACK gemm q4_K8x8 implementation

Signed-off-by: Alberto Cabrera <alberto.cabrera@liquid.ai>

* Guard gemm with proper features, improved superblock scale and min calc

Signed-off-by: Alberto Cabrera <alberto.cabrera@liquid.ai>

* cpu: arm: Implemented REPACK gemv for Q4_K

Signed-off-by: Alberto Cabrera <alberto.cabrera@liquid.ai>

* Removed completed TODO

* Fixed missing guards when selecting optimal repack type for Q4_K

Signed-off-by: Alberto Cabrera <alberto.cabrera@liquid.ai>

* Fixed macro guard for gemv

* Fixed wrong comment in GEMV

* Fixed warning for unused variable

* vdotq_s32 -> ggml_vdotq_s32

Signed-off-by: Alberto Cabrera <alberto.cabrera@liquid.ai>

* Clang-format issues

* Apply suggestions from code review

Co-authored-by: Diego Devesa <slarengh@gmail.com>

* Removed unnecessary GGML_UNUSED

* Fixed guards in q4_k gemm and gemv (repack)

---------

Signed-off-by: Alberto Cabrera <alberto.cabrera@liquid.ai>
Co-authored-by: Diego Devesa <slarengh@gmail.com>
2025-11-24 13:08:11 +02:00
ixgbe 5f55c385cb ggml: add RISC-V cpu-feats (#17461)
* ggml: add RISC-V cpu-feats

Signed-off-by: Wang Yang <yangwang@iscas.ac.cn>

* fix comment[1]

---------

Signed-off-by: Wang Yang <yangwang@iscas.ac.cn>
2025-11-24 13:07:14 +02:00
william pan 4902eebe33 models : Added support for RND1 Diffusion Language Model (#17433)
* Converted RND1 model to GGUF weights

* RND1 llama.cpp support v1

* RND1 llama.cpp support v2 non causal bug

* RND1 llama.cpp support v3 doccumentation

* RND1 llama.cpp support v4 clean code

* linting issues

* RND1 pr fixes v1

* RND1 pr fixes v2

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Diffusion documentation edits

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2025-11-24 14:16:56 +08:00
Max Krasnyansky 923ae3c619 hexagon: add support for ROPE_NEOX (#17458) 2025-11-23 18:55:56 -08:00
150 changed files with 17422 additions and 8647 deletions
+1
View File
@@ -50,6 +50,7 @@ WORKDIR /app
RUN apt-get update \
&& apt-get install -y \
build-essential \
git \
python3 \
python3-pip \
+8 -23
View File
@@ -2,10 +2,8 @@
# multiplie collaborators per item can be specified
/.devops/*.Dockerfile @ngxson
/.github/actions/ @slaren @CISC
/.github/actions/ @CISC
/.github/workflows/ @CISC
/.github/workflows/release.yml @slaren
/.github/workflows/winget.yml @slaren
/ci/ @ggerganov
/cmake/ @ggerganov
/common/CMakeLists.txt @ggerganov
@@ -40,21 +38,14 @@
/examples/passkey/ @ggerganov
/examples/retrieval/ @ggerganov
/examples/save-load-state/ @ggerganov
/examples/simple-chat/ @slaren
/examples/simple/ @slaren
/examples/speculative-simple/ @ggerganov
/examples/speculative/ @ggerganov
/ggml/cmake/ @ggerganov
/ggml/include/ @ggerganov @slaren
/ggml/src/ggml-alloc.c @slaren
/ggml/src/ggml-backend* @slaren
/ggml/src/ggml-blas/ @slaren
/ggml/src/ggml-common.h @ggerganov @slaren
/ggml/src/ggml-cpu/ @ggerganov @slaren
/ggml/include/ @ggerganov
/ggml/src/ggml-common.h @ggerganov
/ggml/src/ggml-cpu/ @ggerganov
/ggml/src/ggml-cpu/spacemit/ @alex-spacemit
/ggml/src/ggml-cuda/common.cuh @slaren
/ggml/src/ggml-cuda/fattn* @JohannesGaessler
/ggml/src/ggml-cuda/ggml-cuda.cu @slaren
/ggml/src/ggml-cuda/mmf.* @JohannesGaessler @am17an
/ggml/src/ggml-cuda/mmq.* @JohannesGaessler
/ggml/src/ggml-cuda/mmvf.* @JohannesGaessler
@@ -62,19 +53,19 @@
/ggml/src/ggml-cuda/fattn-wmma* @IMbackK
/ggml/src/ggml-hip/ @IMbackK
/ggml/src/ggml-cuda/vendors/hip.h @IMbackK
/ggml/src/ggml-impl.h @ggerganov @slaren
/ggml/src/ggml-impl.h @ggerganov
/ggml/src/ggml-metal/ @ggerganov
/ggml/src/ggml-opencl/ @lhez @max-krasnyansky
/ggml/src/ggml-hexagon/ @max-krasnyansky @lhez
/ggml/src/ggml-opt.cpp @JohannesGaessler
/ggml/src/ggml-quants.* @ggerganov
/ggml/src/ggml-rpc/ @rgerganov
/ggml/src/ggml-threading.* @ggerganov @slaren
/ggml/src/ggml-threading.* @ggerganov
/ggml/src/ggml-vulkan/ @0cc4m
/ggml/src/ggml-webgpu/ @reeselevine
/ggml/src/ggml-zdnn/ @taronaeo @Andreas-Krebbel @AlekseiNikiforovIBM
/ggml/src/ggml.c @ggerganov @slaren
/ggml/src/ggml.cpp @ggerganov @slaren
/ggml/src/ggml.c @ggerganov
/ggml/src/ggml.cpp @ggerganov
/ggml/src/gguf.cpp @JohannesGaessler @Green-Sky
/gguf-py/ @CISC
/media/ @ggerganov
@@ -86,15 +77,11 @@
/src/llama-arch.* @CISC
/src/llama-chat.* @ngxson
/src/llama-graph.* @CISC
/src/llama-model-loader.* @slaren
/src/llama-model.* @CISC
/src/llama-vocab.* @CISC
/src/models/ @CISC
/tests/ @ggerganov
/tests/test-backend-ops.cpp @slaren
/tests/test-thread-safety.cpp @slaren
/tools/batched-bench/ @ggerganov
/tools/llama-bench/ @slaren
/tools/main/ @ggerganov
/tools/mtmd/ @ngxson
/tools/perplexity/ @ggerganov
@@ -106,8 +93,6 @@
/tools/tokenize/ @ggerganov
/tools/tts/ @ggerganov
/vendor/ @ggerganov
/.clang-format @slaren
/.clang-tidy @slaren
/AUTHORS @ggerganov
/CMakeLists.txt @ggerganov
/CONTRIBUTING.md @ggerganov
+8 -8
View File
@@ -45,7 +45,7 @@ sd=`dirname $0`
cd $sd/../
SRC=`pwd`
CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON"
CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON -DGGML_SCHED_NO_REALLOC=ON"
if [ ! -z ${GG_BUILD_METAL} ]; then
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=ON"
@@ -428,10 +428,10 @@ function gg_run_qwen3_0_6b {
(time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test} -ngl 99 -c 1024 -b 512 --chunks 2 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa off --no-op-offload) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa on --no-op-offload) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
function check_ppl {
qnt="$1"
@@ -523,8 +523,8 @@ function gg_run_embd_bge_small {
./bin/llama-quantize ${model_f16} ${model_q8_0} q8_0
(time ./bin/llama-embedding --model ${model_f16} -p "I believe the meaning of life is" -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log
(time ./bin/llama-embedding --model ${model_q8_0} -p "I believe the meaning of life is" -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log
(time ./bin/llama-embedding --model ${model_f16} -p "I believe the meaning of life is" -ngl 99 -c 0 --no-op-offload) 2>&1 | tee -a $OUT/${ci}-tg-f16.log
(time ./bin/llama-embedding --model ${model_q8_0} -p "I believe the meaning of life is" -ngl 99 -c 0 --no-op-offload) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log
set +e
}
@@ -564,7 +564,7 @@ function gg_run_rerank_tiny {
model_f16="${path_models}/ggml-model-f16.gguf"
# for this model, the SEP token is "</s>"
(time ./bin/llama-embedding --model ${model_f16} -p "what is panda?\thi\nwhat is panda?\tit's a bear\nwhat is panda?\tThe giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log
(time ./bin/llama-embedding --model ${model_f16} -p "what is panda?\thi\nwhat is panda?\tit's a bear\nwhat is panda?\tThe giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --no-op-offload --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log
# sample output
# rerank score 0: 0.029
+26 -1
View File
@@ -694,6 +694,12 @@ static bool is_autoy(const std::string & value) {
}
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **)) {
// default values specific to example
// note: we place it here instead of inside server.cpp to allow llama-gen-docs to pick it up
if (ex == LLAMA_EXAMPLE_SERVER) {
params.use_jinja = true;
}
// load dynamic backends
ggml_backend_load_all();
@@ -1232,6 +1238,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, const std::string & value) {
const auto sampler_names = string_split<std::string>(value, ';');
params.sampling.samplers = common_sampler_types_from_names(sampler_names, true);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS;
}
).set_sparam());
add_opt(common_arg(
@@ -1261,6 +1268,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, const std::string & value) {
params.sampling.temp = std::stof(value);
params.sampling.temp = std::max(params.sampling.temp, 0.0f);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TEMP;
}
).set_sparam());
add_opt(common_arg(
@@ -1268,6 +1276,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("top-k sampling (default: %d, 0 = disabled)", params.sampling.top_k),
[](common_params & params, int value) {
params.sampling.top_k = value;
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K;
}
).set_sparam());
add_opt(common_arg(
@@ -1275,6 +1284,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p),
[](common_params & params, const std::string & value) {
params.sampling.top_p = std::stof(value);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P;
}
).set_sparam());
add_opt(common_arg(
@@ -1282,6 +1292,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sampling.min_p),
[](common_params & params, const std::string & value) {
params.sampling.min_p = std::stof(value);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P;
}
).set_sparam());
add_opt(common_arg(
@@ -1296,6 +1307,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability),
[](common_params & params, const std::string & value) {
params.sampling.xtc_probability = std::stof(value);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY;
}
).set_sparam());
add_opt(common_arg(
@@ -1303,6 +1315,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sampling.xtc_threshold),
[](common_params & params, const std::string & value) {
params.sampling.xtc_threshold = std::stof(value);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD;
}
).set_sparam());
add_opt(common_arg(
@@ -1321,6 +1334,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
params.sampling.penalty_last_n = value;
params.sampling.n_prev = std::max(params.sampling.n_prev, params.sampling.penalty_last_n);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N;
}
).set_sparam());
add_opt(common_arg(
@@ -1328,6 +1342,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sampling.penalty_repeat),
[](common_params & params, const std::string & value) {
params.sampling.penalty_repeat = std::stof(value);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT;
}
).set_sparam());
add_opt(common_arg(
@@ -1425,6 +1440,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
"(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", params.sampling.mirostat),
[](common_params & params, int value) {
params.sampling.mirostat = value;
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT;
}
).set_sparam());
add_opt(common_arg(
@@ -1432,6 +1448,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sampling.mirostat_eta),
[](common_params & params, const std::string & value) {
params.sampling.mirostat_eta = std::stof(value);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA;
}
).set_sparam());
add_opt(common_arg(
@@ -1439,6 +1456,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sampling.mirostat_tau),
[](common_params & params, const std::string & value) {
params.sampling.mirostat_tau = std::stof(value);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU;
}
).set_sparam());
add_opt(common_arg(
@@ -2476,11 +2494,18 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--jinja"},
"use jinja template for chat (default: disabled)",
string_format("use jinja template for chat (default: %s)\n", params.use_jinja ? "enabled" : "disabled"),
[](common_params & params) {
params.use_jinja = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_JINJA"));
add_opt(common_arg(
{"--no-jinja"},
string_format("disable jinja template for chat (default: %s)\n", params.use_jinja ? "enabled" : "disabled"),
[](common_params & params) {
params.use_jinja = false;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_NO_JINJA"));
add_opt(common_arg(
{"--reasoning-format"}, "FORMAT",
"controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n"
+968
View File
@@ -13,6 +13,120 @@
using json = nlohmann::ordered_json;
static void parse_prefixed_json_tool_call_array(common_chat_msg_parser & builder,
const common_regex & prefix,
size_t rstrip_prefix = 0) {
static const std::vector<std::vector<std::string>> args_paths = { { "arguments" } };
if (auto res = builder.try_find_regex(prefix)) {
builder.move_back(rstrip_prefix);
auto tool_calls = builder.consume_json_with_dumped_args(args_paths);
if (!builder.add_tool_calls(tool_calls.value) || tool_calls.is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call array");
}
} else {
builder.add_content(builder.consume_rest());
}
}
static std::string wrap_code_as_arguments(common_chat_msg_parser & builder, const std::string & code) {
std::string arguments;
if (builder.is_partial()) {
arguments = (json{
{ "code", code + builder.healing_marker() }
})
.dump();
auto idx = arguments.find(builder.healing_marker());
if (idx != std::string::npos) {
arguments.resize(idx);
}
} else {
arguments = (json{
{ "code", code }
})
.dump();
}
return arguments;
}
/**
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
* Aggregates the prefix, suffix and in-between text into the content.
*/
static void parse_json_tool_calls(
common_chat_msg_parser & builder,
const std::optional<common_regex> & block_open,
const std::optional<common_regex> & function_regex_start_only,
const std::optional<common_regex> & function_regex,
const common_regex & close_regex,
const std::optional<common_regex> & block_close,
bool allow_raw_python = false,
const std::function<std::string(const common_chat_msg_parser::find_regex_result & fres)> & get_function_name =
nullptr) {
auto parse_tool_calls = [&]() {
size_t from = std::string::npos;
auto first = true;
while (true) {
auto start_pos = builder.pos();
auto res = function_regex_start_only && first ? builder.try_consume_regex(*function_regex_start_only) :
function_regex ? builder.try_find_regex(*function_regex, from) :
std::nullopt;
if (res) {
std::string name;
if (get_function_name) {
name = get_function_name(*res);
} else {
GGML_ASSERT(res->groups.size() == 2);
name = builder.str(res->groups[1]);
}
first = false;
if (name.empty()) {
// get_function_name signalled us that we should skip this match and treat it as content.
from = res->groups[0].begin + 1;
continue;
}
from = std::string::npos;
auto maybe_raw_python = name == "python" && allow_raw_python;
if (builder.input()[builder.pos()] == '{' || !maybe_raw_python) {
if (auto arguments = builder.try_consume_json_with_dumped_args({ {} })) {
if (!builder.add_tool_call(name, "", arguments->value) || arguments->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
builder.consume_regex(close_regex);
}
continue;
}
if (maybe_raw_python) {
auto arguments = wrap_code_as_arguments(builder, builder.consume_rest());
if (!builder.add_tool_call(name, "", arguments)) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
return;
}
throw common_chat_msg_partial_exception("incomplete tool call");
} else {
builder.move_to(start_pos);
}
break;
}
if (block_close) {
builder.consume_regex(*block_close);
}
builder.consume_spaces();
builder.add_content(builder.consume_rest());
};
if (block_open) {
if (auto res = builder.try_find_regex(*block_open)) {
parse_tool_calls();
} else {
builder.add_content(builder.consume_rest());
}
} else {
parse_tool_calls();
}
}
common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax)
: input_(input), is_partial_(is_partial), syntax_(syntax)
{
@@ -532,3 +646,857 @@ std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parse
void common_chat_msg_parser::clear_tools() {
result_.tool_calls.clear();
}
/**
* All common_chat_parse_* moved from chat.cpp to chat-parser.cpp below
* to reduce incremental compile time for parser changes.
*/
static void common_chat_parse_generic(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const std::vector<std::vector<std::string>> content_paths = {
{"response"},
};
static const std::vector<std::vector<std::string>> args_paths = {
{"tool_call", "arguments"},
{"tool_calls", "arguments"},
};
auto data = builder.consume_json_with_dumped_args(args_paths, content_paths);
if (data.value.contains("tool_calls")) {
if (!builder.add_tool_calls(data.value.at("tool_calls")) || data.is_partial) {
throw common_chat_msg_partial_exception("incomplete tool calls");
}
} else if (data.value.contains("tool_call")) {
if (!builder.add_tool_call(data.value.at("tool_call")) || data.is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
} else if (data.value.contains("response")) {
const auto & response = data.value.at("response");
builder.add_content(response.is_string() ? response.template get<std::string>() : response.dump(2));
if (data.is_partial) {
throw common_chat_msg_partial_exception("incomplete response");
}
} else {
throw common_chat_msg_partial_exception("Expected 'tool_call', 'tool_calls' or 'response' in JSON");
}
}
static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex prefix(regex_escape("[TOOL_CALLS]"));
parse_prefixed_json_tool_call_array(builder, prefix);
}
static void common_chat_parse_magistral(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("[THINK]", "[/THINK]");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex prefix(regex_escape("[TOOL_CALLS]"));
parse_prefixed_json_tool_call_array(builder, prefix);
}
static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>");
static const common_regex start_action_regex("<\\|START_ACTION\\|>");
static const common_regex end_action_regex("<\\|END_ACTION\\|>");
static const common_regex start_response_regex("<\\|START_RESPONSE\\|>");
static const common_regex end_response_regex("<\\|END_RESPONSE\\|>");
if (auto res = builder.try_find_regex(start_action_regex)) {
// If we didn't extract thoughts, prelude includes them.
auto tool_calls = builder.consume_json_with_dumped_args({{"parameters"}});
for (const auto & tool_call : tool_calls.value) {
std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : "";
std::string id = tool_call.contains("tool_call_id") ? tool_call.at("tool_call_id") : "";
std::string arguments = tool_call.contains("parameters") ? tool_call.at("parameters") : "";
if (!builder.add_tool_call(name, id, arguments) || tool_calls.is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
}
if (tool_calls.is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
builder.consume_regex(end_action_regex);
} else if (auto res = builder.try_find_regex(start_response_regex)) {
if (!builder.try_find_regex(end_response_regex)) {
builder.add_content(builder.consume_rest());
throw common_chat_msg_partial_exception(end_response_regex.str());
}
} else {
builder.add_content(builder.consume_rest());
}
}
static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) {
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex function_regex(
"\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: ");
static const common_regex close_regex("\\}\\s*");
static const common_regex function_name_regex("\\s*(\\w+)\\s*\\.\\s*call\\(");
static const common_regex arg_name_regex("\\s*(\\w+)\\s*=\\s*");
if (with_builtin_tools) {
static const common_regex builtin_call_regex("<\\|python_tag\\|>");
if (auto res = builder.try_find_regex(builtin_call_regex)) {
auto fun_res = builder.consume_regex(function_name_regex);
auto function_name = builder.str(fun_res.groups[1]);
common_healing_marker healing_marker;
json args = json::object();
while (true) {
if (auto arg_res = builder.try_consume_regex(arg_name_regex)) {
auto arg_name = builder.str(arg_res->groups[1]);
auto partial = builder.consume_json();
args[arg_name] = partial.json;
healing_marker.marker = partial.healing_marker.marker;
healing_marker.json_dump_marker = partial.healing_marker.json_dump_marker;
builder.consume_spaces();
if (!builder.try_consume_literal(",")) {
break;
}
} else {
break;
}
}
builder.consume_literal(")");
builder.consume_spaces();
auto arguments = args.dump();
if (!builder.add_tool_call(function_name, "", arguments)) {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
return;
}
}
parse_json_tool_calls(
builder,
/* block_open= */ std::nullopt,
/* function_regex_start_only= */ function_regex,
/* function_regex= */ std::nullopt,
close_regex,
std::nullopt);
}
static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex tool_calls_begin("(?:<tool▁calls▁begin>|<tool_calls_begin>|<tool calls begin>|<tool\\\\_calls\\\\_begin>|<tool▁calls>)");
static const common_regex tool_calls_end("<tool▁calls▁end>");
static const common_regex function_regex("(?:<tool▁call▁begin>)?function<tool▁sep>([^\n]+)\n```json\n");
static const common_regex close_regex("```[\\s\\r\\n]*<tool▁call▁end>");
parse_json_tool_calls(
builder,
/* block_open= */ tool_calls_begin,
/* function_regex_start_only= */ std::nullopt,
function_regex,
close_regex,
tool_calls_end);
}
static void common_chat_parse_deepseek_v3_1_content(common_chat_msg_parser & builder) {
static const common_regex function_regex("(?:<tool▁call▁begin>)?([^\\n<]+)(?:<tool▁sep>)");
static const common_regex close_regex("(?:[\\s]*)?<tool▁call▁end>");
static const common_regex tool_calls_begin("(?:<tool▁calls▁begin>|<tool_calls_begin>|<tool calls begin>|<tool\\\\_calls\\\\_begin>|<tool▁calls>)");
static const common_regex tool_calls_end("<tool▁calls▁end>");
if (!builder.syntax().parse_tool_calls) {
LOG_DBG("%s: not parse_tool_calls\n", __func__);
builder.add_content(builder.consume_rest());
return;
}
LOG_DBG("%s: parse_tool_calls\n", __func__);
parse_json_tool_calls(
builder,
/* block_open= */ tool_calls_begin,
/* function_regex_start_only= */ std::nullopt,
function_regex,
close_regex,
tool_calls_end);
}
static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) {
// DeepSeek V3.1 outputs reasoning content between "<think>" and "</think>" tags, followed by regular content
// First try to parse using the standard reasoning parsing method
LOG_DBG("%s: thinking_forced_open: %s\n", __func__, std::to_string(builder.syntax().thinking_forced_open).c_str());
auto start_pos = builder.pos();
auto found_end_think = builder.try_find_literal("</think>");
builder.move_to(start_pos);
if (builder.syntax().thinking_forced_open && !builder.is_partial() && !found_end_think) {
LOG_DBG("%s: no end_think, not partial, adding content\n", __func__);
common_chat_parse_deepseek_v3_1_content(builder);
} else if (builder.try_parse_reasoning("<think>", "</think>")) {
// If reasoning was parsed successfully, the remaining content is regular content
LOG_DBG("%s: parsed reasoning, adding content\n", __func__);
// </think><tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>NAME\n```json\nJSON\n```<tool▁call▁end><tool▁calls▁end>
common_chat_parse_deepseek_v3_1_content(builder);
} else {
if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE) {
LOG_DBG("%s: reasoning_format none, adding content\n", __func__);
common_chat_parse_deepseek_v3_1_content(builder);
return;
}
// If no reasoning tags found, check if we should treat everything as reasoning
if (builder.syntax().thinking_forced_open) {
// If thinking is forced open but no tags found, treat everything as reasoning
LOG_DBG("%s: thinking_forced_open, adding reasoning content\n", __func__);
builder.add_reasoning_content(builder.consume_rest());
} else {
LOG_DBG("%s: no thinking_forced_open, adding content\n", __func__);
// <tool▁call▁begin>NAME<tool▁sep>JSON<tool▁call▁end>
common_chat_parse_deepseek_v3_1_content(builder);
}
}
}
static void common_chat_parse_minimax_m2(common_chat_msg_parser & builder) {
static const xml_tool_call_format form {
/* form.scope_start = */ "<minimax:tool_call>",
/* form.tool_start = */ "<invoke name=\"",
/* form.tool_sep = */ "\">",
/* form.key_start = */ "<parameter name=\"",
/* form.key_val_sep = */ "\">",
/* form.val_end = */ "</parameter>",
/* form.tool_end = */ "</invoke>",
/* form.scope_end = */ "</minimax:tool_call>",
};
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
}
static void common_chat_parse_qwen3_coder_xml(common_chat_msg_parser & builder) {
static const xml_tool_call_format form = ([]() {
xml_tool_call_format form {};
form.scope_start = "<tool_call>";
form.tool_start = "<function=";
form.tool_sep = ">";
form.key_start = "<parameter=";
form.key_val_sep = ">";
form.val_end = "</parameter>";
form.tool_end = "</function>";
form.scope_end = "</tool_call>";
form.trim_raw_argval = true;
return form;
})();
builder.consume_reasoning_with_xml_tool_calls(form);
}
static void common_chat_parse_kimi_k2(common_chat_msg_parser & builder) {
static const xml_tool_call_format form = ([]() {
xml_tool_call_format form {};
form.scope_start = "<|tool_calls_section_begin|>";
form.tool_start = "<|tool_call_begin|>";
form.tool_sep = "<|tool_call_argument_begin|>{";
form.key_start = "\"";
form.key_val_sep = "\": ";
form.val_end = ", ";
form.tool_end = "}<|tool_call_end|>";
form.scope_end = "<|tool_calls_section_end|>";
form.raw_argval = false;
form.last_val_end = "";
return form;
})();
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
}
static void common_chat_parse_apriel_1_5(common_chat_msg_parser & builder) {
static const xml_tool_call_format form = ([]() {
xml_tool_call_format form {};
form.scope_start = "<tool_calls>[";
form.tool_start = "{\"name\": \"";
form.tool_sep = "\", \"arguments\": {";
form.key_start = "\"";
form.key_val_sep = "\": ";
form.val_end = ", ";
form.tool_end = "}, ";
form.scope_end = "]</tool_calls>";
form.raw_argval = false;
form.last_val_end = "";
form.last_tool_end = "}";
return form;
})();
builder.consume_reasoning_with_xml_tool_calls(form, "<thinking>", "</thinking>");
}
static void common_chat_parse_xiaomi_mimo(common_chat_msg_parser & builder) {
static const xml_tool_call_format form = ([]() {
xml_tool_call_format form {};
form.scope_start = "";
form.tool_start = "<tool_call>\n{\"name\": \"";
form.tool_sep = "\", \"arguments\": {";
form.key_start = "\"";
form.key_val_sep = "\": ";
form.val_end = ", ";
form.tool_end = "}\n</tool_call>";
form.scope_end = "";
form.raw_argval = false;
form.last_val_end = "";
return form;
})();
builder.consume_reasoning_with_xml_tool_calls(form);
}
static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
static const std::string constraint = "(?: (<\\|constrain\\|>)?([a-zA-Z0-9_-]+))";
static const std::string recipient("(?: to=functions\\.([^<\\s]+))");
static const common_regex start_regex("<\\|start\\|>assistant");
static const common_regex analysis_regex("<\\|channel\\|>analysis");
static const common_regex final_regex("<\\|channel\\|>final" + constraint + "?");
static const common_regex preamble_regex("<\\|channel\\|>commentary");
static const common_regex tool_call1_regex(recipient + "<\\|channel\\|>(analysis|commentary)" + constraint + "?");
static const common_regex tool_call2_regex("<\\|channel\\|>(analysis|commentary)" + recipient + constraint + "?");
auto consume_end = [&](bool include_end = false) {
if (auto res = builder.try_find_literal("<|end|>")) {
return res->prelude + (include_end ? builder.str(res->groups[0]) : "");
}
return builder.consume_rest();
};
auto handle_tool_call = [&](const std::string & name) {
if (auto args = builder.try_consume_json_with_dumped_args({{}})) {
if (builder.syntax().parse_tool_calls) {
if (!builder.add_tool_call(name, "", args->value) || args->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
} else if (args->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
}
};
auto regex_match = [](const common_regex & regex, const std::string & input) -> std::optional<common_regex_match> {
auto match = regex.search(input, 0, true);
if (match.type == COMMON_REGEX_MATCH_TYPE_FULL) {
return match;
}
return std::nullopt;
};
do {
auto header_start_pos = builder.pos();
auto content_start = builder.try_find_literal("<|message|>");
if (!content_start) {
throw common_chat_msg_partial_exception("incomplete header");
}
auto header = content_start->prelude;
if (auto match = regex_match(tool_call1_regex, header)) {
auto group = match->groups[1];
auto name = header.substr(group.begin, group.end - group.begin);
handle_tool_call(name);
continue;
}
if (auto match = regex_match(tool_call2_regex, header)) {
auto group = match->groups[2];
auto name = header.substr(group.begin, group.end - group.begin);
handle_tool_call(name);
continue;
}
if (regex_match(analysis_regex, header)) {
builder.move_to(header_start_pos);
if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) {
builder.add_content(consume_end(true));
} else {
builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|end|>");
}
continue;
}
if(regex_match(final_regex, header) || regex_match(preamble_regex, header)) {
builder.add_content(consume_end());
continue;
}
// Possibly a malformed message, attempt to recover by rolling
// back to pick up the next <|start|>
LOG_DBG("%s: unknown header from message: %s\n", __func__, header.c_str());
builder.move_to(header_start_pos);
} while (builder.try_find_regex(start_regex, std::string::npos, false));
auto remaining = builder.consume_rest();
if (!remaining.empty()) {
LOG_DBG("%s: content after last message: %s\n", __func__, remaining.c_str());
}
}
static void common_chat_parse_glm_4_5(common_chat_msg_parser & builder) {
static const xml_tool_call_format form {
/* form.scope_start = */ "",
/* form.tool_start = */ "<tool_call>",
/* form.tool_sep = */ "",
/* form.key_start = */ "<arg_key>",
/* form.key_val_sep = */ "</arg_key>",
/* form.val_end = */ "</arg_value>",
/* form.tool_end = */ "</tool_call>",
/* form.scope_end = */ "",
/* form.key_val_sep2 = */ "<arg_value>",
};
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
}
static void common_chat_parse_firefunction_v2(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex prefix(regex_escape(" functools["));
parse_prefixed_json_tool_call_array(builder, prefix, /* rstrip_prefix= */ 1);
}
static void common_chat_parse_functionary_v3_2(common_chat_msg_parser & builder) {
static const common_regex function_regex_start_only(R"((\w+\n\{|python\n|all\n))");
static const common_regex function_regex(R"(>>>(\w+\n\{|python\n|all\n))");
static const common_regex close_regex(R"(\s*)");
parse_json_tool_calls(
builder,
std::nullopt,
function_regex_start_only,
function_regex,
close_regex,
std::nullopt,
/* allow_raw_python= */ true,
/* get_function_name= */ [&](const auto & res) -> std::string {
auto at_start = res.groups[0].begin == 0;
auto name = builder.str(res.groups[1]);
if (!name.empty() && name.back() == '{') {
// Unconsume the opening brace '{' to ensure the JSON parsing goes well.
builder.move_back(1);
}
auto idx = name.find_last_not_of("\n{");
name = name.substr(0, idx + 1);
if (at_start && name == "all") {
return "";
}
return name;
});
}
static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
static const common_regex python_tag_regex(regex_escape("<|python_tag|>"));
static const common_regex function_regex(R"(<function=(\w+)>)");
static const common_regex close_regex(R"(</function>)");
parse_json_tool_calls(
builder,
/* block_open= */ std::nullopt,
/* function_regex_start_only= */ std::nullopt,
function_regex,
close_regex,
std::nullopt);
if (auto res = builder.try_find_regex(python_tag_regex)) {
auto arguments = wrap_code_as_arguments(builder, builder.consume_rest());
builder.add_tool_call("python", "", arguments);
return;
}
}
static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex open_regex(
"(?:"
"(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start)
"(" // match 2 (open_tag)
"<tool_call>"
"|<function_call>"
"|<tool>"
"|<tools>"
"|<response>"
"|<json>"
"|<xml>"
"|<JSON>"
")?"
"(\\s*\\{\\s*\"name\")" // match 3 (named tool call)
")"
"|<function=([^>]+)>" // match 4 (function name)
"|<function name=\"([^\"]+)\">" // match 5 (function name again)
);
while (auto res = builder.try_find_regex(open_regex)) {
const auto & block_start = res->groups[1];
std::string block_end = block_start.empty() ? "" : "```";
const auto & open_tag = res->groups[2];
std::string close_tag;
if (!res->groups[3].empty()) {
builder.move_to(res->groups[3].begin);
close_tag = open_tag.empty() ? "" : "</" + builder.str(open_tag).substr(1);
if (auto tool_call = builder.try_consume_json_with_dumped_args({{"arguments"}})) {
if (!builder.add_tool_call(tool_call->value) || tool_call->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
builder.consume_spaces();
builder.consume_literal(close_tag);
builder.consume_spaces();
if (!block_end.empty()) {
builder.consume_literal(block_end);
builder.consume_spaces();
}
} else {
throw common_chat_msg_partial_exception("failed to parse tool call");
}
} else {
auto function_name = builder.str(res->groups[4]);
if (function_name.empty()) {
function_name = builder.str(res->groups[5]);
}
GGML_ASSERT(!function_name.empty());
close_tag = "</function>";
if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) {
if (!builder.add_tool_call(function_name, "", arguments->value) || arguments->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
builder.consume_spaces();
builder.consume_literal(close_tag);
builder.consume_spaces();
if (!block_end.empty()) {
builder.consume_literal(block_end);
builder.consume_spaces();
}
}
}
}
builder.add_content(builder.consume_rest());
}
static void common_chat_parse_granite(common_chat_msg_parser & builder) {
// Parse thinking tags
static const common_regex start_think_regex(regex_escape("<think>"));
static const common_regex end_think_regex(regex_escape("</think>"));
// Granite models output partial tokens such as "<" and "<think".
// By leveraging try_consume_regex()/try_find_regex() throwing
// common_chat_msg_partial_exception for these partial tokens,
// processing is interrupted and the tokens are not passed to add_content().
if (auto res = builder.try_consume_regex(start_think_regex)) {
// Restore position for try_parse_reasoning()
builder.move_to(res->groups[0].begin);
builder.try_find_regex(end_think_regex, std::string::npos, false);
// Restore position for try_parse_reasoning()
builder.move_to(res->groups[0].begin);
}
builder.try_parse_reasoning("<think>", "</think>");
// Parse response tags
static const common_regex start_response_regex(regex_escape("<response>"));
static const common_regex end_response_regex(regex_escape("</response>"));
// Granite models output partial tokens such as "<" and "<response".
// Same hack as reasoning parsing.
if (builder.try_consume_regex(start_response_regex)) {
builder.try_find_regex(end_response_regex);
}
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// Look for tool calls
static const common_regex tool_call_regex(regex_escape("<|tool_call|>"));
if (auto res = builder.try_find_regex(tool_call_regex)) {
builder.move_to(res->groups[0].end);
// Expect JSON array of tool calls
if (auto tool_call = builder.try_consume_json_with_dumped_args({{{"arguments"}}})) {
if (!builder.add_tool_calls(tool_call->value) || tool_call->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
}
} else {
builder.add_content(builder.consume_rest());
}
}
static void common_chat_parse_nemotron_v2(common_chat_msg_parser & builder) {
// Parse thinking tags
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// Look for tool calls
static const common_regex tool_call_regex(regex_escape("<TOOLCALL>"));
if (auto res = builder.try_find_regex(tool_call_regex)) {
builder.move_to(res->groups[0].end);
// Expect JSON array of tool calls
auto tool_calls_data = builder.consume_json();
if (tool_calls_data.json.is_array()) {
if (!builder.try_consume_literal("</TOOLCALL>")) {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
builder.add_tool_calls(tool_calls_data.json);
} else {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
}
builder.add_content(builder.consume_rest());
}
static void common_chat_parse_apertus(common_chat_msg_parser & builder) {
// Parse thinking tags
builder.try_parse_reasoning("<|inner_prefix|>", "<|inner_suffix|>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// Look for tool calls
static const common_regex tool_call_regex(regex_escape("<|tools_prefix|>"));
if (auto res = builder.try_find_regex(tool_call_regex)) {
builder.move_to(res->groups[0].end);
auto tool_calls_data = builder.consume_json();
if (tool_calls_data.json.is_array()) {
builder.consume_spaces();
if (!builder.try_consume_literal("<|tools_suffix|>")) {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
for (const auto & value : tool_calls_data.json) {
if (value.is_object()) {
builder.add_tool_call_short_form(value);
}
}
} else {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
}
builder.add_content(builder.consume_rest());
}
static void common_chat_parse_lfm2(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// LFM2 format: <|tool_call_start|>[{"name": "get_current_time", "arguments": {"location": "Paris"}}]<|tool_call_end|>
static const common_regex tool_call_start_regex(regex_escape("<|tool_call_start|>"));
static const common_regex tool_call_end_regex(regex_escape("<|tool_call_end|>"));
// Loop through all tool calls
while (auto res = builder.try_find_regex(tool_call_start_regex, std::string::npos, /* add_prelude_to_content= */ true)) {
builder.move_to(res->groups[0].end);
// Parse JSON array format: [{"name": "...", "arguments": {...}}]
auto tool_calls_data = builder.consume_json();
// Consume end marker
builder.consume_spaces();
if (!builder.try_consume_regex(tool_call_end_regex)) {
throw common_chat_msg_partial_exception("Expected <|tool_call_end|>");
}
// Process each tool call in the array
if (tool_calls_data.json.is_array()) {
for (const auto & tool_call : tool_calls_data.json) {
if (!tool_call.is_object()) {
throw common_chat_msg_partial_exception("Tool call must be an object");
}
if (!tool_call.contains("name")) {
throw common_chat_msg_partial_exception("Tool call missing 'name' field");
}
std::string function_name = tool_call.at("name");
std::string arguments = "{}";
if (tool_call.contains("arguments")) {
if (tool_call.at("arguments").is_object()) {
arguments = tool_call.at("arguments").dump();
} else if (tool_call.at("arguments").is_string()) {
arguments = tool_call.at("arguments");
}
}
if (!builder.add_tool_call(function_name, "", arguments)) {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
}
} else {
throw common_chat_msg_partial_exception("Expected JSON array for tool calls");
}
// Consume any trailing whitespace after this tool call
builder.consume_spaces();
}
// Consume any remaining content after all tool calls
auto remaining = builder.consume_rest();
if (!string_strip(remaining).empty()) {
builder.add_content(remaining);
}
}
static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
static const xml_tool_call_format form {
/* form.scope_start = */ "<seed:tool_call>",
/* form.tool_start = */ "<function=",
/* form.tool_sep = */ ">",
/* form.key_start = */ "<parameter=",
/* form.key_val_sep = */ ">",
/* form.val_end = */ "</parameter>",
/* form.tool_end = */ "</function>",
/* form.scope_end = */ "</seed:tool_call>",
};
builder.consume_reasoning_with_xml_tool_calls(form, "<seed:think>", "</seed:think>");
}
static void common_chat_parse_content_only(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<think>", "</think>");
builder.add_content(builder.consume_rest());
}
static void common_chat_parse(common_chat_msg_parser & builder) {
LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(builder.syntax().format), builder.input().c_str());
switch (builder.syntax().format) {
case COMMON_CHAT_FORMAT_CONTENT_ONLY:
common_chat_parse_content_only(builder);
break;
case COMMON_CHAT_FORMAT_GENERIC:
common_chat_parse_generic(builder);
break;
case COMMON_CHAT_FORMAT_MISTRAL_NEMO:
common_chat_parse_mistral_nemo(builder);
break;
case COMMON_CHAT_FORMAT_MAGISTRAL:
common_chat_parse_magistral(builder);
break;
case COMMON_CHAT_FORMAT_LLAMA_3_X:
common_chat_parse_llama_3_1(builder);
break;
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS:
common_chat_parse_llama_3_1(builder, /* with_builtin_tools= */ true);
break;
case COMMON_CHAT_FORMAT_DEEPSEEK_R1:
common_chat_parse_deepseek_r1(builder);
break;
case COMMON_CHAT_FORMAT_DEEPSEEK_V3_1:
common_chat_parse_deepseek_v3_1(builder);
break;
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2:
common_chat_parse_functionary_v3_2(builder);
break;
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1:
common_chat_parse_functionary_v3_1_llama_3_1(builder);
break;
case COMMON_CHAT_FORMAT_HERMES_2_PRO:
common_chat_parse_hermes_2_pro(builder);
break;
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
common_chat_parse_firefunction_v2(builder);
break;
case COMMON_CHAT_FORMAT_COMMAND_R7B:
common_chat_parse_command_r7b(builder);
break;
case COMMON_CHAT_FORMAT_GRANITE:
common_chat_parse_granite(builder);
break;
case COMMON_CHAT_FORMAT_GPT_OSS:
common_chat_parse_gpt_oss(builder);
break;
case COMMON_CHAT_FORMAT_SEED_OSS:
common_chat_parse_seed_oss(builder);
break;
case COMMON_CHAT_FORMAT_NEMOTRON_V2:
common_chat_parse_nemotron_v2(builder);
break;
case COMMON_CHAT_FORMAT_APERTUS:
common_chat_parse_apertus(builder);
break;
case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS:
common_chat_parse_lfm2(builder);
break;
case COMMON_CHAT_FORMAT_MINIMAX_M2:
common_chat_parse_minimax_m2(builder);
break;
case COMMON_CHAT_FORMAT_GLM_4_5:
common_chat_parse_glm_4_5(builder);
break;
case COMMON_CHAT_FORMAT_KIMI_K2:
common_chat_parse_kimi_k2(builder);
break;
case COMMON_CHAT_FORMAT_QWEN3_CODER_XML:
common_chat_parse_qwen3_coder_xml(builder);
break;
case COMMON_CHAT_FORMAT_APRIEL_1_5:
common_chat_parse_apriel_1_5(builder);
break;
case COMMON_CHAT_FORMAT_XIAOMI_MIMO:
common_chat_parse_xiaomi_mimo(builder);
break;
default:
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
}
builder.finish();
}
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) {
common_chat_msg_parser builder(input, is_partial, syntax);
try {
common_chat_parse(builder);
} catch (const common_chat_msg_partial_exception & ex) {
LOG_DBG("Partial parse: %s\n", ex.what());
if (!is_partial) {
builder.clear_tools();
builder.move_to(0);
common_chat_parse_content_only(builder);
}
}
auto msg = builder.result();
if (!is_partial) {
LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat<json>({msg}).at(0).dump().c_str());
}
return msg;
}
-952
View File
File diff suppressed because it is too large Load Diff
+55
View File
@@ -8,6 +8,7 @@
#include "common.h"
#include "log.h"
#include "llama.h"
#include "sampling.h"
#include <algorithm>
#include <cinttypes>
@@ -949,6 +950,58 @@ std::vector<common_file_info> fs_list_files(const std::string & path) {
// Model utils
//
static inline void common_init_sampler_from_model(
const llama_model * model,
common_params_sampling & sparams) {
const uint64_t config = sparams.user_sampling_config;
auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) {
if (config & user_config) return;
char buf[64] = {0};
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
char * end = nullptr;
int32_t v = strtol(buf, &end, 10);
if (end && end != buf) dst = v;
}
};
auto get_float = [&](const char * key, float & dst, uint64_t user_config) {
if (config & user_config) return;
char buf[128] = {0};
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
char * end = nullptr;
float v = strtof(buf, &end);
if (end && end != buf) dst = v;
}
};
// Sampling sequence
if (!(config & common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS)) {
char buf[512] = {0};
if (llama_model_meta_val_str(model, llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE), buf, sizeof(buf)) > 0) {
const std::vector<std::string> sampler_names = string_split<std::string>(std::string(buf), ';');
if (!sampler_names.empty()) {
sparams.samplers = common_sampler_types_from_names(sampler_names, true);
}
}
}
get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_K), sparams.top_k, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_P), sparams.top_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIN_P), sparams.min_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY), sparams.xtc_probability, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD), sparams.xtc_threshold, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TEMP), sparams.temp, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TEMP);
get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N), sparams.penalty_last_n, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT), sparams.penalty_repeat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT);
get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT), sparams.mirostat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU), sparams.mirostat_tau, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA), sparams.mirostat_eta, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA);
}
struct common_init_result common_init_from_params(common_params & params) {
common_init_result iparams;
auto mparams = common_model_params_to_llama(params);
@@ -960,6 +1013,8 @@ struct common_init_result common_init_from_params(common_params & params) {
return iparams;
}
common_init_sampler_from_model(model, params.sampling);
const llama_vocab * vocab = llama_model_get_vocab(model);
auto cparams = common_context_params_to_llama(params);
+18
View File
@@ -140,6 +140,22 @@ struct common_grammar_trigger {
llama_token token = LLAMA_TOKEN_NULL;
};
enum common_params_sampling_config : uint64_t {
COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS = 1 << 0,
COMMON_PARAMS_SAMPLING_CONFIG_TOP_K = 1 << 1,
COMMON_PARAMS_SAMPLING_CONFIG_TOP_P = 1 << 2,
COMMON_PARAMS_SAMPLING_CONFIG_MIN_P = 1 << 3,
COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY = 1 << 4,
COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD = 1 << 5,
COMMON_PARAMS_SAMPLING_CONFIG_TEMP = 1 << 6,
COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N = 1 << 7,
COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT = 1 << 8,
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT = 1 << 9,
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU = 1 << 10,
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA = 1 << 11,
};
// sampling parameters
struct common_params_sampling {
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
@@ -172,6 +188,8 @@ struct common_params_sampling {
bool no_perf = false; // disable performance metrics
bool timing_per_token = false;
uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
+2 -2
View File
@@ -268,10 +268,10 @@ static bool is_reserved_name(const std::string & name) {
}
std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+");
std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"]");
std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"\\\\]");
std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]");
std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
{'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}
{'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}, {'\\', "\\\\"}
};
std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
+79 -3
View File
@@ -565,7 +565,7 @@ class ModelBase:
gguf.MODEL_TENSOR.ALTUP_PREDICT_COEF,
)
)
or not new_name.endswith(".weight")
or new_name[-7:] not in (".weight", ".lora_a", ".lora_b")
):
data_qtype = gguf.GGMLQuantizationType.F32
@@ -4183,6 +4183,51 @@ class Qwen3MoeModel(Qwen2MoeModel):
super().set_vocab()
@ModelBase.register("Qwen3NextForCausalLM")
class Qwen3NextModel(Qwen2MoeModel):
model_arch = gguf.MODEL_ARCH.QWEN3NEXT
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_ssm_conv_kernel(self.hparams["linear_conv_kernel_dim"])
self.gguf_writer.add_ssm_state_size(self.hparams["linear_key_head_dim"])
self.gguf_writer.add_ssm_group_count(self.hparams["linear_num_key_heads"])
self.gguf_writer.add_ssm_time_step_rank(self.hparams["linear_num_value_heads"])
self.gguf_writer.add_ssm_inner_size(self.hparams["linear_value_head_dim"] * self.hparams["linear_num_value_heads"])
if (rope_dim := self.hparams.get("head_dim")) is None:
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.25)))
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name.startswith("mtp"):
return [] # ignore MTP layers for now
if name.endswith(".A_log"):
data_torch = -torch.exp(data_torch)
elif name.endswith(".dt_bias"):
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
elif "conv1d" in name:
data_torch = data_torch.squeeze()
elif name.endswith("norm.weight") and not name.endswith("linear_attn.norm.weight"):
data_torch = data_torch + 1
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("RND1")
class RND1Model(Qwen2MoeModel):
model_arch = gguf.MODEL_ARCH.RND1
def set_gguf_parameters(self):
super().set_gguf_parameters()
# RND1 specific parameters
# RND1 uses bidirectional attention
self.gguf_writer.add_causal_attention(False)
if (mask_token_id := self.hparams.get("mask_token_id")) is not None:
self.gguf_writer.add_mask_token_id(mask_token_id)
@ModelBase.register("Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration")
class Qwen3VLVisionModel(MmprojModel):
def __init__(self, *args, **kwargs):
@@ -10046,6 +10091,25 @@ class LazyTorchTensor(gguf.LazyBase):
torch.uint8: np.uint8,
}
# only used when byteswapping data. Only correct size is needed
_dtype_byteswap_map: dict[torch.dtype, type] = {
torch.float64: np.float64,
torch.float32: np.float32,
torch.bfloat16: np.float16,
torch.float16: np.float16,
torch.int64: np.int64,
torch.uint64: np.uint64,
torch.int32: np.int32,
torch.uint32: np.uint32,
torch.int16: np.int16,
torch.uint16: np.uint16,
torch.int8: np.int8,
torch.uint8: np.uint8,
torch.bool: np.uint8,
torch.float8_e4m3fn: np.uint8,
torch.float8_e5m2: np.uint8,
}
# used for safetensors slices
# ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046
# TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734
@@ -10089,8 +10153,14 @@ class LazyTorchTensor(gguf.LazyBase):
@classmethod
def from_local_tensor(cls, t: gguf.utility.LocalTensor) -> Tensor:
def load_tensor(tensor: gguf.utility.LocalTensor) -> Tensor:
def byteswap_tensor(tensor: np.ndarray, dtype: type) -> np.ndarray:
if sys.byteorder == 'big':
# switch data back to big endian
tensor = tensor.view(dtype).byteswap(inplace=False)
return tensor
dtype = cls._dtype_str_map[tensor.dtype]
return torch.from_numpy(tensor.mmap_bytes()).view(dtype).reshape(tensor.shape)
numpy_dtype = cls._dtype_byteswap_map[dtype]
return torch.from_numpy(byteswap_tensor(tensor.mmap_bytes(), numpy_dtype)).view(dtype).reshape(tensor.shape)
dtype = cls._dtype_str_map[t.dtype]
shape = t.shape
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(t,), func=lambda r: load_tensor(r))
@@ -10098,10 +10168,16 @@ class LazyTorchTensor(gguf.LazyBase):
@classmethod
def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor):
def byteswap_tensor(tensor: np.ndarray, dtype: type) -> np.ndarray:
if sys.byteorder == 'big':
# switch data back to big endian
tensor = tensor.view(dtype).byteswap(inplace=False)
return tensor
dtype = cls._dtype_str_map[remote_tensor.dtype]
numpy_dtype = cls._dtype_byteswap_map[dtype]
shape = remote_tensor.shape
meta = cls.meta_with_dtype_and_shape(dtype, shape)
lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.frombuffer(r.data(), dtype=dtype).reshape(shape))
lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.from_numpy(byteswap_tensor(np.frombuffer(r.data(), dtype=numpy_dtype), numpy_dtype)).view(dtype).reshape(shape))
return cast(torch.Tensor, lazy)
@classmethod
+1 -1
View File
@@ -242,7 +242,7 @@ def parse_args() -> argparse.Namespace:
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
)
parser.add_argument(
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16",
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f32",
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
)
parser.add_argument(
+13
View File
@@ -42,6 +42,9 @@ The following releases are verified and recommended:
## News
- 2025.11
- Support malloc memory on device more than 4GB.
- 2025.2
- Optimize MUL_MAT Q4_0 on Intel GPU for all dGPUs and built-in GPUs since MTL. Increase the performance of LLM (llama-2-7b.Q4_0.gguf) 21%-87% on Intel GPUs (MTL, ARL-H, Arc, Flex, PVC).
|GPU|Base tokens/s|Increased tokens/s|Percent|
@@ -789,6 +792,8 @@ use 1 SYCL GPUs: [0] with Max compute units:512
| GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because graph performance isn't yet better than non-graph performance. |
| GGML_SYCL_DISABLE_DNN | 0 (default) or 1 | Disable running computations through oneDNN and always use oneMKL. |
| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |
| UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS | 0 (default) or 1 | Support malloc device memory more than 4GB.|
## Known Issues
@@ -835,6 +840,14 @@ use 1 SYCL GPUs: [0] with Max compute units:512
| The default context is too big. It leads to excessive memory usage.|Set `-c 8192` or a smaller value.|
| The model is too big and requires more memory than what is available.|Choose a smaller model or change to a smaller quantization, like Q5 -> Q4;<br>Alternatively, use more than one device to load model.|
- `ggml_backend_sycl_buffer_type_alloc_buffer: can't allocate 5000000000 Bytes of memory on device`
You need to enable to support 4GB memory malloc by:
```
export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
```
### **GitHub contribution**:
Please add the `SYCL :` prefix/tag in issues/PRs titles to help the SYCL contributors to check/address them without delay.
+1 -1
View File
@@ -3,7 +3,7 @@
The example demonstrates batched generation from a given prompt
```bash
./llama-batched -m ./models/llama-7b-v2/ggml-model-f16.gguf -p "Hello my name is" -np 4
./llama-batched -m ./models/llama-7b-v2/ggml-model-f16.gguf -p "Hello my name is" -np 4 --kv-unified
...
+48 -2
View File
@@ -6,8 +6,54 @@ More Info:
- https://github.com/ggml-org/llama.cpp/pull/14644
- https://github.com/ggml-org/llama.cpp/pull/14771
## Parameters
The diffusion CLI supports various parameters to control the generation process:
Example of using Dream architechture: `llama-diffusion-cli -m dream7b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-eps 0.001 --diffusion-algorithm 3 --diffusion-steps 256 --diffusion-visual`
### Core Diffusion Parameters
- `--diffusion-steps`: Number of diffusion steps (default: 256)
- `--diffusion-algorithm`: Algorithm for token selection
- `0`: ORIGIN - Token will be generated in a purely random order from https://arxiv.org/abs/2107.03006.
- `1`: ENTROPY_BASED - Entropy-based selection
- `2`: MARGIN_BASED - Margin-based selection
- `3`: RANDOM - Random selection
- `4`: CONFIDENCE_BASED - Confidence-based selection (default)
- More documentation here https://github.com/DreamLM/Dream
- `--diffusion-visual`: Enable live visualization during generation
Example of using LLaDA architechture: `llama-diffusion-cli -m llada-8b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-block-length 32 --diffusion-steps 256 --diffusion-visual`
### Scheduling Parameters
Choose one of the following scheduling methods:
**Timestep-based scheduling:**
- `--diffusion-eps`: Epsilon value for timestep scheduling (e.g., 0.001)
**Block-based scheduling:**
- `--diffusion-block-length`: Block size for block-based scheduling (e.g., 32)
### Sampling Parameters
- `--temp`: Temperature for sampling (0.0 = greedy/deterministic, higher = more random)
- `--top-k`: Top-k filtering for sampling
- `--top-p`: Top-p (nucleus) filtering for sampling
- `--seed`: Random seed for reproducibility
### Model Parameters
- `-m`: Path to the GGUF model file
- `-p`: Input prompt text
- `-ub`: Maximum sequence length (ubatch size)
- `-c`: Context size
- `-b`: Batch size
### Examples
#### Dream architechture:
```
llama-diffusion-cli -m dream7b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-eps 0.001 --diffusion-algorithm 3 --diffusion-steps 256 --diffusion-visual
```
#### LLaDA architechture:
```
llama-diffusion-cli -m llada-8b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-block-length 32 --diffusion-steps 256 --diffusion-visual
```
#### RND1 architecture:
```
llama-diffusion-cli -m RND1-Base-0910.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-algorithm 1 --diffusion-steps 256 --diffusion-visual --temp 0.5 --diffusion-eps 0.001
```
+4 -3
View File
@@ -104,12 +104,16 @@ int main(int argc, char ** argv) {
params.embedding = true;
// get max number of sequences per batch
const int n_seq_max = llama_max_parallel_sequences();
// if the number of prompts that would be encoded is known in advance, it's more efficient to specify the
// --parallel argument accordingly. for convenience, if not specified, we fallback to unified KV cache
// in order to support any number of prompts
if (params.n_parallel == 1) {
LOG_INF("%s: n_parallel == 1 -> unified KV cache is enabled\n", __func__);
params.kv_unified = true;
params.n_parallel = n_seq_max;
}
// utilize the full context
@@ -123,9 +127,6 @@ int main(int argc, char ** argv) {
params.n_ubatch = params.n_batch;
}
// get max number of sequences per batch
const int n_seq_max = llama_max_parallel_sequences();
llama_backend_init();
llama_numa_init(params.numa);
+2 -2
View File
@@ -231,9 +231,9 @@ DOT = '[^\\x0A\\x0D]'
RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()])
INVALID_RULE_CHARS_RE = re.compile(r'[^a-zA-Z0-9-]+')
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]')
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\\]')
GRAMMAR_RANGE_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\]\-\\]')
GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]'}
GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]', '\\': '\\\\'}
NON_LITERAL_SET = set('|.()[]{}*+?')
ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('^$.[]()|{}*+?')
@@ -4,6 +4,11 @@ set -e
# First try command line argument, then environment variable, then file
CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
MODEL_TESTING_PROMPT="${2:-"$MODEL_TESTING_PROMPT"}"
if [ -z "$MODEL_TESTING_PROMPT"]; then
MODEL_TESTING_PROMPT="Hello, my name is"
fi
# Final check if we have a model path
if [ -z "$CONVERTED_MODEL" ]; then
@@ -14,7 +19,8 @@ if [ -z "$CONVERTED_MODEL" ]; then
fi
echo $CONVERTED_MODEL
echo $MODEL_TESTING_PROMPT
cmake --build ../../build --target llama-logits -j8
../../build/bin/llama-logits -m "$CONVERTED_MODEL" "Hello, my name is"
../../build/bin/llama-logits -m "$CONVERTED_MODEL" "$MODEL_TESTING_PROMPT"
@@ -184,8 +184,12 @@ model_name = os.path.basename(model_path)
# of using AutoModelForCausalLM.
print(f"Model class: {model.__class__.__name__}")
prompt = "Hello, my name is"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
device = next(model.parameters()).device
if os.getenv("MODEL_TESTING_PROMPT"):
prompt = os.getenv("MODEL_TESTING_PROMPT")
else:
prompt = "Hello, my name is"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
print(f"Input tokens: {input_ids}")
print(f"Input text: {repr(prompt)}")
+3
View File
@@ -15,6 +15,9 @@ MODEL_FILE=models/llama-2-7b.Q4_0.gguf
NGL=99
CONTEXT=4096
#support malloc device memory more than 4GB.
export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
if [ $# -gt 0 ]; then
GGML_SYCL_DEVICE=$1
echo "use $GGML_SYCL_DEVICE as main GPU"
+6 -3
View File
@@ -6,7 +6,7 @@
# If you want more control, DPC++ Allows selecting a specific device through the
# following environment variable
#export ONEAPI_DEVICE_SELECTOR="level_zero:0"
export ONEAPI_DEVICE_SELECTOR="level_zero:0"
source /opt/intel/oneapi/setvars.sh
#export GGML_SYCL_DEBUG=1
@@ -18,11 +18,14 @@ MODEL_FILE=models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf
NGL=99 # Layers offloaded to the GPU. If the device runs out of memory, reduce this value according to the model you are using.
CONTEXT=4096
#support malloc device memory more than 4GB.
export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
if [ $# -gt 0 ]; then
GGML_SYCL_DEVICE=$1
echo "Using $GGML_SYCL_DEVICE as the main GPU"
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none
else
#use multiple GPUs with same max compute units
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -c ${CONTEXT}
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT}
fi
+2
View File
@@ -5,5 +5,7 @@
set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:"
@call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force
:: support malloc device memory more than 4GB.
set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
.\build\bin\llama-cli.exe -m models\llama-2-7b.Q4_0.gguf -p %INPUT2% -n 400 -e -ngl 99 -s 0
+3 -1
View File
@@ -5,5 +5,7 @@
set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:"
@call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force
:: support malloc device memory more than 4GB.
set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
.\build\bin\llama-cli.exe -m models\Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf -p %INPUT2% -n 400 -e -ngl 99
.\build\bin\llama-cli.exe -m models\Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf -p %INPUT2% -n 400 -s 0 -e -ngl 99
+6 -4
View File
@@ -25,16 +25,17 @@ if(GIT_EXE)
)
endif()
# Build the version string with optional dirty flag
set(GGML_VERSION "${GGML_VERSION_BASE}")
if(GGML_GIT_DIRTY AND NOT GGML_GIT_DIRTY EQUAL 0)
set(GGML_VERSION "${GGML_VERSION}-dirty")
endif()
if(NOT GGML_BUILD_COMMIT)
set(GGML_BUILD_COMMIT "unknown")
endif()
# Build the commit string with optional dirty flag
if(DEFINED GGML_GIT_DIRTY AND GGML_GIT_DIRTY EQUAL 1)
set(GGML_BUILD_COMMIT "${GGML_BUILD_COMMIT}-dirty")
endif()
include(CheckIncludeFileCXX)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
@@ -182,6 +183,7 @@ endif()
# ggml core
set(GGML_SCHED_MAX_COPIES "4" CACHE STRING "ggml: max input copies for pipeline parallelism")
option(GGML_CPU "ggml: enable CPU backend" ON)
option(GGML_SCHED_NO_REALLOC "ggml: disallow reallocations in ggml-alloc (for debugging)" OFF)
# 3rd party libs / backends
option(GGML_ACCELERATE "ggml: enable Accelerate framework" ON)
+1 -1
View File
@@ -8,7 +8,7 @@ extern "C" {
#endif
#define RPC_PROTO_MAJOR_VERSION 3
#define RPC_PROTO_MINOR_VERSION 0
#define RPC_PROTO_MINOR_VERSION 5
#define RPC_PROTO_PATCH_VERSION 0
#define GGML_RPC_MAX_SERVERS 16
+14 -6
View File
@@ -530,6 +530,7 @@ extern "C" {
GGML_OP_ARANGE,
GGML_OP_TIMESTEP_EMBEDDING,
GGML_OP_ARGSORT,
GGML_OP_TOP_K,
GGML_OP_LEAKY_RELU,
GGML_OP_TRI,
GGML_OP_FILL,
@@ -2258,18 +2259,25 @@ extern "C" {
struct ggml_tensor * a,
enum ggml_sort_order order);
// similar to ggml_top_k but implemented as `argsort` + `view`
GGML_API struct ggml_tensor * ggml_argsort_top_k(
struct ggml_context * ctx,
struct ggml_tensor * a,
int k);
// top k elements per row
// note: the resulting top k indices are in no particular order
GGML_API struct ggml_tensor * ggml_top_k(
struct ggml_context * ctx,
struct ggml_tensor * a,
int k);
GGML_API struct ggml_tensor * ggml_arange(
struct ggml_context * ctx,
float start,
float stop,
float step);
// top k elements per row
GGML_API struct ggml_tensor * ggml_top_k(
struct ggml_context * ctx,
struct ggml_tensor * a,
int k);
#define GGML_KQ_MASK_PAD 64
// q: [n_embd_k, n_batch, n_head, ne3 ]
+19
View File
@@ -221,6 +221,10 @@ if (GGML_BACKEND_DL)
target_compile_definitions(ggml-base PUBLIC GGML_BACKEND_DL)
endif()
if (GGML_SCHED_NO_REALLOC)
target_compile_definitions(ggml-base PUBLIC GGML_SCHED_NO_REALLOC)
endif()
add_library(ggml
ggml-backend-reg.cpp)
add_library(ggml::ggml ALIAS ggml)
@@ -328,6 +332,14 @@ function(ggml_add_cpu_backend_variant tag_name)
set(GGML_INTERNAL_${feat} OFF)
endforeach()
foreach (feat ${ARGN})
set(GGML_INTERNAL_${feat} ON)
endforeach()
elseif (GGML_SYSTEM_ARCH STREQUAL "riscv64")
foreach (feat RVV)
set(GGML_INTERNAL_${feat} OFF)
endforeach()
foreach (feat ${ARGN})
set(GGML_INTERNAL_${feat} ON)
endforeach()
@@ -402,6 +414,13 @@ if (GGML_CPU_ALL_VARIANTS)
else()
message(FATAL_ERROR "Unsupported s390x target OS: ${CMAKE_SYSTEM_NAME}")
endif()
elseif (GGML_SYSTEM_ARCH STREQUAL "riscv64")
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
ggml_add_cpu_backend_variant(riscv64_0)
ggml_add_cpu_backend_variant(riscv64_v RVV)
else()
message(FATAL_ERROR "Unsupported RISC-V target OS: ${CMAKE_SYSTEM_NAME}")
endif()
else()
message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS not yet supported with ${GGML_SYSTEM_ARCH} on ${CMAKE_SYSTEM_NAME}")
endif()
+8 -3
View File
@@ -921,10 +921,15 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
}
if (realloc) {
#ifndef NDEBUG
size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0;
GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
{
size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0;
if (cur_size > 0) {
GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n",
__func__, ggml_backend_buft_name(galloc->bufts[i]),
cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
}
}
#endif
ggml_vbuffer_free(galloc->buffers[i]);
galloc->buffers[i] = ggml_vbuffer_alloc(galloc->bufts[i], galloc->buf_tallocs[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE);
if (galloc->buffers[i] == NULL) {
+9 -3
View File
@@ -1395,14 +1395,20 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
// allocate graph
if (backend_ids_changed || !ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) {
#ifdef GGML_SCHED_NO_REALLOC
GGML_ABORT("%s: failed to allocate graph, but graph re-allocation is disabled by GGML_SCHED_NO_REALLOC\n", __func__);
#endif
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed);
#endif
// the re-allocation may cause the split inputs to be moved to a different address
// synchronize without ggml_backend_sched_synchronize to avoid changing cur_copy
for (int i = 0; i < sched->n_backends; i++) {
ggml_backend_synchronize(sched->backends[i]);
}
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed);
#endif
ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids);
if (!ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) {
GGML_LOG_ERROR("%s: failed to allocate graph\n", __func__);
+390 -142
View File
@@ -42,6 +42,7 @@
#include <aclnnop/aclnn_exp.h>
#include <aclnnop/aclnn_fill_scalar.h>
#include <aclnnop/aclnn_fused_infer_attention_score_v2.h>
#include <aclnnop/aclnn_ger.h>
#include <aclnnop/aclnn_group_norm.h>
#include <aclnnop/aclnn_grouped_matmul_v3.h>
#include <aclnnop/aclnn_gt_scalar.h>
@@ -2206,78 +2207,120 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context & ctx,
}
/**
* @brief Initializes and caches sine/cosine positional encoding values
* (used in RoPE, Rotary Position Embedding) for attention layers.
* @brief Initializes and caches all intermediate tensors required for RoPE
* (Rotary Position Embedding), including support for Yarn, mRoPE,
* i-mRoPE, Neox repeat strategy, independent sectors, frequency factors
* and multi-section rotary groups.
*
* This function computes and caches the sin/cos values of
* θ = position * theta_scale for RoPE encoding. The cache is shared
* across attention layers, and only the first attention layer will
* trigger initialization. The cache includes repeated sin/cos values
* with different repeat methods depending on the @param is_neox flag.
* This function computes and caches the per-dimension θ coefficients used for
* Q/K rotary embedding. The cache is shared across layers, and recomputed only
* when any dependent parameter changes.
*
* Steps performed by this function:
* 1. Identify whether the target tensor belongs to Q/K in attention
* and restrict computation to the first layer only.
* 2. Initialize the theta scale array (arange → power → freq scaling).
* 3. Allocate sin/cos caches if the max prompt length increases.
* 4. Compute θ = position * theta_scale.
* 5. Compute sin(θ), cos(θ) and optionally scale by attn_factor.
* 6. Expand sin/cos values by repeat or repeat_interleave depending
* on whether @param is_neox is enabled.
* The function now supports:
* - Yarn RoPE extrapolation (via @param corr_dims and @param ext_factor)
* - Per-dimension independent sector exponent rules (indep_sects + sections[])
* - Multi-section RoPE (mRoPE) index mapping (mrope_used + is_imrope)
* - Frequency factor division (src2)
* - Neox / normal repeat expansion modes
*
* @param ctx The CANN backend context, holding memory pool,
* stream, and persistent buffers for rope init/cache.
* @param dst The destination ggml_tensor whose computation
* depends on the RoPE values (usually Qcur/Kcur).
* @param theta_scale Scalar exponent base for computing theta scale values.
* @param freq_scale Frequency scaling factor, applied to theta scale.
* @param attn_factor Attention scaling factor, applied to sin/cos.
* @param is_neox Whether to use Neox-style repeat strategy
* (dim expansion vs repeat_interleave).
* @param ctx CANN backend context, containing memory pool,
* cached buffers, and runtime stream.
* @param dst Destination ggml_tensor whose computation
* depends on RoPE (typically Qcur or Kcur).
* @param corr_dims [low, high] Yarn correction range.
* @param ext_factor Yarn extrapolation strength. 0 = disabled.
* @param theta_scale Base multiplier for per-dimension θ exponent.
* @param freq_scale Global frequency scaling factor.
* @param attn_factor Optional scaling applied to sin/cos (if needed).
* @param is_neox Whether to use Neox-style dimension interleave.
* @param sections 4-way sector sizes for independent-section RoPE
* and multi-section mRoPE (t/h/w/e).
* @param mrope_used Whether to enable multi-section rotary embedding.
* @param is_imrope Whether to apply interleaved mRoPE rules.
* @param indep_sects Whether each dimension runs independent exponent
* resets based on @p sections.
*/
static void aclnn_cache_init(ggml_backend_cann_context & ctx,
ggml_tensor * dst,
float * corr_dims,
float ext_factor,
float theta_scale,
float freq_scale,
float attn_factor,
bool is_neox) {
static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
ggml_tensor * dst,
float * corr_dims,
float ext_factor,
float theta_scale,
float freq_scale,
float attn_factor,
bool is_neox,
int sections[4],
bool mrope_used,
bool is_imrope,
bool indep_sects) {
ggml_tensor * src0 = dst->src[0]; // input
ggml_tensor * src1 = dst->src[1]; // position
ggml_tensor * src2 = dst->src[2]; // freq_factors
if (src2 == nullptr && ctx.rope_cache.cached && ctx.rope_cache.ext_factor == ext_factor &&
ctx.rope_cache.theta_scale == theta_scale && ctx.rope_cache.freq_scale == freq_scale &&
ctx.rope_cache.attn_factor == attn_factor && ctx.rope_cache.is_neox == is_neox) {
int64_t theta_scale_length = src0->ne[0] / 2;
int64_t position_length = dst->ne[2];
// TODO: check theta_scale_length and position_length.
if (src2 == nullptr && ctx.rope_cache.cached &&
ctx.rope_cache.equal(theta_scale_length, position_length, ext_factor, theta_scale, freq_scale, attn_factor,
is_neox, indep_sects, mrope_used, is_imrope, sections)) {
// use cache.
return;
}
int64_t theta_scale_length = src0->ne[0] / 2;
int64_t theta_scale_ne[] = { theta_scale_length, 1, 1, 1 };
size_t theta_scale_nb[] = { sizeof(float), sizeof(float), sizeof(float), theta_scale_length * sizeof(float) };
// Step0: calculate tensor shape.
int64_t theta_scale_ne[] = { theta_scale_length, 1, 1, 1 };
size_t theta_scale_nb[] = { sizeof(float), theta_scale_length * sizeof(float), theta_scale_length * sizeof(float),
theta_scale_length * sizeof(float) };
GGML_ASSERT(src1->type == GGML_TYPE_I32);
int64_t position_length = src1->ne[0];
int64_t position_ne[] = { 1, 1, position_length, 1 };
size_t position_nb[] = { sizeof(int32_t), sizeof(int32_t), sizeof(int32_t), sizeof(int32_t) * position_length };
int64_t position_ne[] = { 1, 1, position_length, 1 };
size_t position_nb[] = { sizeof(int32_t), sizeof(int32_t), sizeof(int32_t), sizeof(int32_t) * position_length };
int64_t theta_ne[] = { theta_scale_length, 1, position_length, 1 };
size_t theta_nb[GGML_MAX_DIMS];
theta_nb[0] = sizeof(float);
int64_t cache_ne[] = { theta_scale_length, 1, position_length, 1 };
size_t cache_nb[GGML_MAX_DIMS];
cache_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1];
cache_nb[i] = cache_nb[i - 1] * cache_ne[i - 1];
}
// theta_scale arange, [0,1,...,ne00/2 - 1]
// Step1: Compute the coefficient of theta. During the cache_init process, aside from
// (1) multiplying by the position,
// (2) dividing by freq_factors,
// (3) computing the sine and cosine,
// the other parameters used in the computation generally do not change in most scenarios.
// Therefore, we can first compute this part of the result and then cache it.
// Step1.1: prepare theta_scale exponent. if this exponent updated, should update theta_scale_tensor.
acl_tensor_ptr acl_theta_scale_tensor;
// cache theta scale
if (ctx.rope_cache.theta_scale_length != theta_scale_length ||
// theta_scale and freq_scale should not change during the current token inference process,
// so we can directly use == here instead of comparing the absolute difference.
ctx.rope_cache.theta_scale != theta_scale || ctx.rope_cache.freq_scale != freq_scale) {
ctx.rope_cache.theta_scale_length = theta_scale_length;
bool theta_scale_updated = false;
if (ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.theta_scale != theta_scale ||
ctx.rope_cache.indep_sects != indep_sects) {
theta_scale_updated = true;
if (ctx.rope_cache.theta_scale_exp_host != nullptr) {
free(ctx.rope_cache.theta_scale_exp_host);
}
ctx.rope_cache.theta_scale_exp_host = (float *) malloc(theta_scale_length * sizeof(float));
GGML_ASSERT(ctx.rope_cache.theta_scale_exp_host != nullptr);
if (!indep_sects) {
ctx.rope_cache.theta_scale_exp_host[0] = 1;
for (int i = 1; i < theta_scale_length; i++) {
ctx.rope_cache.theta_scale_exp_host[i] = ctx.rope_cache.theta_scale_exp_host[i - 1] * theta_scale;
}
} else {
int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
int sec_w = sections[1] + sections[0];
int sec_e = sections[2] + sec_w;
ctx.rope_cache.theta_scale_exp_host[0] = 1;
for (int i = 1; i < theta_scale_length; i++) {
int sector = i % sect_dims;
if (sector == 0 || sector == sections[0] || sector == sec_w || sector == sec_e) {
ctx.rope_cache.theta_scale_exp_host[i] = 1;
continue;
}
ctx.rope_cache.theta_scale_exp_host[i] = ctx.rope_cache.theta_scale_exp_host[i - 1] * theta_scale;
}
}
if (ctx.rope_cache.theta_scale_cache != nullptr) {
ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache));
@@ -2285,74 +2328,138 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float),
ACL_MEM_MALLOC_HUGE_FIRST));
ACL_CHECK(aclrtMemcpyAsync(ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float),
ctx.rope_cache.theta_scale_exp_host, theta_scale_length * sizeof(float),
ACL_MEMCPY_HOST_TO_DEVICE, ctx.stream()));
acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
theta_scale_ne, theta_scale_nb, 1);
}
float start = 0;
float step = 1;
float stop = theta_scale_length;
float n_elements = theta_scale_length;
aclnn_arange(ctx, acl_theta_scale_tensor.get(), start, stop, step, n_elements);
// Step1.2: prepare rope_yarn_ramp, if this part updated, should update theta_scale_tensor.
bool yarn_ramp_tensor_updated = false;
ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool());
acl_tensor_ptr acl_yarn_ramp_tensor;
if (ext_factor != 0 &&
// TODO: check more parameter.
(ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.freq_scale != freq_scale)) {
yarn_ramp_tensor_updated = true;
ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool());
acl_tensor_ptr acl_yarn_ramp_tensor;
if (ext_factor != 0) {
// -rope_yarn_ramp
// const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
// return MIN(1, MAX(0, y)) - 1;
yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
void * yarn_ramp_buffer = yarn_ramp_allocator.get();
acl_yarn_ramp_tensor =
ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
float zero_value = 0, one_value = 1;
float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
acl_scalar_ptr low = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT);
acl_scalar_ptr zero = ggml_cann_create_scalar(&zero_value, aclDataType::ACL_FLOAT);
acl_scalar_ptr one = ggml_cann_create_scalar(&one_value, aclDataType::ACL_FLOAT);
acl_scalar_ptr denom_safe = ggml_cann_create_scalar(&denom_safe_value, aclDataType::ACL_FLOAT);
acl_scalar_ptr ext_factor_sc = ggml_cann_create_scalar(&ext_factor, aclDataType::ACL_FLOAT);
// -rope_yarn_ramp
// const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
// return MIN(1, MAX(0, y)) - 1;
yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
void * yarn_ramp_buffer = yarn_ramp_allocator.get();
acl_yarn_ramp_tensor =
ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
float zero_value = 0, one_value = 1;
float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
acl_scalar_ptr low = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT);
acl_scalar_ptr zero = ggml_cann_create_scalar(&zero_value, aclDataType::ACL_FLOAT);
acl_scalar_ptr one = ggml_cann_create_scalar(&one_value, aclDataType::ACL_FLOAT);
acl_scalar_ptr denom_safe = ggml_cann_create_scalar(&denom_safe_value, aclDataType::ACL_FLOAT);
acl_scalar_ptr ext_factor_sc = ggml_cann_create_scalar(&ext_factor, aclDataType::ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, Subs, acl_theta_scale_tensor.get(), low.get(), one.get(),
acl_yarn_ramp_tensor.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDivs, acl_yarn_ramp_tensor.get(), denom_safe.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceThreshold, acl_yarn_ramp_tensor.get(), zero.get(), zero.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceClampMax, acl_yarn_ramp_tensor.get(), one.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor.get(), one.get(), one.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), ext_factor_sc.get());
aclnn_arange(ctx, acl_yarn_ramp_tensor.get(), 0, theta_scale_length, 1, theta_scale_length);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor.get(), low.get(), one.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDivs, acl_yarn_ramp_tensor.get(), denom_safe.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceThreshold, acl_yarn_ramp_tensor.get(), zero.get(), zero.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceClampMax, acl_yarn_ramp_tensor.get(), one.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor.get(), one.get(), one.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), ext_factor_sc.get());
// theta_interp = freq_scale * theta_extrap;
// theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
// theta = freq_scale * theta_extrap * (1 - ramp_mix) + theta_extrap * ramp_mix;
// theta = freq_scale * theta_extrap - freq_scale * theta_extrap * ramp_mix + theta_extrap * ramp_mix;
// theta = theta_extrap * (freq_scale - freq_scale * ramp_mix + ramp_mix);
//
// we cache (freq_scale - freq_scale * ramp_mix + ramp_mix), Considering that the rope_yarn_ramp here is the inverse
// cache freq_scale + (freq_scale - 1) * ramp_mix
float freq_scale_1 = freq_scale - 1;
acl_scalar_ptr freq_scale_sc = ggml_cann_create_scalar(&freq_scale, aclDataType::ACL_FLOAT);
acl_scalar_ptr freq_scale_1_sc = ggml_cann_create_scalar(&freq_scale_1, aclDataType::ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get());
}
// theta_interp = freq_scale * theta_extrap;
// theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
// theta = freq_scale * theta_extrap * (1 - ramp_mix) + theta_extrap * ramp_mix;
// theta = freq_scale * theta_extrap - freq_scale * theta_extrap * ramp_mix + theta_extrap * ramp_mix;
// theta = theta_extrap * (freq_scale - freq_scale * ramp_mix + ramp_mix);
//
// we cache (freq_scale - freq_scale * ramp_mix + ramp_mix), Considering that the rope_yarn_ramp here is the inverse
// cache freq_scale + (freq_scale - 1) * ramp_mix
float freq_scale_1 = freq_scale - 1;
acl_scalar_ptr freq_scale_sc = ggml_cann_create_scalar(&freq_scale, aclDataType::ACL_FLOAT);
acl_scalar_ptr freq_scale_1_sc = ggml_cann_create_scalar(&freq_scale_1, aclDataType::ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get());
}
// power
acl_scalar_ptr acl_theta_scale = ggml_cann_create_scalar(&theta_scale, aclDataType::ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, acl_theta_scale.get(), acl_theta_scale_tensor.get(),
acl_theta_scale_tensor.get());
if (ext_factor != 0) {
// Step 1.3: update theta_scale_tensor according to ext_factor or freq_scale.
if (ext_factor != 0) {
if (theta_scale_updated || yarn_ramp_tensor_updated) {
theta_scale_updated = true;
aclnn_mul(ctx, acl_theta_scale_tensor.get(), acl_yarn_ramp_tensor.get());
} else if (freq_scale != 1) {
aclnn_muls(ctx, acl_theta_scale_tensor.get(), freq_scale, nullptr, true);
}
} else {
// use cache
if (freq_scale != 1 && (ctx.rope_cache.freq_scale != freq_scale || theta_scale_updated)) {
theta_scale_updated = true;
aclnn_muls(ctx, acl_theta_scale_tensor.get(), freq_scale, nullptr, true);
}
}
// Nothing changed, use cache.
if (!theta_scale_updated) {
acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
}
// Step 1.4: prepare select index if mrope
acl_tensor_ptr position_select_index_tensor;
if (mrope_used) {
if (ctx.rope_cache.sections[0] != sections[0] || ctx.rope_cache.sections[1] != sections[1] ||
ctx.rope_cache.sections[2] != sections[2] || ctx.rope_cache.sections[3] != sections[3] ||
ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.is_imrope != is_imrope) {
if (ctx.rope_cache.position_select_index_host != nullptr) {
free(ctx.rope_cache.position_select_index_host);
}
ctx.rope_cache.position_select_index_host = (int *) malloc(theta_scale_length * sizeof(int));
GGML_ASSERT(ctx.rope_cache.position_select_index_host != nullptr);
int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
int sec_w = sections[1] + sections[0];
int sec_e = sections[2] + sec_w;
// t,h,w,e
for (int i = 0; i < theta_scale_length; i++) {
int sector = i % sect_dims;
if (is_imrope) { // qwen3vl apply interleaved mrope
if (sector % 3 == 1 && sector < 3 * sections[1]) {
ctx.rope_cache.position_select_index_host[i] = 1;
} else if (sector % 3 == 2 && sector < 3 * sections[2]) {
ctx.rope_cache.position_select_index_host[i] = 2;
} else if (sector % 3 == 0 && sector < 3 * sections[0]) {
ctx.rope_cache.position_select_index_host[i] = 0;
} else {
ctx.rope_cache.position_select_index_host[i] = 3;
}
} else {
if (sector >= sections[0] && sector < sec_w) {
ctx.rope_cache.position_select_index_host[i] = 1;
} else if (sector >= sec_w && sector < sec_e) {
ctx.rope_cache.position_select_index_host[i] = 2;
} else if (sector >= sec_e) {
ctx.rope_cache.position_select_index_host[i] = 3;
} else {
ctx.rope_cache.position_select_index_host[i] = 0;
}
}
}
if (ctx.rope_cache.position_select_index != nullptr) {
ACL_CHECK(aclrtFree(ctx.rope_cache.position_select_index));
}
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.position_select_index, theta_scale_length * sizeof(int),
ACL_MEM_MALLOC_HUGE_FIRST));
ACL_CHECK(aclrtMemcpyAsync(ctx.rope_cache.position_select_index, theta_scale_length * sizeof(int),
ctx.rope_cache.position_select_index_host, theta_scale_length * sizeof(int),
ACL_MEMCPY_HOST_TO_DEVICE, ctx.stream()));
}
position_select_index_tensor = ggml_cann_create_tensor(ctx.rope_cache.position_select_index, ACL_INT32,
sizeof(int), theta_scale_ne, theta_scale_nb, 1);
}
// Step2: divide by freq_factors
ggml_cann_pool_alloc freq_fac_res_allocator(ctx.pool());
// freq_factors
if (src2) {
freq_fac_res_allocator.alloc(theta_scale_length * sizeof(float));
void * freq_fac_res_ptr = freq_fac_res_allocator.get();
@@ -2365,6 +2472,85 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
std::swap(acl_theta_scale_tensor, acl_freq_fac_res_tensor);
}
// Step3: prepare position_tensor
acl_tensor_ptr acl_position_tensor;
ggml_cann_pool_alloc mrope_position_acllocator(ctx.pool());
if (mrope_used) {
// Step3.1: select current position;
// position :
// pos1: [[0, 1 ,2 ,3 ],
// pos2: [4, 5 ,6 ,7 ],
// pos3: [8, 9 ,10,11],
// pos4: [12,13,14,15] ]
//
// select index = [0, 1, 2, 2, 1, 0]
//
// selected_tensor:
// [[0, 1 ,2 ,3 ],
// [4, 5 ,6 ,7 ],
// [8, 9 ,10,11],
// [8, 9 ,10,11],
// [4, 5 ,6 ,7 ],
// [0, 1 ,2 ,3 ]]
//
// transpose, from [seq_len:dims] to [dims:seq_len]
// [0, 4, 8 ,8 ,4, 0],
// [1, 5, 9, 9, 5, 1],
// [2, 6, 10,10,6 ,2],
// [3, 7, 11,11,7 3 ]]
//
// multipy by theta_scale_tensor
// [theta_scale^0, theta_scale^1, ..., theta_scale ^ n]
int64_t mrope_position_ne[] = { position_length, 4 };
size_t mrope_position_nb[] = { sizeof(int), position_length * sizeof(int) };
acl_tensor_ptr mrope_position =
ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type),
mrope_position_ne, mrope_position_nb, 2);
// selected position tensor's shape is a transpose of cache tensor.
int64_t selected_position_ne[] = { position_length, theta_scale_length };
size_t selected_position_nb[] = { sizeof(float), position_length * sizeof(float) };
mrope_position_acllocator.alloc(theta_scale_length * position_length * sizeof(float));
void * mrope_position_buffer = mrope_position_acllocator.get();
acl_position_tensor =
ggml_cann_create_tensor(mrope_position_buffer, ggml_cann_type_mapping(src1->type),
ggml_type_size(src1->type), selected_position_ne, selected_position_nb, 2);
GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, mrope_position.get(), 0, position_select_index_tensor.get(),
acl_position_tensor.get());
// transpose
int64_t transposed_ne[] = { position_length, 1, theta_scale_length, 1 };
size_t transposed_nb[GGML_MAX_DIMS];
transposed_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
transposed_nb[i] = transposed_nb[i - 1] * transposed_ne[i - 1];
}
std::swap(transposed_ne[0], transposed_ne[2]);
std::swap(transposed_nb[0], transposed_nb[2]);
acl_position_tensor =
ggml_cann_create_tensor(mrope_position_buffer, ggml_cann_type_mapping(src1->type),
ggml_type_size(src1->type), transposed_ne, transposed_nb, GGML_MAX_DIMS);
} else {
// auto bcast.
acl_position_tensor =
ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type),
position_ne, position_nb, GGML_MAX_DIMS);
}
// Step4: multiply by the position
int64_t theta_length = theta_scale_length * position_length;
ggml_cann_pool_alloc theta_allocator(ctx.pool(), theta_length * sizeof(float));
void * theta_buffer = theta_allocator.get();
acl_tensor_ptr acl_theta_tensor =
ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float), cache_ne, cache_nb, GGML_MAX_DIMS);
aclnn_mul(ctx, acl_position_tensor.get(), acl_theta_scale_tensor.get(), acl_theta_tensor.get());
// Step5: calculate sin cos.
// init sin_repeat && cos_repeat, only to accelerate first layer on each device
if (position_length > ctx.rope_cache.position_length) {
ctx.rope_cache.position_length = position_length;
@@ -2381,44 +2567,30 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
aclrtMalloc(&ctx.rope_cache.cos_cache, repeat_theta_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
}
// position
acl_tensor_ptr acl_position_tensor =
ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), position_ne,
position_nb, GGML_MAX_DIMS);
// power * position
int64_t theta_length = theta_scale_length * position_length;
ggml_cann_pool_alloc theta_allocator(ctx.pool(), theta_length * sizeof(float));
void * theta_buffer = theta_allocator.get();
acl_tensor_ptr acl_theta_tensor =
ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS);
aclnn_mul(ctx, acl_position_tensor.get(), acl_theta_scale_tensor.get(), acl_theta_tensor.get());
// sin/cos
ggml_cann_pool_alloc sin_allocator(ctx.pool(), theta_length * sizeof(float));
void * sin_buffer = sin_allocator.get();
acl_tensor_ptr acl_sin_tensor =
ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);
ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float), cache_ne, cache_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);
aclnn_sin(ctx, acl_theta_tensor.get(), acl_sin_tensor.get());
ggml_cann_pool_alloc cos_allocator(ctx.pool(), theta_length * sizeof(float));
void * cos_buffer = cos_allocator.get();
acl_tensor_ptr acl_cos_tensor =
ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);
ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float), cache_ne, cache_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);
aclnn_cos(ctx, acl_theta_tensor.get(), acl_cos_tensor.get());
if (ext_factor != 0) {
attn_factor *= 1.0f + 0.1f * logf(1.0f / freq_scale);
}
// attn_factor
// Step 5: multiply by attn_factor
if (attn_factor != 1) {
aclnn_muls(ctx, acl_sin_tensor.get(), attn_factor, nullptr, true);
aclnn_muls(ctx, acl_cos_tensor.get(), attn_factor, nullptr, true);
}
int64_t sin_reshape_ne[4] = { src0->ne[0], 1, src0->ne[2], 1 };
int64_t sin_reshape_ne[4] = { src0->ne[0], 1, dst->ne[2], 1 };
size_t sin_reshape_nb[GGML_MAX_DIMS];
sin_reshape_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
@@ -2429,8 +2601,9 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
acl_tensor_ptr acl_cos_repeat_tensor = ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
// repeat
// Step 6: repeat
if (is_neox) {
// [sinθ1, sinθ1, sinθ2, sinθ2, ..., sinθn, sinθn]
int64_t repeatsArray[] = { 1, 1, 1, 2 };
aclnn_repeat(ctx, acl_sin_tensor.get(), acl_sin_repeat_tensor.get(), repeatsArray);
aclnn_repeat(ctx, acl_cos_tensor.get(), acl_cos_repeat_tensor.get(), repeatsArray);
@@ -2438,17 +2611,15 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
int64_t num_repeats = 2;
int64_t dim = 3;
int64_t output_size = theta_scale_length * num_repeats;
// [sinθ1, sinθ2, ..., sinθn, sinθ1, sinθ2, ..., sinθn]
aclnn_repeat_interleave(ctx, acl_sin_tensor.get(), acl_sin_repeat_tensor.get(), dim, num_repeats, output_size);
aclnn_repeat_interleave(ctx, acl_cos_tensor.get(), acl_cos_repeat_tensor.get(), dim, num_repeats, output_size);
}
// Other layers use cache except first layer.
ctx.rope_cache.cached = true;
ctx.rope_cache.ext_factor = ext_factor;
ctx.rope_cache.theta_scale = theta_scale;
ctx.rope_cache.freq_scale = freq_scale;
ctx.rope_cache.attn_factor = attn_factor;
ctx.rope_cache.is_neox = is_neox;
// Update cached value.
ctx.rope_cache.cached = true;
ctx.rope_cache.set(theta_scale_length, position_length, ext_factor, theta_scale, freq_scale, attn_factor, is_neox,
indep_sects, mrope_used, is_imrope, sections);
}
#ifdef __cplusplus
@@ -2474,6 +2645,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
// param
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
int sections[4];
// const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
@@ -2482,12 +2654,13 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
GGML_TENSOR_UNARY_OP_LOCALS
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
// TODO: n_dims <= ne0
GGML_ASSERT(n_dims == ne0);
@@ -2498,10 +2671,25 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (mrope_used) {
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
}
if (is_vision) {
GGML_ASSERT(n_dims == ne0/2);
}
if (is_imrope || mrope_used) {
is_neox = true;
}
// init ctx.rope_cos/rope_sin cache
aclnn_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox);
aclnn_rope_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, sections, mrope_used, is_imrope, is_vision);
int64_t sin_reshape_ne[4] = { ne00, 1, ne02, 1 };
size_t sin_reshape_nb[GGML_MAX_DIMS];
@@ -2657,8 +2845,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
return;
#endif
// ggml_mode = 0 --> aclnn_model = 1
int64_t acl_mode = mode == 0 ? 1 : mode;
int64_t acl_mode = is_neox ? 0 : 1;
switch (src0->type) {
case GGML_TYPE_F32:
@@ -3236,3 +3423,64 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst
GGML_ABORT("Function is not implemented.");
}
}
static void ggml_cann_out_prod_fp(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0]; // weight
ggml_tensor * src1 = dst->src[1]; // input
GGML_TENSOR_BINARY_OP_LOCALS
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get());
const int64_t dps2 = ne2 / ne02;
const int64_t dps3 = ne3 / ne03;
for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = 0; i2 < ne2; i2++) {
const int64_t i02 = i2 / dps2;
const int64_t i03 = i3 / dps3;
const int64_t i12 = i2;
const int64_t i13 = i3;
acl_tensor_ptr accumulator =
ggml_cann_create_tensor((char *) dst->data + i2 * nb2 + i3 * nb3, ggml_cann_type_mapping(dst->type),
ggml_type_size(dst->type), dst->ne, dst->nb, 2);
// The outer product needs to be accumulated in this dimension.
for (int64_t i1 = 0; i1 < ne11; i1++) {
acl_tensor_ptr acl_input = ggml_cann_create_tensor(
(char *) src1->data + i1 * nb11 + i12 * nb12 + i13 * nb13, ggml_cann_type_mapping(src0->type),
ggml_type_size(src0->type), src1->ne, src1->nb, 1);
acl_tensor_ptr acl_weight = ggml_cann_create_tensor(
(char *) src0->data + i1 * nb01 + i02 * nb02 + i03 * nb03, ggml_cann_type_mapping(src0->type),
ggml_type_size(src0->type), src0->ne, src0->nb, 1);
ggml_cann_pool_alloc output_allocator(ctx.pool());
void * output_buffer = output_allocator.alloc(ggml_nbytes(dst));
acl_tensor_ptr acl_out = ggml_cann_create_tensor(output_buffer, ggml_cann_type_mapping(dst->type),
ggml_type_size(dst->type), dst->ne, dst->nb, 2);
GGML_CANN_CALL_ACLNN_OP(ctx, Ger, acl_input.get(), acl_weight.get(), acl_out.get());
float alpha_value = 1.0f;
aclScalar * alpha = aclCreateScalar(&alpha_value, ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, accumulator.get(), acl_out.get(), alpha);
}
}
}
}
void ggml_cann_out_prod(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0];
const enum ggml_type type = src0->type;
switch (type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
ggml_cann_out_prod_fp(ctx, dst);
break;
default:
GGML_ABORT("Unsupport type for GGML_OP_OUT_PROD");
break;
}
}
+20
View File
@@ -1125,3 +1125,23 @@ void ggml_cann_op_unary_gated(std::function<void(ggml_backend_cann_context &, ac
} while (0)
#endif // CANN_ACLNN_OPS
/**
* @brief Performs outer product operation on two ggml tensors using the CANN backend.
*
* @details This function computes the outer product of two input tensors (src0 and src1)
* and stores the result in the destination tensor. The outer product operation is defined as:
* dst[i,j,k,l] = sum_m (src0[i,m,k,l] * src1[j,m,k,l])
*
* The function supports multiple data types including F32, F16. For floating-point
* types, it uses batch matrix multiplication for efficient computation.
*
* The implementation handles 4D tensor broadcasting and batch processing automatically.
*
* @param ctx The CANN backend context for operation execution and memory management.
* @param dst The destination ggml_tensor where the outer product result will be stored.
* The input tensors are assumed to be `dst->src[0]` and `dst->src[1]`.
*
* @see GGML_CANN_CALL_ACLNN_OP for CANN operator invocation
*/
void ggml_cann_out_prod(ggml_backend_cann_context & ctx, ggml_tensor * dst);
+76 -14
View File
@@ -300,30 +300,92 @@ struct ggml_cann_graph_lru_cache {
struct ggml_cann_rope_cache {
~ggml_cann_rope_cache() {
if (theta_scale_cache != nullptr) {
if (theta_scale_cache) {
ACL_CHECK(aclrtFree(theta_scale_cache));
}
if (sin_cache != nullptr) {
if (sin_cache) {
ACL_CHECK(aclrtFree(sin_cache));
}
if (cos_cache != nullptr) {
if (cos_cache) {
ACL_CHECK(aclrtFree(cos_cache));
}
if (position_select_index) {
ACL_CHECK(aclrtFree(position_select_index));
}
if (theta_scale_exp_host) {
free(theta_scale_exp_host);
}
if(position_select_index_host) {
free(position_select_index_host);
}
}
void * theta_scale_cache = nullptr;
int64_t theta_scale_length = 0;
bool equal(int64_t theta_scale_length,
int64_t position_length,
float ext_factor,
float theta_scale,
float freq_scale,
float attn_factor,
bool is_neox,
bool indep_sects,
bool mrope_used,
bool is_imrope,
int sections[4]) {
return this->theta_scale_length == theta_scale_length && this->position_length == position_length &&
this->ext_factor == ext_factor && this->theta_scale == theta_scale && this->freq_scale == freq_scale &&
this->attn_factor == attn_factor && this->is_neox == is_neox && this->indep_sects == indep_sects &&
this->mrope_used == mrope_used && this->is_imrope == is_imrope && this->sections[0] == sections[0] &&
this->sections[1] == sections[1] && this->sections[2] == sections[2] && this->sections[3] == sections[3];
}
void set(int64_t theta_scale_length,
int64_t position_length,
float ext_factor,
float theta_scale,
float freq_scale,
float attn_factor,
bool is_neox,
bool indep_sects,
bool mrope_used,
bool is_imrope,
int sections[4]) {
this->theta_scale_length = theta_scale_length;
this->position_length = position_length;
this->ext_factor = ext_factor;
this->theta_scale = theta_scale;
this->freq_scale = freq_scale;
this->attn_factor = attn_factor;
this->is_neox = is_neox;
this->indep_sects = indep_sects;
this->mrope_used = mrope_used;
this->is_imrope = is_imrope;
this->sections[0] = sections[0];
this->sections[1] = sections[1];
this->sections[2] = sections[2];
this->sections[3] = sections[3];
}
// memory cache, prepare before inferencing.
void * theta_scale_cache = nullptr;
float * theta_scale_exp_host = nullptr;
int * position_select_index_host = nullptr;
void * position_select_index = nullptr;
// sin/cos cache, used only to accelerate first layer on each device
void * sin_cache = nullptr;
void * cos_cache = nullptr;
int64_t position_length = 0;
void * sin_cache = nullptr;
void * cos_cache = nullptr;
// Properties to check before reusing the sincos cache
bool cached = false;
float ext_factor = 0.0f;
float theta_scale = 0.0f;
float freq_scale = 0.0f;
float attn_factor = 0.0f;
bool is_neox = false;
int64_t theta_scale_length = 0;
int64_t position_length = 0;
bool cached = false;
float ext_factor = 0.0f;
float theta_scale = 0.0f;
float freq_scale = 0.0f;
float attn_factor = 0.0f;
bool is_neox = false;
bool indep_sects = false;
bool mrope_used = false;
int sections[4] = { 0, 0, 0, 0 };
bool is_imrope = false;
};
struct ggml_cann_tensor_cache {
+13 -7
View File
@@ -1886,6 +1886,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
case GGML_OP_FLASH_ATTN_EXT:
ggml_cann_flash_attn_ext(ctx, dst);
break;
case GGML_OP_OUT_PROD:
ggml_cann_out_prod(ctx, dst);
break;
default:
return false;
}
@@ -2477,13 +2480,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
return false;
}
const int mode = ((const int32_t *) op->op_params)[2];
if (mode & GGML_ROPE_TYPE_MROPE) {
return false;
}
if (mode & GGML_ROPE_TYPE_VISION) {
return false;
}
if (op->src[0]->ne[0] > 896) {
return false;
}
@@ -2563,6 +2559,16 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_COUNT_EQUAL:
return true;
case GGML_OP_OUT_PROD:
{
switch (op->src[0]->type) {
case GGML_TYPE_F16:
case GGML_TYPE_F32:
return true;
default:
return false;
}
}
case GGML_OP_CONV_TRANSPOSE_1D:
// TODO: ((weightL - 1) * dilationW - padLeft)=1336 should not be larger than 255.
return (op->src[0]->ne[0] - 1) <= 255;
+29 -15
View File
@@ -224,7 +224,8 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
include(CheckCXXSourceCompiles)
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
set(CMAKE_REQUIRED_FLAGS "${ARCH_FLAGS}")
string(REPLACE ";" " " ARCH_FLAGS_STR "${ARCH_FLAGS}")
set(CMAKE_REQUIRED_FLAGS "${ARCH_FLAGS_STR}")
foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME)
set(ARM_FEATURE "HAVE_${feature}")
check_cxx_source_compiles(
@@ -452,22 +453,35 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
ggml-cpu/spacemit/ime_kernels.h
)
endif()
set(MARCH_STR "rv64gc")
if (GGML_RV_ZFH)
string(APPEND MARCH_STR "_zfh")
endif()
if (GGML_XTHEADVECTOR)
string(APPEND MARCH_STR "_xtheadvector")
elseif (GGML_RVV)
string(APPEND MARCH_STR "_v")
if (GGML_RV_ZVFH)
string(APPEND MARCH_STR "_zvfh")
if(NOT GGML_CPU_ALL_VARIANTS)
set(MARCH_STR "rv64gc")
if (GGML_RV_ZFH)
string(APPEND MARCH_STR "_zfh")
endif()
if (GGML_XTHEADVECTOR)
string(APPEND MARCH_STR "_xtheadvector")
elseif (GGML_RVV)
string(APPEND MARCH_STR "_v")
if (GGML_RV_ZVFH)
string(APPEND MARCH_STR "_zvfh")
endif()
endif()
if (GGML_RV_ZICBOP)
string(APPEND MARCH_STR "_zicbop")
endif()
list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d)
else()
# Begin with the lowest baseline
set(ARCH_DEFINITIONS "")
if (GGML_INTERNAL_RVV)
message(STATUS "RVV enabled")
list(APPEND ARCH_DEFINITIONS GGML_USE_RVV)
list(APPEND ARCH_FLAGS -march=rv64gc_v -mabi=lp64d)
endif()
ggml_add_cpu_backend_features(${GGML_CPU_NAME} riscv ${ARCH_DEFINITIONS})
endif()
if (GGML_RV_ZICBOP)
string(APPEND MARCH_STR "_zicbop")
endif()
list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d)
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
message(STATUS "s390x detected")
list(APPEND GGML_CPU_SOURCES
+22 -2
View File
@@ -33,10 +33,12 @@
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
@@ -44,27 +46,30 @@
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
// repack.cpp
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#elif defined(__POWERPC__) || defined(__powerpc__)
// ref: https://github.com/ggml-org/llama.cpp/pull/14146#issuecomment-2972561679
@@ -76,10 +81,12 @@
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
@@ -87,6 +94,7 @@
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
@@ -101,10 +109,12 @@
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
@@ -112,6 +122,7 @@
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
@@ -134,15 +145,18 @@
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
@@ -163,10 +177,12 @@
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
@@ -174,6 +190,7 @@
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
@@ -196,10 +213,12 @@
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
@@ -207,6 +226,7 @@
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
+721
View File
@@ -24,6 +24,29 @@
#define UNUSED GGML_UNUSED
static inline void decode_q4_Kx8_scales_mins(const uint8_t * scales_in,
int16x8_t * out_mins,
int8_t * out_scales) {
constexpr uint32_t kmask1 = 0x3f3f3f3f;
constexpr uint32_t kmask2 = 0x0f0f0f0f;
constexpr uint32_t kmask3 = 0x03030303;
constexpr uint8_t scales_size = 12;
uint32_t sm[3];
memcpy(sm, scales_in, scales_size);
const uint32_t mins_0_3 = sm[1] & kmask1;
const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4);
const uint32x2_t mins_u32 = { mins_0_3, mins_4_7 };
*out_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins_u32)));
uint32_t scales_u32[2];
scales_u32[0] = sm[0] & kmask1;
scales_u32[1] = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4);
memcpy(out_scales, scales_u32, 8);
}
void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
assert(QK8_0 == 32);
assert(k % QK8_0 == 0);
@@ -474,6 +497,295 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
constexpr int qk = QK_K;
const int nb = n / qk;
constexpr int ncols_interleaved = 8;
constexpr int blocklen = 8;
assert(n % qk == 0);
assert(nr % 4 == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(nb);
UNUSED(ncols_interleaved);
UNUSED(blocklen);
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
constexpr int col_groups = ncols_interleaved / 4; // 0123 and 4567
const uint8x16_t m4b = vdupq_n_u8(0x0f);
// 1x8 tile = 2 x 4
float32x4_t acc_f32[col_groups];
const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
for (int i = 0; i < col_groups; i++) {
acc_f32[i] = vdupq_n_f32(0);
}
for (int b = 0; b < nb; b++) {
float32x4_t q4_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); // d0 d1 d2 d3
float32x4_t q4_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4)); // d4 d5 d6 d7
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
float32x4_t sb_scale_0123 = vmulq_f32(q4_d_0, q8_d);
float32x4_t sb_scale_4567 = vmulq_f32(q4_d_1, q8_d);
float32x4_t q4_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); // dmin 0..3
float32x4_t q4_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4)); // dmin 4..7
float32x4_t sb_min_0123 = vmulq_f32(q4_dmin_0, q8_d);
float32x4_t sb_min_4567 = vmulq_f32(q4_dmin_1, q8_d);
// interleaved bias_acc: [0]->r0 0123, [1]->r0 4567
int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
int32x4_t acc_lo[col_groups];
int32x4_t acc_hi[col_groups];
// Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
int16_t bsums_arr[8];
vst1q_s16(bsums_arr, bsums);
for (int sb = 0; sb < QK_K / 64; sb++) {
for (int i = 0; i < col_groups; i++) {
acc_lo[i] = vdupq_n_s32(0);
acc_hi[i] = vdupq_n_s32(0);
}
// Need scales for the low and high nibbles
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
int16x8_t q4sb_mins[2];
int16x8_t q4sb_scales[2];
for (int i = 0; i < 2; i++) {
int8_t aux_q4sb[8];
const int offset = sb * 24 + i * 12;
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
}
int8x16_t q8_qs[64 / 16];
for (int i = 0; i < 64 / 16; i++) {
q8_qs[i] = vld1q_s8(q8_ptr[b].qs + sb * 64 + i * 16);
}
for (int c = 0; c < col_groups; c++) {
uint8x16_t q4_cols[8];
for (int i = 0; i < 8; i++) {
q4_cols[i] = vld1q_u8(q4_ptr[b].qs + sb * QK_K + i * 32 + 16 * c);
}
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[0], m4b)), q8_qs[0], 0);
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[1], m4b)), q8_qs[0], 1);
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[2], m4b)), q8_qs[0], 2);
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[3], m4b)), q8_qs[0], 3);
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[4], m4b)), q8_qs[1], 0);
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[5], m4b)), q8_qs[1], 1);
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[6], m4b)), q8_qs[1], 2);
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[7], m4b)), q8_qs[1], 3);
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[0], 4)), q8_qs[2], 0);
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[1], 4)), q8_qs[2], 1);
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[2], 4)), q8_qs[2], 2);
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[3], 4)), q8_qs[2], 3);
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[4], 4)), q8_qs[3], 0);
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[5], 4)), q8_qs[3], 1);
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[6], 4)), q8_qs[3], 2);
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[7], 4)), q8_qs[3], 3);
}
// Scales
// row c0123 blk0 and blk1
const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]);
const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]);
const float32x4_t sumf_0123 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[0]),
vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[0])));
acc_f32[0] = vfmaq_f32(acc_f32[0], sb_scale_0123, sumf_0123);
// row c4567 blk0 and blk1
const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]);
const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]);
const float32x4_t sumf_4567 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[1]),
vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[1])));
acc_f32[1] = vfmaq_f32(acc_f32[1], sb_scale_4567, sumf_4567);
// Bias Correction
const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
} // for sb
acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0123);
acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_4567);
} // for b
int base = x * ncols_interleaved;
vst1q_f32(s + base, acc_f32[0]);
vst1q_f32(s + base + 4, acc_f32[1]);
} // for x
return;
#endif // #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
ggml_gemv_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemv_q4_K_8x8_q8_K(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
constexpr int qk = QK_K;
const int nb = n / qk;
constexpr int ncols_interleaved = 8;
constexpr int blocklen = 8;
assert(n % qk == 0);
assert(nr % 4 == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(nb);
UNUSED(ncols_interleaved);
UNUSED(blocklen);
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
constexpr int col_pairs = ncols_interleaved / 2;
const uint8x16_t m4b = vdupq_n_u8(0x0f);
// 1x8 tile = 2 x 4
float32x4_t acc_f32[ncols_interleaved / 4];
const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
for (int i = 0; i < ncols_interleaved / 4; i++) {
acc_f32[i] = vdupq_n_f32(0);
}
for (int b = 0; b < nb; b++) {
float32x4_t q4_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); // d0 d1 d2 d3
float32x4_t q4_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4)); // d4 d5 d6 d7
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
float32x4_t sb_scale_0 = vmulq_f32(q4_d_0, q8_d);
float32x4_t sb_scale_1 = vmulq_f32(q4_d_1, q8_d);
float32x4_t q4_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); // dmin 0..3
float32x4_t q4_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4)); // dmin 4..7
float32x4_t sb_min_0 = vmulq_f32(q4_dmin_0, q8_d);
float32x4_t sb_min_1 = vmulq_f32(q4_dmin_1, q8_d);
// interleaved bias_acc: [0]->r0 0123, [1]->r0 4567
int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
// 2 sb each iteration
int32x4_t acc_lo[col_pairs];
int32x4_t acc_hi[col_pairs];
// Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
int16_t bsums_arr[8];
vst1q_s16(bsums_arr, bsums);
for (int sb = 0; sb < QK_K / 64; sb++) {
for (int i = 0; i < col_pairs; i++) {
acc_lo[i] = vdupq_n_s32(0);
acc_hi[i] = vdupq_n_s32(0);
}
// Need scales for the low and high nibbles
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later
int16x8_t q4sb_scales[2];
for (int i = 0; i < 2; i++) {
int8_t aux_q4sb[8];
const int offset = sb * 24 + i * 12;
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
}
const uint8_t * q4_base = q4_ptr[b].qs + sb * QK_K;
// Load the 64 quants from q8K duplicated to use vecdots with the interelaved columns
// but still need the qs to use the low and hi bits from q4
const int8_t * q8_base = q8_ptr[b].qs + sb * 64;
int8x16_t q8_qs[8];
for (int i = 0; i < 8; i++) {
q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base + i * 8));
}
// Q4s columns iterated in pairs (01, 23, 45, 67)
for (int cp = 0; cp < col_pairs; cp++) {
uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_base + 16 * cp);
uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_base + 16 * cp + 64);
uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_base + 16 * cp + 128);
uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_base + 16 * cp + 192);
acc_lo[cp] =
ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)), q8_qs[0]); // 0 .. 7
acc_lo[cp] =
ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)), q8_qs[1]); // 8 ..15
acc_lo[cp] =
ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)), q8_qs[2]); // 16..23
acc_lo[cp] =
ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)), q8_qs[3]); // 24..31
acc_hi[cp] =
ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)), q8_qs[4]); // 32..39
acc_hi[cp] =
ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)), q8_qs[5]); // 40..47
acc_hi[cp] =
ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)), q8_qs[6]); // 48..55
acc_hi[cp] =
ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)), q8_qs[7]); // 56..63
}
// Iterates over a pair of column pairs (4 columns) to use a single 128 register
// p = 0 -> 0123 p2 -> 4567
for (int i = 0, p = 0; p < col_pairs; i++, p += 2) {
int16x4_t group_scales_lo = p == 0 ? vget_low_s16(q4sb_scales[0]) : vget_high_s16(q4sb_scales[0]);
int16x4_t group_scales_hi = p == 0 ? vget_low_s16(q4sb_scales[1]) : vget_high_s16(q4sb_scales[1]);
float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1;
// 0123 or 4567
float32x4_t sumf_0 =
vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1])));
acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0);
float32x4_t sumf_1 =
vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_hi), vpaddq_s32(acc_hi[p], acc_hi[p + 1])));
acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_1);
}
// Multiply Acc bsum + mins
// Each pair of subblocks share the same bsums
// Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
// cols 0-3 bias
bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
// cols 4-7 bias
bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
} // for sb
acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0);
acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_1);
} // for b
int base = x * ncols_interleaved;
vst1q_f32(s + base, acc_f32[0]);
vst1q_f32(s + base + 4, acc_f32[1]);
} // for x
return;
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0;
const int nb = n / qk;
@@ -1889,3 +2201,412 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
constexpr int qk = QK_K;
const int nb = n / qk;
constexpr int ncols_interleaved = 8;
constexpr int blocklen = 4;
assert(n % qk == 0);
assert(nr % 4 == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(nb);
UNUSED(ncols_interleaved);
UNUSED(blocklen);
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
constexpr int q8_k_blocklen = 4;
constexpr int acc_size = 2 * 4; // 2 row pairs × 4 col pairs
const uint8x16_t m4b = vdupq_n_u8(0x0f);
// 8 accumulators: 2 row pairs × 4 col pairs
float32x4_t acc_f32[acc_size];
for (int y = 0; y < nr / q8_k_blocklen; y++) {
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
for (int i = 0; i < acc_size; i++) {
acc_f32[i] = vdupq_n_f32(0);
}
for (int b = 0; b < nb; b++) {
// d4 0 1 2 3, 4 5 6 7
float32x4_t q4_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d));
float32x4_t q4_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4));
// d8 0 1 2 3
float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d);
// mins
float32x4_t q4_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin));
float32x4_t q4_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4));
// Precomputation of scales and mins
float32x4_t sbd_scale_0123[q8_k_blocklen];
float32x4_t sbd_scale_4567[q8_k_blocklen];
float32x4_t sbd_min_0123[q8_k_blocklen];
float32x4_t sbd_min_4567[q8_k_blocklen];
sbd_scale_0123[0] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 0);
sbd_scale_4567[0] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 0);
sbd_min_0123[0] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 0);
sbd_min_4567[0] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 0);
sbd_scale_0123[1] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 1);
sbd_scale_4567[1] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 1);
sbd_min_0123[1] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 1);
sbd_min_4567[1] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 1);
sbd_scale_0123[2] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 2);
sbd_scale_4567[2] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 2);
sbd_min_0123[2] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 2);
sbd_min_4567[2] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 2);
sbd_scale_0123[3] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 3);
sbd_scale_4567[3] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 3);
sbd_min_0123[3] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 3);
sbd_min_4567[3] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 3);
// Precomputation of bsums, each vpaddq calcs all the bsums for each row
const int16x8_t bsums[q8_k_blocklen] = {
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
};
int16_t bsums_arr[QK_K / 64][8];
for (int q8_row = 0; q8_row < 4; q8_row++) {
vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
}
// interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 ..
int32x4_t bias_acc[acc_size];
for (int i = 0; i < acc_size; i++) {
bias_acc[i] = vdupq_n_s32(0);
}
for (int sb = 0; sb < QK_K / 64; sb++) {
// Int accumulators for qs vecdot (4 row x 2 col quartets)
int32x4_t acc_lo[acc_size];
int32x4_t acc_hi[acc_size];
for (int i = 0; i < acc_size; i++) {
acc_lo[i] = vdupq_n_s32(0);
acc_hi[i] = vdupq_n_s32(0);
}
// Need scales for the low and high nibbles
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
int16x8_t q4sb_scales[2];
int16x8_t q4sb_mins[2];
for (int i = 0; i < 2; i++) {
int8_t aux_q4sb[8];
const int offset = sb * 24 + i * 12;
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
}
constexpr int reads_per_sb = 8; // 8 * 16 bytes each => 32 qs * 4 rows
for (int k = 0; k < reads_per_sb; k++) {
const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k);
const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128);
// 0..3 & 32..35
const uint8x16_t q4_0123 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k);
const uint8x16_t q4_4567 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k + 16);
const int8x16_t q4_0123_lo = vreinterpretq_s8_u8(vandq_u8(q4_0123, m4b));
const int8x16_t q4_0123_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_0123, 4));
acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q4_0123_lo, q8_blk0, 0); // 0..3 r0 c0123
acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q4_0123_lo, q8_blk0, 1); // 0..3 r1 c0123
acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q4_0123_lo, q8_blk0, 2); // 0..3 r2 c0123
acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q4_0123_lo, q8_blk0, 3); // 0..3 r3 c0123
acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q4_0123_hi, q8_blk1, 0); // 32..35 r0 c0123
acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q4_0123_hi, q8_blk1, 1); // 32..35 r1 c0123
acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q4_0123_hi, q8_blk1, 2); // 32..35 r2 c0123
acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q4_0123_hi, q8_blk1, 3); // 32..35 r3 c0123
const int8x16_t q4_4567_lo = vreinterpretq_s8_u8(vandq_u8(q4_4567, m4b));
const int8x16_t q4_4567_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_4567, 4));
acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q4_4567_lo, q8_blk0, 0); // 0..3 r0 c4567
acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q4_4567_lo, q8_blk0, 1); // 0..3 r1 c4567
acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q4_4567_lo, q8_blk0, 2); // 0..3 r2 c4567
acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q4_4567_lo, q8_blk0, 3); // 0..3 r3 c4567
acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q4_4567_hi, q8_blk1, 0); // 32..35 r0 c4567
acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q4_4567_hi, q8_blk1, 1); // 32..35 r1 c4567
acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q4_4567_hi, q8_blk1, 2); // 32..35 r2 c4567
acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q4_4567_hi, q8_blk1, 3); // 32..35 r3 c4567
}
// Scale and bias application
// acc is stored interleaved to match output layout
const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]);
const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]);
const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]);
const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]);
for (int row = 0; row < q8_k_blocklen; row++) {
// Bias correction
// row c0123 blk0 and blk1
const float32x4_t sumf_0123 =
vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]),
vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row])));
acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123);
// row c4567 blk0 and blk1
const float32x4_t sumf_4567 =
vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]),
vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4])));
acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567);
// Bias
const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]);
const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]);
// row c0123 blk0 and blk1
bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
// row c4567 blk0 and blk1
bias_acc[2 * row + 1] =
vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
bias_acc[2 * row + 1] =
vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
}
} // for sb
for (int row = 0; row < q8_k_blocklen; row++) {
acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]);
acc_f32[2 * row + 1] =
vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]);
}
} // for b
for (int i = 0; i < q8_k_blocklen; i++) {
int row = y * q8_k_blocklen + i;
for (int j = 0; j < 2; j++) {
int col = x * ncols_interleaved + j * 4;
int offset = row * bs + col;
vst1q_f32(s + offset, acc_f32[2 * i + j]);
}
}
} // for x
} // for y
return;
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
ggml_gemm_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemm_q4_K_8x8_q8_K(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
constexpr int qk = QK_K;
const int nb = n / qk;
constexpr int ncols_interleaved = 8;
constexpr int blocklen = 8;
assert(n % qk == 0);
assert(nr % 4 == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(nb);
UNUSED(ncols_interleaved);
UNUSED(blocklen);
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
constexpr int q8_k_blocklen = 4;
const uint8x16_t m4b = vdupq_n_u8(0x0f);
// 8 accumulators: 2 row pairs × 4 col pairs
float32x4_t acc_f32[blocklen];
for (int y = 0; y < nr / q8_k_blocklen; y++) {
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
for (int i = 0; i < blocklen; i++) {
acc_f32[i] = vdupq_n_f32(0);
}
for (int b = 0; b < nb; b++) {
// bsums pairs belongs to the same q8_k subblock
const int16x8_t bsums[4]{
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
};
int16_t bsums_arr[4][8];
for (int q8_row = 0; q8_row < 4; q8_row++) {
vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
}
int32x4_t sb_acc[4]; // Aux accumulators to store subblock (partial) results
int32x4_t acc[8]; // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7]
int32x4_t bias_acc[8]; // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ...
for (int i = 0; i < 8; i++) {
acc[i] = vdupq_n_s32(0);
bias_acc[i] = vdupq_n_s32(0);
}
for (int sb = 0; sb < QK_K / 64; sb++) {
// Need scales for the low and high nibbles
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
int8_t q4sb_scales[2][8];
int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later
for (int i = 0; i < 2; i++) {
const int offset = sb * 24 + i * 12;
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]);
}
// q8_ptr[b].qs has interleaved Q8 rows (01, 23)
const int8_t * q8_base = q8_ptr[b].qs + sb * 256;
int8x16_t q8_qs_01[8];
int8x16_t q8_qs_23[8];
// Load 32-byte per row pair, 1 subblock each time
for (int i = 0; i < 8; i++) {
const int offset = i * 32; // 16 for row 01, 16 for row 23
q8_qs_01[i] = vld1q_s8(q8_base + offset);
q8_qs_23[i] = vld1q_s8(q8_base + offset + 16);
}
const int8x16_t q8s[2][8] = {
{ q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3],
q8_qs_01[4], q8_qs_01[5], q8_qs_01[6], q8_qs_01[7] },
{ q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3],
q8_qs_23[4], q8_qs_23[5], q8_qs_23[6], q8_qs_23[7] },
};
// Q4s columns iterated in pairs (01, 23, 45, 67)
for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
for (int i = 0; i < 4; i++) {
sb_acc[i] = vdupq_n_s32(0);
}
uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 0); // 0 .. 7 & 32..39
uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 64); // 8 ..15 & 40..47
uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 128); // 16..23 & 48..55
uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 192); // 24..31 & 56..63
const int8x16_t q4_nibbles[2][4] = {
{
vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)),
vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)),
vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)),
vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)),
},
{
vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)),
vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)),
vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)),
vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)),
}
};
// Calculates the Qs muladd of every row pair (rp) rows 01 and 23 of q8
// for each of the internal 32 qs subblock (blk)
for (int rp = 0; rp < 2; rp++) {
for (int blk = 0; blk < 2; blk++) {
const int8x16_t * q8 = &q8s[rp][4 * blk];
const int8x16_t * q4 = q4_nibbles[blk];
int32x4_t acc = sb_acc[2 * rp + blk];
// mul add for each qs in the same subblock
for (int qs_offset = 0; qs_offset < 4; qs_offset++) {
acc = vmmlaq_s32(acc, q4[qs_offset], q8[qs_offset]);
}
sb_acc[2 * rp + blk] = acc;
}
}
// Scales[i] corresponds to column i
const int scale_offset = cp * 2;
for (int blk = 0; blk < 2; blk++) {
const int32x4_t block_scale = {
(int32_t) q4sb_scales[blk][scale_offset],
(int32_t) q4sb_scales[blk][scale_offset],
(int32_t) q4sb_scales[blk][scale_offset + 1],
(int32_t) q4sb_scales[blk][scale_offset + 1],
};
acc[cp] = vmlaq_s32(acc[cp], sb_acc[blk], block_scale);
acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[blk + 2], block_scale);
}
}
// Multiply Acc bsum + mins
for (int q8_row = 0; q8_row < 4; q8_row++) {
// Each pair of subblocks share the same bsums
// Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]);
int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]);
bias_acc[2 * q8_row] =
vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
bias_acc[2 * q8_row] =
vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
bias_acc[2 * q8_row + 1] =
vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
bias_acc[2 * q8_row + 1] =
vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
}
} // for sb
// Reorder of i8mm output with bias and output layout
for (int i = 0; i < 8; i++) {
int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));
acc[i] = vcombine_s32(aux.val[0], aux.val[1]);
}
int32x4_t reorder_acc[8] = {
vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])),
vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])),
vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])),
vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])),
vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])),
vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])),
vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])),
vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),
};
for (int i = 0; i < q8_k_blocklen; i++) {
for (int j = 0; j < 2; j++) {
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]);
float32x4_t q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].dmin + j * 4)));
const float32x4_t dmins = vmulq_f32(q4_dmin, q8_d);
float32x4_t q4_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].d + j * 4)));
const float32x4_t scale = vmulq_f32(q4_d, q8_d);
acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins);
acc_f32[2 * i + j] =
vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);
}
}
} // for b
// With the previous reorder, the tile is already in the correct memory layout.
for (int i = 0; i < q8_k_blocklen; i++) {
int row = y * q8_k_blocklen + i;
for (int j = 0; j < 2; j++) {
int col = x * ncols_interleaved + j * 4;
int offset = row * bs + col;
vst1q_f32(s + offset, acc_f32[2 * i + j]);
}
}
} // for x
} // for y
return;
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
@@ -0,0 +1,38 @@
#include "ggml-backend-impl.h"
#if defined(__riscv) && __riscv_xlen == 64
#include <asm/hwprobe.h>
#include <asm/unistd.h>
#include <unistd.h>
struct riscv64_features {
bool has_rvv = false;
riscv64_features() {
struct riscv_hwprobe probe;
probe.key = RISCV_HWPROBE_KEY_IMA_EXT_0;
probe.value = 0;
int ret = syscall(__NR_riscv_hwprobe, &probe, 1, 0, NULL, 0);
if (0 == ret) {
has_rvv = !!(probe.value & RISCV_HWPROBE_IMA_V);
}
}
};
static int ggml_backend_cpu_riscv64_score() {
int score = 1;
riscv64_features rf;
#ifdef GGML_USE_RVV
if (!rf.has_rvv) { return 0; }
score += 1 << 1;
#endif
return score;
}
GGML_BACKEND_DL_SCORE_IMPL(ggml_backend_cpu_riscv64_score)
#endif // __riscv && __riscv_xlen == 64
+9
View File
@@ -1927,6 +1927,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_argsort(params, tensor);
} break;
case GGML_OP_TOP_K:
{
ggml_compute_forward_top_k(params, tensor);
} break;
case GGML_OP_LEAKY_RELU:
{
ggml_compute_forward_leaky_relu(params, tensor);
@@ -2311,6 +2315,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_ARGSORT:
case GGML_OP_TOP_K:
case GGML_OP_FLASH_ATTN_EXT:
case GGML_OP_FLASH_ATTN_BACK:
case GGML_OP_SSM_CONV:
@@ -2834,6 +2839,10 @@ struct ggml_cplan ggml_graph_plan(
cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
} break;
case GGML_OP_TOP_K:
{
cur += sizeof(int32_t)*node->src[0]->ne[0]*n_tasks;
} break;
case GGML_OP_FLASH_ATTN_EXT:
{
const int64_t ne10 = node->src[1]->ne[0]; // DK
+71 -4
View File
@@ -7794,7 +7794,7 @@ void ggml_compute_forward_timestep_embedding(
// ggml_compute_forward_argsort
template<enum ggml_sort_order order>
struct argsort_cmp {
struct cmp_argsort {
const float * data;
bool operator()(int32_t a, int32_t b) const {
if constexpr (order == GGML_SORT_ORDER_ASC) {
@@ -7833,11 +7833,11 @@ static void ggml_compute_forward_argsort_f32(
switch (order) {
case GGML_SORT_ORDER_ASC:
std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_ASC>{src_data});
std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_ASC>{src_data});
break;
case GGML_SORT_ORDER_DESC:
std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_DESC>{src_data});
std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_DESC>{src_data});
break;
default:
@@ -7864,6 +7864,72 @@ void ggml_compute_forward_argsort(
}
}
// ggml_compute_forward_top_k
struct cmp_top_k {
const float * data;
bool operator()(int32_t a, int32_t b) const {
return data[a] > data[b];
}
};
static void ggml_compute_forward_top_k_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
GGML_TENSOR_UNARY_OP_LOCALS
GGML_ASSERT(nb0 == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
const int64_t nr = ggml_nrows(src0);
const int top_k = ne0;
int32_t * tmp = (int32_t *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
for (int64_t i = ith; i < nr; i += nth) {
const float * src_data = (float *)((char *) src0->data + i*nb01);
for (int64_t j = 0; j < ne00; j++) {
tmp[j] = j;
}
std::partial_sort(tmp, tmp + top_k, tmp + ne00, cmp_top_k{src_data});
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
std::copy(tmp, tmp + top_k, dst_data);
// emphasize that the order is not important
if (top_k > 1) {
std::swap(dst_data[0], dst_data[1]);
}
}
}
void ggml_compute_forward_top_k(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_top_k_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_flash_attn_ext
static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
@@ -9700,7 +9766,8 @@ static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params
}
const float diag = A_batch[i00 * n + i00];
GGML_ASSERT(diag != 0.0f && "Zero diagonal in triangular matrix");
assert(diag != 0.0f && "Zero diagonal in triangular matrix");
X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag;
}
}
+1
View File
@@ -81,6 +81,7 @@ void ggml_compute_forward_roll(const struct ggml_compute_params * params, struct
void ggml_compute_forward_arange(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_top_k(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_fill(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+243 -4
View File
@@ -124,6 +124,58 @@ void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GG
}
}
void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
assert(QK_K == 256);
assert(k % QK_K == 0);
const int nb = k / QK_K;
block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
// scalar
const int blck_size_interleave = 4;
float srcv[4][QK_K];
float iscale[4];
for (int i = 0; i < nb; i++) {
for (int row_iter = 0; row_iter < 4; row_iter++) {
float amax = 0.0f; // absolute max
float max = 0;
for (int j = 0; j < QK_K; j++) {
srcv[row_iter][j] = x[row_iter * k + i * QK_K + j];
// Update the maximum value of the corresponding super block
if(amax < fabsf(srcv[row_iter][j])) {
amax = fabsf(srcv[row_iter][j]);
max = srcv[row_iter][j];
}
}
iscale[row_iter] = amax ? -127.f/max : 0;
y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0;
}
for (int j = 0; j < QK_K / 4; j++) {
y[i].bsums[j] = 0;
}
// Quants values are interleaved in sequence of four bytes from corresponding super blocks
// Bsums values are interleaved in sequence of four bsums from each super block taken for interleaving
// i.e first four bsums from the first super block, followed by first four bsums from second super block and so on
for (int j = 0; j < QK_K * 4; j++) {
int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
src_offset += (j % blck_size_interleave);
int index = (((j & 15) >> 2) << 2) + ((j >> 8) << 4) + ((j >> 6) & 3);
float x0 = srcv[src_id][src_offset] * iscale[src_id];
y[i].qs[j] = nearest_int(x0);
y[i].bsums[index] += y[i].qs[j];
}
}
}
void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
assert(QK_K == 256);
assert(k % QK_K == 0);
@@ -192,6 +244,12 @@ template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTR
ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row);
}
template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
assert(nrow == 4);
UNUSED(nrow);
ggml_quantize_mat_q8_K_4x4(x, vy, n_per_row);
}
template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
assert(nrow == 4);
UNUSED(nrow);
@@ -333,6 +391,77 @@ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
}
}
void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK_K;
const int nb = n / qk;
const int ncols_interleaved = 8;
const int blocklen = 4;
static const uint32_t kmask1 = 0x3f3f3f3f;
static const uint32_t kmask2 = 0x0f0f0f0f;
static const uint32_t kmask3 = 0x03030303;
assert (n % qk == 0);
assert (nc % ncols_interleaved == 0);
UNUSED(bs);
UNUSED(nr);
float sumf[8];
float sum_minf[8];
uint32_t utmp[32];
int sumi1;
int sumi2;
int sumi;
const block_q8_K * a_ptr = (const block_q8_K *) vy;
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
for (int j = 0; j < ncols_interleaved; j++) {
sumf[j] = 0.0;
sum_minf[j] = 0.0;
}
for (int l = 0; l < nb; l++) {
for (int sb = 0; sb < 8; sb++) {
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
utmp[sb * 4 + 2] = uaux_0;
utmp[sb * 4 + 0] &= kmask1;
}
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
uint8_t * scales_0 = (uint8_t *) utmp + (k / 8) * 32;
uint8_t * scales_1 = (uint8_t *) utmp + (k / 8) * 32 + 16;
for (int j = 0; j < ncols_interleaved; j++) {
sumi1 = 0;
sumi2 = 0;
sumi = 0;
for (int i = 0; i < blocklen; ++i) {
const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
sumi1 = (v0 * a_ptr[l].qs[(k / 8) * 64 + (k % 8) * blocklen + i]);
sumi2 = (v1 * a_ptr[l].qs[(k / 8) * 64 + (k % 8) * blocklen + i + 32]);
sumi1 = sumi1 * scales_0[j];
sumi2 = sumi2 * scales_1[j];
sumi += sumi1 + sumi2;
}
sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
}
}
for (int sb = 0; sb < 8; sb++) {
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
for (int j = 0; j < ncols_interleaved; j++) {
sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
}
}
}
for (int j = 0; j < ncols_interleaved; j++) {
s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
}
}
}
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK_K;
const int nb = n / qk;
@@ -727,6 +856,89 @@ void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
}
}
void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK_K;
const int nb = n / qk;
const int ncols_interleaved = 8;
const int blocklen = 4;
static const uint32_t kmask1 = 0x3f3f3f3f;
static const uint32_t kmask2 = 0x0f0f0f0f;
static const uint32_t kmask3 = 0x03030303;
assert (n % qk == 0);
assert (nr % 4 == 0);
assert (nc % ncols_interleaved == 0);
UNUSED(nb);
UNUSED(ncols_interleaved);
UNUSED(blocklen);
float sumf[4][8];
float sum_minf[4][8];
uint32_t utmp[32];
int sumi1;
int sumi2;
int sumi;
for (int y = 0; y < nr / 4; y++) {
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
sumf[m][j] = 0.0;
sum_minf[m][j] = 0.0;
}
}
for (int l = 0; l < nb; l++) {
for (int sb = 0; sb < 8; sb++) {
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
utmp[sb * 4 + 2] = uaux_0;
utmp[sb * 4 + 0] &= kmask1;
}
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
uint8_t * scales_0 = (uint8_t *) utmp + (k / 8) * 32;
uint8_t * scales_1 = (uint8_t *) utmp + (k / 8) * 32 + 16;
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
sumi1 = 0;
sumi2 = 0;
sumi = 0;
for (int i = 0; i < blocklen; ++i) {
const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
sumi1 = (v0 * a_ptr[l].qs[(k / 8) * 256 + (k % 8) * 4 * blocklen + m * blocklen + i]);
sumi2 = (v1 * a_ptr[l].qs[(k / 8) * 256 + (k % 8) * 4 * blocklen + m * blocklen + i + 128]);
sumi1 = sumi1 * scales_0[j];
sumi2 = sumi2 * scales_1[j];
sumi += sumi1 + sumi2;
}
sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
}
}
}
for (int sb = 0; sb < 8; sb++) {
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
for(int m = 0; m < 4; m++) {
const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
for(int j = 0; j < ncols_interleaved; j++) {
sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
}
}
}
}
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
}
}
}
}
}
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK_K;
const int nb = n / qk;
@@ -1228,9 +1440,10 @@ static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block
GGML_UNUSED(data_size);
}
static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
GGML_ASSERT(t->type == GGML_TYPE_Q4_K);
GGML_ASSERT(interleave_block == 8);
GGML_ASSERT(interleave_block == 8 || interleave_block == 4);
constexpr int nrows_interleaved = 8;
block_q4_Kx8 * dst = (block_q4_Kx8*)t->data;
@@ -1468,6 +1681,10 @@ template <> int repack<block_q4_K, 8, 8>(struct ggml_tensor * t, const void * da
return repack_q4_K_to_q4_K_8_bl(t, 8, data, data_size);
}
template <> int repack<block_q4_K, 4, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
return repack_q4_K_to_q4_K_8_bl(t, 4, data, data_size);
}
template <> int repack<block_q2_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size);
}
@@ -1501,6 +1718,10 @@ template <> void gemv<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t
ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
template <> void gemv<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
}
template <> void gemv<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
}
@@ -1529,6 +1750,10 @@ template <> void gemm<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t
ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
template <> void gemm<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
}
template <> void gemm<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
@@ -1731,12 +1956,13 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
nchunk0 = (nr0 + min_chunk_size - 1) / min_chunk_size;
}
if (nth == 1 || nchunk0 < nth || disable_chunking) {
int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
// Only increase nchunk0 to nth if it won't make chunks too small
if (nth == 1 || ((nchunk0 < nth || disable_chunking) && (nr0 + nth - 1) / nth >= min_chunk_size)) {
nchunk0 = nth;
dr0 = (nr0 + nchunk0 - 1) / nchunk0;
}
const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
// Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size
// This prevents creating too many tiny chunks that could overlap after alignment
const int64_t max_nchunk = (nr0 + min_chunk_size - 1) / min_chunk_size;
@@ -1930,6 +2156,9 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
static const ggml::cpu::repack::tensor_traits<block_q4_0, 4, 4, GGML_TYPE_Q8_0> q4_0_4x4_q8_0;
static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 4, GGML_TYPE_Q8_0> q4_0_4x8_q8_0;
static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 8, GGML_TYPE_Q8_0> q4_0_8x8_q8_0;
// instance for Q4_K
static const ggml::cpu::repack::tensor_traits<block_q4_K, 4, 8, GGML_TYPE_Q8_K> q4_K_8x4_q8_K;
static const ggml::cpu::repack::tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
// instance for Q2
@@ -1961,6 +2190,16 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
return &q4_K_8x8_q8_K;
}
}
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
if (cur->ne[1] % 8 == 0) {
return &q4_K_8x8_q8_K;
}
}
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
if (cur->ne[1] % 8 == 0) {
return &q4_K_8x4_q8_K;
}
}
} else if (cur->type == GGML_TYPE_Q2_K) {
if (ggml_cpu_has_avx512()) {
if (cur->ne[1] % 8 == 0) {
+6
View File
@@ -80,10 +80,12 @@ extern "C" {
void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_quantize_mat_q8_K_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@@ -91,6 +93,7 @@ void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@@ -99,10 +102,12 @@ void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
// Native implementations
void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@@ -110,6 +115,7 @@ void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+90 -91
View File
@@ -397,119 +397,118 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
}
inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) {
#if defined(GGML_SIMD)
#if defined(__ARM_FEATURE_SVE)
const int sve_register_length = svcntb() * 8;
const int ggml_f16_epr = sve_register_length / 16;
const int ggml_f16_step = 8 * ggml_f16_epr;
#if defined(GGML_SIMD) && defined(__ARM_FEATURE_SVE)
const int sve_register_length = svcntb() * 8;
const int ggml_f16_epr = sve_register_length / 16;
const int ggml_f16_step = 8 * ggml_f16_epr;
GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
const int np= (n & ~(ggml_f16_step - 1));
int np = (n & ~(ggml_f16_step - 1));
svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
for (int i = 0; i < np; i += ggml_f16_step) {
ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0);
ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0);
ay1 = GGML_F16x_VEC_FMA(ay1, ax1, vx);
svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
for (int i = 0; i < np; i += ggml_f16_step) {
ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0);
ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0);
ay1 = GGML_F16x_VEC_FMA(ay1, ax1, vx);
GGML_F16x_VEC_STORE(y + i + 0 * ggml_f16_epr, ay1, 0);
GGML_F16x_VEC_STORE(y + i + 0 * ggml_f16_epr, ay1, 0);
ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1);
ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1);
ay2 = GGML_F16x_VEC_FMA(ay2, ax2, vx);
ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1);
ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1);
ay2 = GGML_F16x_VEC_FMA(ay2, ax2, vx);
GGML_F16x_VEC_STORE(y + i + 1 * ggml_f16_epr, ay2, 1);
GGML_F16x_VEC_STORE(y + i + 1 * ggml_f16_epr, ay2, 1);
ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2);
ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
ay3 = GGML_F16x_VEC_FMA(ay3, ax3, vx);
ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2);
ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
ay3 = GGML_F16x_VEC_FMA(ay3, ax3, vx);
GGML_F16x_VEC_STORE(y + i + 2 * ggml_f16_epr, ay3, 2);
GGML_F16x_VEC_STORE(y + i + 2 * ggml_f16_epr, ay3, 2);
ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3);
ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
ay4 = GGML_F16x_VEC_FMA(ay4, ax4, vx);
ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3);
ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
ay4 = GGML_F16x_VEC_FMA(ay4, ax4, vx);
GGML_F16x_VEC_STORE(y + i + 3 * ggml_f16_epr, ay4, 3);
GGML_F16x_VEC_STORE(y + i + 3 * ggml_f16_epr, ay4, 3);
ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4);
ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
ay5 = GGML_F16x_VEC_FMA(ay5, ax5, vx);
ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4);
ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
ay5 = GGML_F16x_VEC_FMA(ay5, ax5, vx);
GGML_F16x_VEC_STORE(y + i + 4 * ggml_f16_epr, ay5, 4);
GGML_F16x_VEC_STORE(y + i + 4 * ggml_f16_epr, ay5, 4);
ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5);
ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
ay6 = GGML_F16x_VEC_FMA(ay6, ax6, vx);
ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5);
ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
ay6 = GGML_F16x_VEC_FMA(ay6, ax6, vx);
GGML_F16x_VEC_STORE(y + i + 5 * ggml_f16_epr, ay6, 5);
GGML_F16x_VEC_STORE(y + i + 5 * ggml_f16_epr, ay6, 5);
ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6);
ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
ay7 = GGML_F16x_VEC_FMA(ay7, ax7, vx);
ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6);
ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
ay7 = GGML_F16x_VEC_FMA(ay7, ax7, vx);
GGML_F16x_VEC_STORE(y + i + 6 * ggml_f16_epr, ay7, 6);
GGML_F16x_VEC_STORE(y + i + 6 * ggml_f16_epr, ay7, 6);
ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7);
ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
ay8 = GGML_F16x_VEC_FMA(ay8, ax8, vx);
ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7);
ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
ay8 = GGML_F16x_VEC_FMA(ay8, ax8, vx);
GGML_F16x_VEC_STORE(y + i + 7 * ggml_f16_epr, ay8, 7);
GGML_F16x_VEC_STORE(y + i + 7 * ggml_f16_epr, ay8, 7);
}
const int np2 = (n & ~(ggml_f16_epr - 1));
for (int k = np; k < np2; k += ggml_f16_epr) {
svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);
svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
ry = GGML_F16x_VEC_FMA(ry, rx, vx);
GGML_F16x_VEC_STORE(y + k, ry, 0);
}
if (np2 < n) {
svbool_t pg = svwhilelt_b16(np2, n);
svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
hy = svmad_f16_x(pg, hx, vx, hy);
svst1_f16(pg, (__fp16 *)(y + np2), hy);
}
np = n;
#elif defined(__riscv_zvfh) // implies __riscv_v_intrinsic
const int np = n;
_Float16 hv = (_Float16)v;
for (int i = 0, avl; i < n; i += avl) {
avl = __riscv_vsetvl_e16m8(n - i);
vfloat16m8_t ax = __riscv_vle16_v_f16m8((const _Float16 *)&x[i], avl);
vfloat16m8_t ay = __riscv_vle16_v_f16m8((_Float16 *)&y[i], avl);
vfloat16m8_t ny = __riscv_vfmadd_vf_f16m8(ax, hv, ay, avl);
__riscv_vse16_v_f16m8((_Float16 *)&y[i], ny, avl);
}
#elif defined(GGML_SIMD)
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
GGML_F16_VEC ax[GGML_F16_ARR];
GGML_F16_VEC ay[GGML_F16_ARR];
for (int i = 0; i < np; i += GGML_F16_STEP) {
for (int j = 0; j < GGML_F16_ARR; j++) {
ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
}
const int np2 = (n & ~(ggml_f16_epr - 1));
for (int k = np; k < np2; k += ggml_f16_epr) {
svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);
svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
ry = GGML_F16x_VEC_FMA(ry, rx, vx);
GGML_F16x_VEC_STORE(y + k, ry, 0);
}
if (np2 < n) {
svbool_t pg = svwhilelt_b16(np2, n);
svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
hy = svmad_f16_x(pg, hx, vx, hy);
svst1_f16(pg, (__fp16 *)(y + np2), hy);
}
#elif defined(__riscv_v_intrinsic)
// todo: RVV impl
// scalar
for (int i = 0; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
}
#else
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
GGML_F16_VEC ax[GGML_F16_ARR];
GGML_F16_VEC ay[GGML_F16_ARR];
for (int i = 0; i < np; i += GGML_F16_STEP) {
for (int j = 0; j < GGML_F16_ARR; j++) {
ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
}
}
// leftovers
for (int i = np; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
}
#endif
}
#else
// scalar
for (int i = 0; i < n; ++i) {
const int np = 0;
#endif
// leftovers
for (int i = np; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
}
#endif
}
// xs and vs are byte strides of x and v
+18 -10
View File
@@ -84,12 +84,12 @@
#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD
#define GGML_CUDA_CC_PH1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // MTT S5000
#define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD)
#define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2)
#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NG)
#define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG)
#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_PH1)
#define GGML_CUDA_CC_IS_PH1(cc) (cc >= GGML_CUDA_CC_PH1)
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
# define GGML_CUDA_USE_CUB
@@ -212,9 +212,9 @@ static const char * cu_get_error_str(CUresult err) {
#define GGML_USE_VMM
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
#if defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
#if defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
#define FP16_AVAILABLE
#endif // defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
#endif // defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
#define FAST_FP16_AVAILABLE
@@ -250,12 +250,14 @@ static const char * cu_get_error_str(CUresult err) {
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
static bool fp16_available(const int cc) {
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL ||
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);
}
static bool fast_fp16_available(const int cc) {
return GGML_CUDA_CC_IS_AMD(cc) ||
(GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610);
(GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610) ||
(GGML_CUDA_CC_IS_MTHREADS(cc) && fp16_available(cc));
}
// To be used for feature selection of external libraries, e.g. cuBLAS.
@@ -272,7 +274,9 @@ static bool fp16_mma_hardware_available(const int cc) {
}
static bool bf16_mma_hardware_available(const int cc) {
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) || GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) ||
GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3 ||
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);
}
static bool fp32_mma_hardware_available(const int cc) {
@@ -558,8 +562,12 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float2 v
acc += v.y*u.y;
}
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) {
#if defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA))
#define V_DOT2_F32_F16_AVAILABLE
#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA))
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) {
#ifdef V_DOT2_F32_F16_AVAILABLE
asm volatile("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(acc) : "v"(v), "v"(u));
#else
#ifdef FAST_FP16_AVAILABLE
@@ -571,7 +579,7 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v,
acc += tmpv.x * tmpu.x;
acc += tmpv.y * tmpu.y;
#endif // FAST_FP16_AVAILABLE
#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(GCN5) || defined(CDNA))
#endif // V_DOT2_F32_F16_AVAILABLE
}
static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v, const half2 u) {
+4 -1
View File
@@ -86,6 +86,9 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const
}
}
}
GGML_UNUSED_VARS(ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11,
nb12, nb13);
}
static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
@@ -202,7 +205,7 @@ static void ggml_cpy_scalar_cuda(
ne00n = ne00;
ne01n = ne01;
ne02n = ne02;
} else if (nb00 > nb02) {
} else {
ne00n = ne00;
ne01n = ne01*ne02;
ne02n = 1;
+2 -2
View File
@@ -55,11 +55,11 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16(
ggml_cuda_memcpy_1<sizeof(tmp)>(tmp, K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne);
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
#ifdef FAST_FP16_AVAILABLE
#ifdef V_DOT2_F32_F16_AVAILABLE
ggml_cuda_mad(sum, tmp[k_KQ_1] , ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
#else
ggml_cuda_mad(sum, __half22float2(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
#endif // FP16_AVAILABLE
#endif // V_DOT2_F32_F16_AVAILABLE
}
}
+1 -1
View File
@@ -609,7 +609,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
float KQ_sum_add = 0.0f;
#pragma unroll
for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {
const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup ?
const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < static_cast<uint32_t>(k_VKQ_sup) ?
expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0.0f;
KQ_sum_add += val;
tmp[i0/(np*warp_size)][jc1] = val;
+18 -18
View File
@@ -86,11 +86,11 @@ static __global__ void flash_attn_ext_vec(
constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>();
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
#ifdef FAST_FP16_AVAILABLE
#ifdef V_DOT2_F32_F16_AVAILABLE
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half, V_rows_per_thread>();
#else
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, float, V_rows_per_thread>();
#endif // FAST_FP16_AVAILABLE
#endif // V_DOT2_F32_F16_AVAILABLE
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
@@ -112,13 +112,13 @@ static __global__ void flash_attn_ext_vec(
constexpr int ne_KQ = ncols*D;
constexpr int ne_combine = nwarps*V_cols_per_iter*D;
#ifdef FAST_FP16_AVAILABLE
#ifdef V_DOT2_F32_F16_AVAILABLE
half2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
__shared__ half KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
#else
float2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
__shared__ float KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
#endif // FAST_FP16_AVAILABLE
#endif // V_DOT2_F32_F16_AVAILABLE
float KQ_max[ncols];
float KQ_sum[ncols];
@@ -129,11 +129,11 @@ static __global__ void flash_attn_ext_vec(
}
// Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
#ifdef FAST_FP16_AVAILABLE
#ifdef V_DOT2_F32_F16_AVAILABLE
half2 Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely.
#else
float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized.
#endif // FAST_FP16_AVAILABLE
#endif // V_DOT2_F32_F16_AVAILABLE
int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
if constexpr (Q_q8_1) {
@@ -155,7 +155,7 @@ static __global__ void flash_attn_ext_vec(
for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE <= D/sizeof(int) || i < D/sizeof(int)) {
if (i0 + WARP_SIZE <= int(D/sizeof(int)) || i < int(D/sizeof(int))) {
tmp_q_i32[i] = 0;
}
}
@@ -191,7 +191,7 @@ static __global__ void flash_attn_ext_vec(
__syncthreads();
} else {
#ifdef FAST_FP16_AVAILABLE
#ifdef V_DOT2_F32_F16_AVAILABLE
const half2 scale_h2 = make_half2(scale, scale);
#pragma unroll
for (int j = 0; j < ncols; ++j) {
@@ -233,7 +233,7 @@ static __global__ void flash_attn_ext_vec(
Q_reg[j][k].y *= scale;
}
}
#endif // FAST_FP16_AVAILABLE
#endif // V_DOT2_F32_F16_AVAILABLE
}
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
@@ -272,7 +272,7 @@ static __global__ void flash_attn_ext_vec(
KQ_max_new[j] = fmaxf(KQ_max_new[j], sum);
if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == i_KQ_0) {
if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == uint32_t(i_KQ_0)) {
KQ_reg[j] = sum;
}
}
@@ -291,7 +291,7 @@ static __global__ void flash_attn_ext_vec(
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j];
KQ[j*nthreads + tid] = KQ_reg[j];
#ifdef FAST_FP16_AVAILABLE
#ifdef V_DOT2_F32_F16_AVAILABLE
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
@@ -303,7 +303,7 @@ static __global__ void flash_attn_ext_vec(
VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
}
#endif // FAST_FP16_AVAILABLE
#endif // V_DOT2_F32_F16_AVAILABLE
}
#ifndef GGML_USE_HIP
@@ -314,7 +314,7 @@ static __global__ void flash_attn_ext_vec(
for (int k0 = 0; k0 < WARP_SIZE; k0 += V_cols_per_iter) {
const int k = threadIdx.y*WARP_SIZE + k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V);
#ifdef FAST_FP16_AVAILABLE
#ifdef V_DOT2_F32_F16_AVAILABLE
half2 KQ_k[ncols];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
@@ -353,7 +353,7 @@ static __global__ void flash_attn_ext_vec(
}
}
}
#endif // FAST_FP16_AVAILABLE
#endif // V_DOT2_F32_F16_AVAILABLE
}
}
@@ -374,7 +374,7 @@ static __global__ void flash_attn_ext_vec(
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx.x == 0 ? expf(sink - KQ_max[j]) : 0.0f);
#ifdef FAST_FP16_AVAILABLE
#ifdef V_DOT2_F32_F16_AVAILABLE
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
@@ -386,7 +386,7 @@ static __global__ void flash_attn_ext_vec(
VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
}
#endif // FAST_FP16_AVAILABLE
#endif // V_DOT2_F32_F16_AVAILABLE
}
}
@@ -421,7 +421,7 @@ static __global__ void flash_attn_ext_vec(
const float kqmax_scale = expf(KQ_max[j_VKQ] - kqmax_new);
KQ_max[j_VKQ] = kqmax_new;
#ifdef FAST_FP16_AVAILABLE
#ifdef V_DOT2_F32_F16_AVAILABLE
half2 * VKQ_tmp = (half2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2)
+ (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2);
@@ -452,7 +452,7 @@ static __global__ void flash_attn_ext_vec(
ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]);
ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ + V_rows_per_thread/4, &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]);
}
#endif // FAST_FP16_AVAILABLE
#endif // V_DOT2_F32_F16_AVAILABLE
KQ_sum[j_VKQ] *= kqmax_scale;
KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]);
+22 -12
View File
@@ -53,6 +53,7 @@
#include "ggml-cuda/set.cuh"
#include "ggml-cuda/set-rows.cuh"
#include "ggml-cuda/pad_reflect_1d.cuh"
#include "ggml-cuda/solve_tri.cuh"
#include "ggml.h"
#include <algorithm>
@@ -2717,6 +2718,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_OPT_STEP_SGD:
ggml_cuda_opt_step_sgd(ctx, dst);
break;
case GGML_OP_SOLVE_TRI:
ggml_cuda_op_solve_tri(ctx, dst);
break;
default:
return false;
}
@@ -3046,7 +3050,12 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
std::initializer_list<enum ggml_op> topk_moe_ops_delayed_softmax =
ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
if (ops.size() == topk_moe_ops_with_norm.size() &&
const auto is_equal = [](const std::initializer_list<enum ggml_op> & list1,
const std::initializer_list<enum ggml_op> & list2) {
return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end());
};
if (is_equal(topk_moe_ops_with_norm, ops) &&
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx + 9];
@@ -3056,8 +3065,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
}
}
if (ops.size() == topk_moe_ops.size() &&
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx + 4];
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
@@ -3065,7 +3073,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
}
}
if (ops.size() == topk_moe_ops_delayed_softmax.size() &&
if (is_equal(topk_moe_ops_delayed_softmax, ops) &&
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
ggml_tensor * weights = cgraph->nodes[node_idx + 5];
@@ -3081,9 +3089,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
std::initializer_list<enum ggml_op> mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU };
std::initializer_list<enum ggml_op> mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU };
if (ops.size() == 5 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}) ||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}))) {
if ((is_equal(mul_mat_bias_glu_ops, ops) || is_equal(mul_mat_id_bias_glu_ops, ops)) &&
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 4 })) {
const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
const ggml_tensor * ffn_gate_bias = cgraph->nodes[node_idx + 1];
const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 2];
@@ -3095,9 +3102,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
}
}
if (ops.size() == 3 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}) ||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}))) {
if ((is_equal(mul_mat_id_glu_ops, ops) || is_equal(mul_mat_glu_ops, ops)) &&
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 1];
const ggml_tensor * glu = cgraph->nodes[node_idx + 2];
@@ -3107,7 +3113,9 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
}
}
if (ops.size() == 3 && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
std::initializer_list<enum ggml_op> rope_set_rows_ops = { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS };
if (is_equal(rope_set_rows_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
const ggml_tensor * rope = cgraph->nodes[node_idx];
const ggml_tensor * view = cgraph->nodes[node_idx + 1];
const ggml_tensor * set_rows = cgraph->nodes[node_idx + 2];
@@ -3837,7 +3845,7 @@ static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t *
// Check if UMA is explicitly enabled via environment variable
bool uma_env = getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr;
bool is_uma = prop.unifiedAddressing > 0 || uma_env;
bool is_uma = prop.integrated > 0 || uma_env;
if (is_uma) {
// For UMA systems (like DGX Spark), use system memory info
@@ -4255,6 +4263,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
return true;
case GGML_OP_SOLVE_TRI:
return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32;
default:
return false;
}
+112 -32
View File
@@ -73,34 +73,7 @@ namespace ggml_cuda_mma {
static constexpr int I = I_;
static constexpr int J = J_;
#if defined(GGML_USE_HIP)
#if defined(RDNA4)
static constexpr int ne = I * J / 32;
T x[ne] = {0};
static constexpr __device__ bool supported() {
if (I == 16 && J == 16) return true;
return false;
}
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 16) {
return 8 * (threadIdx.x / 16) + l;
} else {
NO_DEVICE_CODE;
return -1;
}
}
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 16) {
return threadIdx.x % 16;
} else {
NO_DEVICE_CODE;
return -1;
}
}
#else
#if defined(AMD_MFMA_AVAILABLE)
static constexpr int ne = I * J / 64;
T x[ne] = {0};
@@ -146,7 +119,6 @@ namespace ggml_cuda_mma {
return -1;
}
}
#endif // defined(RDNA4)
#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
static constexpr int ne = I * J / 32;
T x[ne] = {0};
@@ -177,6 +149,34 @@ namespace ggml_cuda_mma {
return -1;
}
}
#elif defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA4)
static constexpr int ne = I * J / 32;
T x[ne] = {0};
static constexpr __device__ bool supported() {
if (I == 16 && J == 16) return true;
return false;
}
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 16) {
return 8 * (threadIdx.x / 16) + l;
} else {
NO_DEVICE_CODE;
return -1;
}
}
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 16) {
return threadIdx.x % 16;
} else {
NO_DEVICE_CODE;
return -1;
}
}
#endif
#else
static constexpr int ne = I * J / 32;
T x[ne] = {0};
@@ -437,7 +437,29 @@ namespace ggml_cuda_mma {
xi[0] = xs[0];
}
#elif defined(AMD_WMMA_AVAILABLE)
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
} else if constexpr (std::is_same_v<T, int>) {
if constexpr (I == 16 && J == 4) {
int64_t * xi = (int64_t *) t.x;
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
xi[0] = xs[0];
}else if constexpr (I == 16 && J == 8) {
int64_t * xi = (int64_t *) t.x;
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
xi[0] = xs[0];
const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
xi[1] = xs1[0];
}else{
NO_DEVICE_CODE;
}
} else {
NO_DEVICE_CODE;
}
#else
#pragma unroll
for (int l = 0; l < t.ne; ++l) {
@@ -772,6 +794,36 @@ namespace ggml_cuda_mma {
acc[0],
0, 0, 0);
#endif // defined(CDNA3)
#elif defined(AMD_WMMA_AVAILABLE)
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
int32x2_t * a_vec = (int32x2_t *) A.x;
int32x2_t * b_vec = (int32x2_t *) B.x;
using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
int32x8_t * acc = (int32x8_t *) D.x;
#if defined(RDNA4)
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
true,
a_vec[0],
true,
b_vec[0],
acc[0],
true
);
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
true,
a_vec[1],
true,
b_vec[1],
acc[0],
true
);
#endif // defined(RDNA4)
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
@@ -798,6 +850,7 @@ namespace ggml_cuda_mma {
acc[0],
0, 0, 0);
#endif // defined(CDNA3)
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
@@ -836,10 +889,37 @@ namespace ggml_cuda_mma {
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
: "r"(Axi[6]), "r"(Axi[7]), "r"(Bxi[6]), "r"(Bxi[7]));
#else
tile<16, 8, float> * D16 = (tile<16, 8, float> *) &D;
tile<16, 8, half2> * A16 = (tile<16, 8, half2> *) &A;
tile <16, 8, float> * D16 = reinterpret_cast<tile <16, 8, float> *>(&D);
const tile<16, 8, half2> * A16 = reinterpret_cast<const tile<16, 8, half2> *>(&A);
mma(D16[0], A16[0], B);
mma(D16[1], A16[1], B);
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
}
static __device__ __forceinline__ void mma(
tile<16, 16, int> & D, const tile<16, 4, int> & A, const tile<16, 4, int> & B) {
#if defined(AMD_WMMA_AVAILABLE)
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
int32x2_t * a_vec = (int32x2_t *) A.x;
int32x2_t * b_vec = (int32x2_t *) B.x;
using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
int32x8_t * acc = (int32x8_t *) D.x;
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
true,
a_vec[0],
true,
b_vec[0],
acc[0],
false
);
#else
GGML_UNUSED(D);
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif
}
}
+1 -1
View File
@@ -151,7 +151,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
return false;
}
} else {
if (src1_ncols > 16 || GGML_CUDA_CC_IS_RDNA4(cc)) {
if (src1_ncols > 16) {
return false;
}
}
+7 -1
View File
@@ -306,5 +306,11 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
return false;
}
return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
if (amd_wmma_available(cc)) {
if (GGML_CUDA_CC_IS_RDNA4(cc)) {
return true;
}
}
return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}
+301 -138
View File
File diff suppressed because it is too large Load Diff
+203
View File
@@ -0,0 +1,203 @@
#include "common.cuh"
#include "ggml.h"
#include "solve_tri.cuh"
#define MAX_N_FAST 64
#define MAX_K_FAST 32
// ======================
// Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction
// ======================
// When ncols_template == 0 the bounds for the loops in this function are not
// known and can't be unrolled. As we want to keep pragma unroll for all other
// cases we supress the clang transformation warning here.
#ifdef __clang__
# pragma clang diagnostic push
# pragma clang diagnostic ignored "-Wpass-failed"
#endif // __clang__
template <int n_template, int k_template>
static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
const float * __restrict__ B,
float * __restrict__ X,
const uint3 ne02,
const size_t nb02,
const size_t nb03,
const size_t nb12,
const size_t nb13,
const size_t nb2,
const size_t nb3,
const int n_arg,
const int k_arg) {
const int n = n_template == 0 ? n_arg : n_template;
const int k = k_template == 0 ? k_arg : k_template;
const int batch_idx = blockIdx.x;
const int lane = threadIdx.x;
const int col_idx = threadIdx.y;
if (col_idx >= k) {
return;
}
const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
const int64_t i02 = i02_i03.y;
const int64_t i03 = i02_i03.x;
const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
__shared__ float sA[MAX_N_FAST * MAX_N_FAST];
__shared__ float sXt[MAX_N_FAST * (MAX_K_FAST + 1)];
const int offset = threadIdx.x + threadIdx.y * blockDim.x;
#pragma unroll
for (int i = 0; i < n * n; i += k * WARP_SIZE) {
int i0 = i + offset;
if (i0 < n * n) {
sA[i0] = A_batch[i0];
}
}
const int rows_per_warp = (n + WARP_SIZE - 1) / WARP_SIZE;
#pragma unroll
for (int i = 0; i < rows_per_warp; i++) {
const int i0 = lane + i * WARP_SIZE;
if (i0 < n) {
sXt[col_idx * n + i0] = B_batch[i0 * k + col_idx];
}
}
__syncthreads();
#pragma unroll
for (int row = 0; row < n; ++row) {
float sum = 0.0f;
{
int j = lane;
if (j < row) {
sum += sA[row * n + j] * sXt[col_idx * n + j];
}
}
if (row >= WARP_SIZE) {
int j = WARP_SIZE + lane;
if (j < row) {
sum += sA[row * n + j] * sXt[col_idx * n + j];
}
}
sum = warp_reduce_sum(sum);
if (lane == 0) {
const float b_val = sXt[col_idx * n + row];
const float a_diag = sA[row * n + row];
// no safeguards for division by zero because that indicates corrupt
// data anyway
sXt[col_idx * n + row] = (b_val - sum) / a_diag;
}
}
__syncthreads();
#pragma unroll
for (int i = 0; i < rows_per_warp; i++) {
const int i0 = lane + i * WARP_SIZE;
if (i0 < n) {
X_batch[i0 * k + col_idx] = sXt[col_idx * n + i0];
}
}
}
#ifdef __clang__
# pragma clang diagnostic pop
#endif // __clang__
static void solve_tri_f32_cuda(const float * A,
const float * B,
float * X,
int n,
int k,
int64_t ne02,
int64_t ne03,
size_t nb02,
size_t nb03,
size_t nb12,
size_t nb13,
size_t nb2,
size_t nb3,
cudaStream_t stream) {
const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
dim3 threads(WARP_SIZE, k);
dim3 grid(ne02 * ne03);
if (n == 64) {
switch (k) {
case 32:
solve_tri_f32_fast<64, 32>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 16:
solve_tri_f32_fast<64, 16>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 14:
solve_tri_f32_fast<64, 14>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 12:
solve_tri_f32_fast<64, 12>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 10:
solve_tri_f32_fast<64, 10>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 8:
solve_tri_f32_fast<64, 8>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 6:
solve_tri_f32_fast<64, 6>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 4:
solve_tri_f32_fast<64, 4>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 2:
solve_tri_f32_fast<64, 2>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 1:
solve_tri_f32_fast<64, 1>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
default:
solve_tri_f32_fast<0, 0>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
}
} else { // run general case
solve_tri_f32_fast<0, 0>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
}
}
void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; // A (triangular n x x matrix)
const ggml_tensor * src1 = dst->src[1]; // B (right hand side of n x k equation columns)
ggml_is_contiguous(src0);
ggml_is_contiguous(src1);
const int64_t n = src0->ne[0];
const int64_t k = src1->ne[0];
GGML_ASSERT(n <= 64);
GGML_ASSERT(k <= 32);
solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k, src0->ne[2],
src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
dst->nb[3] / sizeof(float), ctx.stream());
}
+3
View File
@@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+1 -1
View File
@@ -2229,7 +2229,7 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess
int mode = op_params[2];
if ((mode & GGML_ROPE_TYPE_NEOX) || (mode & GGML_ROPE_TYPE_MROPE) || (mode & GGML_ROPE_TYPE_VISION)) {
if ((mode & GGML_ROPE_TYPE_MROPE) || (mode & GGML_ROPE_TYPE_VISION)) {
return false;
}
if (mode & 1) {
+80 -7
View File
@@ -24,6 +24,10 @@
#include "hvx-utils.h"
#include "ops-utils.h"
// Redefined the types GGML_ROPE_TYPE_NORMAL & GGML_ROPE_TYPE_NEOX as we cant include ggml.h
#define HTP_ROPE_TYPE_NORMAL 0
#define HTP_ROPE_TYPE_NEOX 2
#define htp_rope_preamble \
const uint32_t ne00 = src0->ne[0]; \
const uint32_t ne01 = src0->ne[1]; \
@@ -146,6 +150,57 @@ static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context
rope_ctx->ext_factor, rope_ctx->theta_scale, rope_ctx->attn_factor);
}
static void hvx_calc_rope_neox_f32(const float * restrict src0,
float * restrict dst,
const int num_elems,
const float * restrict theta_cache) {
// for (int i = 0; i < num_elems; i += 2) {
//const float cos_theta = theta_cache[i + 0];
//const float sin_theta = theta_cache[i + 1];
//const float x0 = src[0];
//const float x1 = src[num_elems/2];
//dst[0] = x0*cos_theta - x1*sin_theta;
//dst[num_elems/2] = x0*sin_theta + x1*cos_theta;
//src += 1;
//dst += 1;
// }
const uint8_t * restrict src0_curr = (const uint8_t *) src0;
const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache;
uint8_t * restrict dst_curr = (uint8_t *) dst;
int step_of_1 = num_elems >> 6; // 6 because we process two vectors at once
int half_size = (sizeof(float) * (num_elems / 2));
for (int i = 0; i < step_of_1; i++) {
HVX_Vector v0 = *(HVX_Vector *) src0_curr;
HVX_Vector v1 = *(HVX_Vector *) (src0_curr + half_size);
HVX_Vector v2 = *(HVX_Vector *) theta_curr;
HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN);
HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta
HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_lo_W(vcos_sin));
HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_hi_W(vcos_sin));
HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(v1, Q6_V_lo_W(vcos_sin));
HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(v1, Q6_V_hi_W(vcos_sin));
HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s);
HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c);
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v4);
*(HVX_Vector *) (dst_curr + half_size) = Q6_Vsf_equals_Vqf32(v5);
src0_curr += VLEN;
theta_curr += 2 * VLEN;
dst_curr += VLEN;
}
}
static void hvx_calc_rope_f32(const float * restrict src0,
float * restrict dst,
const int num_elems,
@@ -212,6 +267,9 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
const struct htp_tensor * src2 = &octx->src2;
struct htp_tensor * dst = &octx->dst;
const int32_t mode = rope_ctx->mode;
const bool is_neox = mode & HTP_ROPE_TYPE_NEOX;
htp_rope_preamble;
const int32_t * pos = (const int32_t *) src1->data;
@@ -247,20 +305,35 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
float * dst_data_loc = dst_data;
if (1 == opt_path) {
hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
if (is_neox) {
hvx_calc_rope_neox_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
} else {
hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
}
} else {
for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) {
const float cos_theta = wp0[i0 + 0];
const float sin_theta = wp0[i0 + 1];
const float x0 = src_loc[0];
const float x1 = src_loc[1];
if (is_neox) {
const float x0 = src_loc[0];
const float x1 = src_loc[rope_ctx->n_dims/2];
dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
dst_data_loc[1] = x0 * sin_theta + x1 * cos_theta;
dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
dst_data_loc[rope_ctx->n_dims/2] = x0 * sin_theta + x1 * cos_theta;
src_loc += 2;
dst_data_loc += 2;
src_loc += 1;
dst_data_loc += 1;
} else {
const float x0 = src_loc[0];
const float x1 = src_loc[1];
dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
dst_data_loc[1] = x0 * sin_theta + x1 * cos_theta;
src_loc += 2;
dst_data_loc += 2;
}
}
}
+58
View File
@@ -1009,6 +1009,64 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_l
return res;
}
// note: reuse the argsort kernel for top_k
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_TOP_K);
char base[256];
char name[256];
// note: the top_k kernel is always descending order
ggml_sort_order order = GGML_SORT_ORDER_DESC;
const char * order_str = "undefined";
switch (order) {
case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
default: GGML_ABORT("fatal error");
};
snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_TOP_K);
char base[256];
char name[256];
ggml_sort_order order = GGML_SORT_ORDER_DESC;
const char * order_str = "undefined";
switch (order) {
case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
default: GGML_ABORT("fatal error");
};
snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
ggml_metal_library_t lib,
const struct ggml_tensor * op,
+2
View File
@@ -128,6 +128,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
+1
View File
@@ -905,6 +905,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_LEAKY_RELU:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ARGSORT:
case GGML_OP_TOP_K:
case GGML_OP_ARANGE:
return true;
case GGML_OP_FLASH_ATTN_EXT:
+14 -4
View File
@@ -832,14 +832,19 @@ typedef struct {
} ggml_metal_kargs_leaky_relu;
typedef struct {
int64_t ne00;
int64_t ne01;
int64_t ne02;
int64_t ne03;
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
int32_t top_k;
} ggml_metal_kargs_argsort;
typedef struct {
@@ -851,6 +856,11 @@ typedef struct {
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
int32_t top_k;
int32_t len;
} ggml_metal_kargs_argsort_merge;
+143 -17
View File
@@ -406,6 +406,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_argsort(ctx, idx);
} break;
case GGML_OP_TOP_K:
{
n_fuse = ggml_metal_op_top_k(ctx, idx);
} break;
case GGML_OP_LEAKY_RELU:
{
n_fuse = ggml_metal_op_leaky_relu(ctx, idx);
@@ -3678,14 +3682,19 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
}
ggml_metal_kargs_argsort args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.top_k =*/ nth,
};
ggml_metal_encoder_set_pipeline(enc, pipeline);
@@ -3705,15 +3714,20 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
ggml_metal_op_concurrency_reset(ctx);
ggml_metal_kargs_argsort_merge args_merge = {
.ne00 = ne00,
.ne01 = ne01,
.ne02 = ne02,
.ne03 = ne03,
.nb00 = nb00,
.nb01 = nb01,
.nb02 = nb02,
.nb03 = nb03,
.len = len,
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.top_k =*/ ne00,
/*.len =*/ len,
};
// merges per row
@@ -3737,6 +3751,118 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
return 1;
}
int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_top_k(lib, op);
// bitonic sort requires the number of elements to be power of 2
int nth = 1;
while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth *= 2;
}
// blocks per row
const int npr = (ne00 + nth - 1)/nth;
const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
ggml_metal_buffer_id bid_tmp = bid_dst;
bid_tmp.offs += sizeof(int32_t)*ggml_nelements(op->src[0]);
if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
std::swap(bid_dst, bid_tmp);
}
const int top_k = ne0;
ggml_metal_kargs_argsort args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.top_k =*/ std::min(nth, top_k), // for each block, keep just the top_k indices
};
if (npr > 1) {
args.ne0 = (npr - 1)*args.top_k + std::min(ne00 - (npr - 1)*nth, args.top_k);
}
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op);
int len = args.top_k;
while (len < args.ne0) {
ggml_metal_op_concurrency_reset(ctx);
// merges per row
const int nm = (args.ne0 + 2*len - 1) / (2*len);
const int nth = std::min(512, std::min(len, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge)));
ggml_metal_kargs_argsort_merge args_merge = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ args.ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.top_k =*/ nm == 1 ? top_k : args.ne0, // the final merge outputs top_k elements
/*.len =*/ len,
};
ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
std::swap(bid_dst, bid_tmp);
len <<= 1;
}
return 1;
}
int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
+1
View File
@@ -81,6 +81,7 @@ int ggml_metal_op_arange (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);
+4
View File
@@ -202,6 +202,10 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_
{
res *= 2;
} break;
case GGML_OP_TOP_K:
{
res = 2*sizeof(int32_t)*ggml_nelements(tensor->src[0]);
} break;
default:
break;
}
+22 -15
View File
@@ -4670,11 +4670,12 @@ kernel void kernel_argsort_f32_i32(
ushort3 ntg[[threads_per_threadgroup]]) {
// bitonic sort
const int col = tpitg[0];
const int ib = tgpig[0] / args.ne01;
const int i00 = (tgpig[0]/args.ne01)*ntg.x;
const int i01 = tgpig[0]%args.ne01;
const int i02 = tgpig[1];
const int i03 = tgpig[2];
const int i00 = ib*ntg.x;
const int i01 = tgpig[0] % args.ne01;
const int i02 = tgpig[1];
const int i03 = tgpig[2];
device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
@@ -4710,9 +4711,11 @@ kernel void kernel_argsort_f32_i32(
}
}
const int64_t i0 = ib*args.top_k;
// copy the result to dst without the padding
if (i00 + col < args.ne00) {
dst += i00 + args.ne00*i01 + args.ne00*args.ne01*i02 + args.ne00*args.ne01*args.ne02*i03;
if (i0 + col < args.ne0 && col < args.top_k) {
dst += i0 + args.ne0*i01 + args.ne0*args.ne1*i02 + args.ne0*args.ne1*args.ne2*i03;
dst[col] = shmem_i32[col];
}
@@ -4747,22 +4750,22 @@ kernel void kernel_argsort_merge_f32_i32(
const int start = im * (2 * args.len);
const int len0 = MIN(args.len, MAX(0, args.ne00 - (int)(start)));
const int len1 = MIN(args.len, MAX(0, args.ne00 - (int)(start + args.len)));
const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start)));
const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len)));
const int total = len0 + len1;
device const int32_t * tmp0 = tmp + start
+ i01*args.ne00
+ i02*args.ne00*args.ne01
+ i03*args.ne00*args.ne01*args.ne02;
+ i01*args.ne0
+ i02*args.ne0*args.ne01
+ i03*args.ne0*args.ne01*args.ne02;
device const int32_t * tmp1 = tmp0 + args.len;
dst += start
+ i01*args.ne00
+ i02*args.ne00*args.ne01
+ i03*args.ne00*args.ne01*args.ne02;
+ i01*args.top_k
+ i02*args.top_k*args.ne01
+ i03*args.top_k*args.ne01*args.ne02;
device const float * src0_row = (device const float *)(src0
+ args.nb01*i01
@@ -4776,7 +4779,11 @@ kernel void kernel_argsort_merge_f32_i32(
const int chunk = (total + ntg.x - 1) / ntg.x;
const int k0 = tpitg.x * chunk;
const int k1 = min(k0 + chunk, total);
const int k1 = MIN(MIN(k0 + chunk, total), args.top_k);
if (k0 >= args.top_k) {
return;
}
if (k0 >= total) {
return;
+4
View File
@@ -70,6 +70,7 @@ set(GGML_OPENCL_KERNELS
group_norm
im2col_f32
im2col_f16
mean
mul_mat_Ab_Bi_8x4
mul_mv_f16_f16
mul_mv_f16_f32_1row
@@ -109,6 +110,9 @@ set(GGML_OPENCL_KERNELS
softmax_4_f16
softmax_f32
softmax_f16
sqr
sqrt
ssm_conv
sub
sum_rows
transpose
+331
View File
@@ -449,6 +449,9 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_sub, kernel_sub_row, kernel_sub_f16, kernel_sub_row_f16;
cl_kernel kernel_add_id;
cl_kernel kernel_scale;
cl_kernel kernel_sqr_cont_f32, kernel_sqr_cont_f32_4, kernel_sqr_cont_f16, kernel_sqr_cont_f16_4;
cl_kernel kernel_sqrt_cont_f32, kernel_sqrt_cont_f32_4, kernel_sqrt_cont_f16, kernel_sqrt_cont_f16_4;
cl_kernel kernel_mean_f32;
cl_kernel kernel_silu, kernel_silu_4;
cl_kernel kernel_gelu, kernel_gelu_4;
cl_kernel kernel_gelu_erf, kernel_gelu_erf_4;
@@ -509,6 +512,7 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_conv_2d_f16;
cl_kernel kernel_conv_2d_f32;
cl_kernel kernel_conv_2d_f16_f32;
cl_kernel kernel_ssm_conv_f32_f32, kernel_ssm_conv_f32_f32_4;
cl_kernel kernel_timestep_embedding;
cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32;
cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
@@ -1552,6 +1556,66 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
GGML_LOG_CONT(".");
}
// sqr
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "sqr.cl.h"
};
#else
const std::string kernel_src = read_file("sqr.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_sqr_cont_f32 = clCreateKernel(prog, "kernel_sqr_cont_f32", &err), err));
CL_CHECK((backend_ctx->kernel_sqr_cont_f32_4 = clCreateKernel(prog, "kernel_sqr_cont_f32_4", &err), err));
CL_CHECK((backend_ctx->kernel_sqr_cont_f16 = clCreateKernel(prog, "kernel_sqr_cont_f16", &err), err));
CL_CHECK((backend_ctx->kernel_sqr_cont_f16_4 = clCreateKernel(prog, "kernel_sqr_cont_f16_4", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// sqrt
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "sqrt.cl.h"
};
#else
const std::string kernel_src = read_file("sqrt.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_sqrt_cont_f32 = clCreateKernel(prog, "kernel_sqrt_cont_f32", &err), err));
CL_CHECK((backend_ctx->kernel_sqrt_cont_f32_4 = clCreateKernel(prog, "kernel_sqrt_cont_f32_4", &err), err));
CL_CHECK((backend_ctx->kernel_sqrt_cont_f16 = clCreateKernel(prog, "kernel_sqrt_cont_f16", &err), err));
CL_CHECK((backend_ctx->kernel_sqrt_cont_f16_4 = clCreateKernel(prog, "kernel_sqrt_cont_f16_4", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// mean
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "mean.cl.h"
};
#else
const std::string kernel_src = read_file("mean.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_mean_f32 = clCreateKernel(prog, "kernel_mean_f32", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// sub
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -1825,6 +1889,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
}
}
// ssm_conv
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "ssm_conv.cl.h"
};
#else
const std::string kernel_src = read_file("ssm_conv.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_ssm_conv_f32_f32 = clCreateKernel(prog, "kernel_ssm_conv_f32_f32", &err), err));
CL_CHECK((backend_ctx->kernel_ssm_conv_f32_f32_4 = clCreateKernel(prog, "kernel_ssm_conv_f32_f32_4", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// mul_mv_id_q4_0_f32_8x_flat
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -2959,6 +3041,10 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
case GGML_OP_ADD_ID:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SQR:
case GGML_OP_SQRT:
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
ggml_is_contiguous(op->src[0]);
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_GELU:
@@ -3007,6 +3093,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
return (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16) ||
(op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
(op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32);
case GGML_OP_SSM_CONV:
return (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32);
case GGML_OP_CONCAT:
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
case GGML_OP_TIMESTEP_EMBEDDING:
@@ -3075,6 +3163,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
return cols <= max_workgroup_size && op->src[0]->type == GGML_TYPE_F32;
}
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
case GGML_OP_FLASH_ATTN_EXT:
{
@@ -5193,6 +5282,224 @@ static void ggml_cl_sub(ggml_backend_t backend, const ggml_tensor * src0, const
}
}
static void ggml_cl_sqr(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
GGML_ASSERT(dst);
GGML_ASSERT(dst->extra);
UNUSED(src1);
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
cl_kernel kernel;
// Currently assumes src0 is contiguous
int n = ggml_nelements(dst);
if (n % 4 == 0) {
if (src0->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_sqr_cont_f32_4;
} else {
kernel = backend_ctx->kernel_sqr_cont_f16_4;
}
n /= 4;
} else {
if (src0->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_sqr_cont_f32;
} else {
kernel = backend_ctx->kernel_sqr_cont_f16;
}
}
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
size_t global_work_size[] = {(size_t)n, 1, 1};
size_t local_work_size[] = {64, 1, 1};
size_t * local_work_size_ptr = local_work_size;
if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
local_work_size_ptr = nullptr;
}
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
}
static void ggml_cl_sqrt(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
GGML_ASSERT(dst);
GGML_ASSERT(dst->extra);
UNUSED(src1);
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
cl_kernel kernel;
// Currently assumes src0 is contiguous
int n = ggml_nelements(dst);
if (n % 4 == 0) {
if (src0->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_sqrt_cont_f32_4;
} else {
kernel = backend_ctx->kernel_sqrt_cont_f16_4;
}
n /= 4;
} else {
if (src0->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_sqrt_cont_f32;
} else {
kernel = backend_ctx->kernel_sqrt_cont_f16;
}
}
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
size_t global_work_size[] = {(size_t)n, 1, 1};
size_t local_work_size[] = {64, 1, 1};
size_t * local_work_size_ptr = local_work_size;
if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
local_work_size_ptr = nullptr;
}
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
}
static void ggml_cl_mean(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
GGML_ASSERT(dst);
GGML_ASSERT(dst->extra);
GGML_UNUSED(src1);
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
GGML_ASSERT(ggml_is_contiguous(src0));
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
const int ne00 = src0->ne[0];
const int ne01 = src0->ne[1];
const int ne02 = src0->ne[2];
const int ne03 = src0->ne[3];
const cl_ulong nb01 = src0->nb[1];
const cl_ulong nb02 = src0->nb[2];
const cl_ulong nb03 = src0->nb[3];
const cl_ulong nb1 = dst->nb[1];
const cl_ulong nb2 = dst->nb[2];
const cl_ulong nb3 = dst->nb[3];
cl_kernel kernel = backend_ctx->kernel_mean_f32;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb1));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb3));
size_t global_work_size[] = {(size_t)ne01, (size_t)ne02, (size_t)ne03};
size_t local_work_size[] = {(size_t)64, 1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
}
static void ggml_cl_ssm_conv(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
GGML_ASSERT(src1);
GGML_ASSERT(src1->extra);
GGML_ASSERT(dst);
GGML_ASSERT(dst->extra);
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offset1 = extra1->offset + src1->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
int ne01 = src0->ne[1];
cl_ulong nb00 = src0->nb[0];
cl_ulong nb01 = src0->nb[1];
cl_ulong nb02 = src0->nb[2];
int ne10 = src1->ne[0];
cl_ulong nb11 = src1->nb[1];
int ne1 = dst->ne[1];
int ne2 = dst->ne[2];
cl_ulong nb0 = dst->nb[0];
cl_ulong nb1 = dst->nb[1];
cl_ulong nb2 = dst->nb[2];
cl_kernel kernel = backend_ctx->kernel_ssm_conv_f32_f32;
if (ne10 % 4 == 0) {
kernel = backend_ctx->kernel_ssm_conv_f32_f32_4;
}
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb00));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb0));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb1));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb2));
size_t global_work_size[] = {(size_t)ne01, (size_t)ne1, (size_t)ne2};
size_t local_work_size[] = {64, 1, 1};
size_t * local_work_size_ptr = local_work_size;
if (ne01 % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
local_work_size_ptr = nullptr;
}
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
}
static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
@@ -9091,6 +9398,24 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
}
func = ggml_cl_sub;
break;
case GGML_OP_SQR:
if (!any_on_device) {
return false;
}
func = ggml_cl_sqr;
break;
case GGML_OP_SQRT:
if (!any_on_device) {
return false;
}
func = ggml_cl_sqrt;
break;
case GGML_OP_MEAN:
if (!any_on_device) {
return false;
}
func = ggml_cl_mean;
break;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(tensor)) {
case GGML_UNARY_OP_GELU:
@@ -9192,6 +9517,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
}
func = ggml_cl_conv_2d;
break;
case GGML_OP_SSM_CONV:
if (!any_on_device) {
return false;
}
func = ggml_cl_ssm_conv;
break;
case GGML_OP_CONCAT:
if (!any_on_device) {
return false;
+39
View File
@@ -0,0 +1,39 @@
kernel void kernel_mean_f32(
global float * src0,
ulong offset0,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb01,
ulong nb02,
ulong nb03,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = (global float *)((global char *)src0 + offset0);
dst = (global float *)((global char *)dst + offsetd);
int i3 = get_global_id(2);
int i2 = get_global_id(1);
int i1 = get_global_id(0);
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
return;
}
global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
global float * dst_row = (global float *) ((global char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
float row_sum = 0;
for (int i0 = 0; i0 < ne00; i0++) {
row_sum += src_row[i0];
}
dst_row[0] = row_sum / ne00;
}
+53
View File
@@ -0,0 +1,53 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
kernel void kernel_sqr_cont_f32(
global float * src0,
ulong offset0,
global float * dst,
ulong offsetd
) {
src0 = (global float*)((global char*)src0 + offset0);
dst = (global float*)((global char*)dst + offsetd);
uint gid = get_global_id(0);
dst[gid] = src0[gid] * src0[gid];
}
kernel void kernel_sqr_cont_f32_4(
global float4 * src0,
ulong offset0,
global float4 * dst,
ulong offsetd
) {
src0 = (global float4*)((global char*)src0 + offset0);
dst = (global float4*)((global char*)dst + offsetd);
uint gid = get_global_id(0);
dst[gid] = src0[gid] * src0[gid];
}
kernel void kernel_sqr_cont_f16(
global half * src0,
ulong offset0,
global half * dst,
ulong offsetd
) {
src0 = (global half*)((global char*)src0 + offset0);
dst = (global half*)((global char*)dst + offsetd);
uint gid = get_global_id(0);
dst[gid] = src0[gid] * src0[gid];
}
kernel void kernel_sqr_cont_f16_4(
global half4 * src0,
ulong offset0,
global half4 * dst,
ulong offsetd
) {
src0 = (global half4*)((global char*)src0 + offset0);
dst = (global half4*)((global char*)dst + offsetd);
uint gid = get_global_id(0);
dst[gid] = src0[gid] * src0[gid];
}
+53
View File
@@ -0,0 +1,53 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
kernel void kernel_sqrt_cont_f32(
global float * src0,
ulong offset0,
global float * dst,
ulong offsetd
) {
src0 = (global float*)((global char*)src0 + offset0);
dst = (global float*)((global char*)dst + offsetd);
uint gid = get_global_id(0);
dst[gid] = sqrt(src0[gid]);
}
kernel void kernel_sqrt_cont_f32_4(
global float4 * src0,
ulong offset0,
global float4 * dst,
ulong offsetd
) {
src0 = (global float4*)((global char*)src0 + offset0);
dst = (global float4*)((global char*)dst + offsetd);
uint gid = get_global_id(0);
dst[gid] = sqrt(src0[gid]);
}
kernel void kernel_sqrt_cont_f16(
global half * src0,
ulong offset0,
global half * dst,
ulong offsetd
) {
src0 = (global half*)((global char*)src0 + offset0);
dst = (global half*)((global char*)dst + offsetd);
uint gid = get_global_id(0);
dst[gid] = convert_half(sqrt(convert_float(src0[gid])));
}
kernel void kernel_sqrt_cont_f16_4(
global half4 * src0,
ulong offset0,
global half4 * dst,
ulong offsetd
) {
src0 = (global half4*)((global char*)src0 + offset0);
dst = (global half4*)((global char*)dst + offsetd);
uint gid = get_global_id(0);
dst[gid] = convert_half4(sqrt(convert_float4(src0[gid])));
}
+77
View File
@@ -0,0 +1,77 @@
kernel void kernel_ssm_conv_f32_f32(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
ulong nb00,
ulong nb01,
ulong nb02,
int ne10,
ulong nb11,
ulong nb0,
ulong nb1,
ulong nb2
){
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
int ir = get_global_id(0);
int i2 = get_global_id(1);
int i3 = get_global_id(2);
int nc = ne10;
global float * s = (global float *) (src0 + ir*nb01 + i2*nb00 + i3*nb02);
global float * c = (global float *) (src1 + ir*nb11);
global float * d = (global float *) (dst + ir*nb0 + i2*nb1 + i3*nb2);
float sumf = 0.0f;
for (int i0 = 0; i0 < nc; ++i0) {
sumf += s[i0] * c[i0];
}
d[0] = sumf;
}
kernel void kernel_ssm_conv_f32_f32_4(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
ulong nb00,
ulong nb01,
ulong nb02,
int ne10,
ulong nb11,
ulong nb0,
ulong nb1,
ulong nb2
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
int ir = get_global_id(0);
int i2 = get_global_id(1);
int i3 = get_global_id(2);
int nc = ne10;
global float4 * s = (global float4 *) (src0 + ir*nb01 + i2*nb00 + i3*nb02);
global float4 * c = (global float4 *) (src1 + ir*nb11);
global float * d = (global float *) (dst + ir*nb0 + i2*nb1 + i3*nb2);
float sumf = 0.0f;
for (int i0 = 0; i0 < nc/4; ++i0) {
sumf += dot(s[i0], c[i0]);
}
d[0] = sumf;
}
+91 -20
View File
@@ -106,6 +106,7 @@ enum rpc_cmd {
RPC_CMD_GET_ALLOC_SIZE,
RPC_CMD_HELLO,
RPC_CMD_DEVICE_COUNT,
RPC_CMD_GRAPH_RECOMPUTE,
RPC_CMD_COUNT,
};
@@ -205,10 +206,6 @@ struct rpc_msg_copy_tensor_rsp {
uint8_t result;
};
struct rpc_msg_graph_compute_rsp {
uint8_t result;
};
struct rpc_msg_get_device_memory_req {
uint32_t device;
};
@@ -217,6 +214,11 @@ struct rpc_msg_get_device_memory_rsp {
uint64_t free_mem;
uint64_t total_mem;
};
struct rpc_msg_graph_recompute_req {
uint32_t device;
};
#pragma pack(pop)
// RPC data structures
@@ -234,10 +236,35 @@ struct ggml_backend_rpc_buffer_type_context {
size_t max_size;
};
struct graph_cache {
bool is_cached(const ggml_cgraph * cgraph) {
if ((int)last_graph.size() != cgraph->n_nodes) {
return false;
}
for (int i = 0; i < cgraph->n_nodes; i++) {
if (memcmp(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) {
return false;
}
}
return true;
}
void add(const ggml_cgraph * cgraph) {
last_graph.resize(cgraph->n_nodes);
for (int i = 0; i < cgraph->n_nodes; i++) {
memcpy(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor));
}
}
std::vector<ggml_tensor> last_graph;
};
struct ggml_backend_rpc_context {
std::string endpoint;
uint32_t device;
std::string name;
graph_cache gc;
};
struct ggml_backend_rpc_buffer_context {
@@ -815,13 +842,24 @@ static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::ve
static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
std::vector<uint8_t> input;
serialize_graph(rpc_ctx->device, cgraph, input);
rpc_msg_graph_compute_rsp response;
auto sock = get_socket(rpc_ctx->endpoint);
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
return (enum ggml_status)response.result;
GGML_ASSERT(cgraph->n_nodes > 0);
bool reuse = rpc_ctx->gc.is_cached(cgraph);
if (reuse) {
rpc_msg_graph_recompute_req request;
request.device = rpc_ctx->device;
auto sock = get_socket(rpc_ctx->endpoint);
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request));
RPC_STATUS_ASSERT(status);
} else {
rpc_ctx->gc.add(cgraph);
std::vector<uint8_t> input;
serialize_graph(rpc_ctx->device, cgraph, input);
auto sock = get_socket(rpc_ctx->endpoint);
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size());
RPC_STATUS_ASSERT(status);
}
return GGML_STATUS_SUCCESS;
}
static ggml_backend_i ggml_backend_rpc_interface = {
@@ -880,7 +918,8 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) {
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
/* .endpoint = */ endpoint,
/* .device = */ device,
/* .name = */ dev_name
/* .name = */ dev_name,
/* .gc = */ {},
};
auto reg = ggml_backend_rpc_add_server(endpoint);
ggml_backend_t backend = new ggml_backend {
@@ -920,8 +959,9 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device,
class rpc_server {
public:
rpc_server(std::vector<ggml_backend_t> backends, const char * cache_dir)
: backends(std::move(backends)), cache_dir(cache_dir) {
rpc_server(std::vector<ggml_backend_t> all_backends, const char * cache_dir)
: backends(std::move(all_backends)), cache_dir(cache_dir) {
stored_graphs.resize(backends.size());
}
~rpc_server();
@@ -936,11 +976,17 @@ public:
bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response);
bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
bool graph_compute(const std::vector<uint8_t> & input);
bool graph_recompute(const rpc_msg_graph_recompute_req & request);
bool init_tensor(const rpc_msg_init_tensor_req & request);
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
struct stored_graph {
ggml_context_ptr ctx_ptr;
ggml_cgraph * graph;
};
private:
bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
@@ -953,6 +999,8 @@ private:
std::vector<ggml_backend_t> backends;
const char * cache_dir;
std::unordered_set<ggml_backend_buffer_t> buffers;
// store the last computed graph for each backend
std::vector<stored_graph> stored_graphs;
};
void rpc_server::hello(rpc_msg_hello_rsp & response) {
@@ -1394,7 +1442,7 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
return result;
}
bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
bool rpc_server::graph_compute(const std::vector<uint8_t> & input) {
// serialization format:
// | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
if (input.size() < 2*sizeof(uint32_t)) {
@@ -1455,7 +1503,24 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
}
}
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
response.result = status;
GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
stored_graphs[device].ctx_ptr.swap(ctx_ptr);
stored_graphs[device].graph = graph;
return true;
}
bool rpc_server::graph_recompute(const rpc_msg_graph_recompute_req & request) {
uint32_t device = request.device;
if (device >= backends.size()) {
return false;
}
if (stored_graphs[device].graph == nullptr) {
return false;
}
ggml_cgraph * graph = stored_graphs[device].graph;
LOG_DBG("[%s] device: %u\n", __func__, device);
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
return true;
}
@@ -1690,11 +1755,17 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
if (!recv_msg(sockfd, input)) {
return;
}
rpc_msg_graph_compute_rsp response;
if (!server.graph_compute(input, response)) {
if (!server.graph_compute(input)) {
return;
}
if (!send_msg(sockfd, &response, sizeof(response))) {
break;
}
case RPC_CMD_GRAPH_RECOMPUTE: {
rpc_msg_graph_recompute_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
if (!server.graph_recompute(request)) {
return;
}
break;
+6 -2
View File
@@ -91,7 +91,10 @@ if (GGML_SYCL_F16)
add_compile_definitions(GGML_SYCL_F16)
endif()
if (GGML_SYCL_TARGET STREQUAL "NVIDIA")
if (GGML_SYCL_TARGET STREQUAL "INTEL")
add_compile_definitions(GGML_SYCL_WARP_SIZE=16)
target_link_options(ggml-sycl PRIVATE -Xs -ze-intel-greater-than-4GB-buffer-required)
elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA")
add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
elseif (GGML_SYCL_TARGET STREQUAL "AMD")
# INFO: Allowed Sub_group_sizes are not consistent through all
@@ -100,7 +103,8 @@ elseif (GGML_SYCL_TARGET STREQUAL "AMD")
# Target archs tested working: gfx1030, gfx1031, (Only tested sub_group_size = 32)
add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
else()
add_compile_definitions(GGML_SYCL_WARP_SIZE=16)
# default for other target
add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
endif()
if (GGML_SYCL_GRAPH)
+26
View File
@@ -617,4 +617,30 @@ static __dpct_inline__ float get_alibi_slope(const float max_bias,
return dpct::pow(base, exph);
}
static const sycl::uint3 init_fastdiv_values(uint32_t d) {
GGML_ASSERT(d != 0);
uint32_t L = 0;
while (L < 32 && (uint32_t{ 1 } << L) < d) {
L++;
}
uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1);
return sycl::uint3(mp, L, d);
}
static __dpct_inline__ uint32_t fastdiv(uint32_t n, const sycl::uint3 fastdiv_values) {
const uint32_t hi = sycl::mul_hi<unsigned>(n, fastdiv_values.x());
return (hi + n) >> fastdiv_values.y();
}
static __dpct_inline__ sycl::uint2 fast_div_modulo(uint32_t n, const sycl::uint3 fastdiv_values) {
const uint32_t div_val = fastdiv(n, fastdiv_values);
const uint32_t mod_val = n - div_val * fastdiv_values.z();
return sycl::uint2(div_val, mod_val);
}
#endif // GGML_SYCL_COMMON_HPP
-3
View File
@@ -515,9 +515,6 @@ void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, co
const int64_t ne = ggml_nelements(src0);
GGML_ASSERT(ne == ggml_nelements(src1));
GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
GGML_TENSOR_BINARY_OP_LOCALS01;
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
+79 -51
View File
@@ -1,72 +1,100 @@
#include "pad_reflect_1d.hpp"
void pad_reflect_1d_f32(const float* src,float* dst,
const int64_t ne0, const int64_t ne02, const int p0, const int p1,
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
const sycl::nd_item<3> &item_ct1){
static void pad_reflect_1d_kernel_f32(
const void *__restrict__ src0, void *__restrict__ dst, const int64_t ne0,
const int64_t ne00, const sycl::uint3 ne01, const int64_t ne02,
const int64_t ne03, const int64_t nb00, const int64_t nb01,
const int64_t nb02, const int64_t nb03, const int64_t nb0,
const int64_t nb1, const int64_t nb2, const int64_t nb3, const int p0,
const int p1, sycl::nd_item<3> item_ct1) {
const int i0 = item_ct1.get_group(0) * SYCL_CONCAT_BLOCK_SIZE + item_ct1.get_local_id(0);
const int i1 = item_ct1.get_group(1);
const int g2 = item_ct1.get_group(2);
const int i2 = g2 % ne02;
const int i3 = g2 / ne02;
const int64_t i3 = item_ct1.get_group(0);
const int64_t i2 = item_ct1.get_group(1);
if (i0 >= p0 + ne0 + p1) return;
const sycl::uint2 div_mod_packed =
fast_div_modulo(item_ct1.get_group(2), ne01);
const int64_t tile1 = div_mod_packed.y();
const int64_t tile0 = div_mod_packed.x();
const int64_t i1 = tile1;
const int64_t i0 =
item_ct1.get_local_id(2) + tile0 * item_ct1.get_local_range(2);
int t = i0 - p0;
int period = 2 * ne0 -2;
int m = t % period;
m += (m < 0) * period;
int center = ne0 -1;
int srci0 = center - abs(center - m);
if (i0 >= ne0 || i1 >= ne01.z() || i2 >= ne02 || i3 >= ne03) {
return;
}
int offest_src = i3*nb3 + i2*nb2 + i1*nb1 + srci0*nb0;
int offest_dst = i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00;
dst[offest_dst] = src[offest_src];
const char *src0_ptr =
(const char *)src0 + i3 * nb03 + i2 * nb02 + i1 * nb01;
char *dst_ptr = (char *)dst + i3 * nb3 + i2 * nb2 + i1 * nb1;
const int64_t rel_i0 = i0 - p0; // relative i0 in src0
int64_t src_idx;
if (rel_i0 < 0) {
// Left padding - reflect
src_idx = -rel_i0;
} else if (rel_i0 < ne00) {
// Middle - copy
src_idx = rel_i0;
} else {
// Right padding - reflect
src_idx = 2 * ne00 - 2 - rel_i0;
}
const float value = *(const float *)(src0_ptr + src_idx * nb00);
*(float *)(dst_ptr + i0 * nb0) = value;
GGML_UNUSED(p1);
}
void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context& ctx, ggml_tensor* dst){
void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context &ctx,
ggml_tensor *dst) {
const ggml_tensor * src0 = dst->src[0];
queue_ptr stream = ctx.stream();
const ggml_tensor *src0 = dst->src[0];
dpct::queue_ptr stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
const int32_t * opts = (const int32_t *) dst->op_params;
const int32_t *opts = (const int32_t *)dst->op_params;
const int p0 = opts[0];
const int p1 = opts[1];
const int64_t ne0 = src0->ne[0];
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const sycl::uint3 ne01_packed = init_fastdiv_values(ne01);
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];
const int64_t ne00 = dst->ne[0];
const int64_t ne01 = dst->ne[1];
const int64_t ne02 = dst->ne[2];
const int64_t ne03 = dst->ne[3];
const int64_t ne0 = dst->ne[0];
const int64_t nb00 = dst->nb[0];
const int64_t nb01 = dst->nb[1];
const int64_t nb02 = dst->nb[2];
const int64_t nb03 = dst->nb[3];
const int64_t nb0 = src0->nb[0];
const int64_t nb1 = src0->nb[1];
const int64_t nb2 = src0->nb[2];
const int64_t nb3 = src0->nb[3];
GGML_ASSERT(ne0 == ne00 + p0 + p1);
int num_blocks = (ne00 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE;
sycl::range<3> global(num_blocks * SYCL_CONCAT_BLOCK_SIZE, ne01, ne02*ne03);
sycl::range<3> local(SYCL_CONCAT_BLOCK_SIZE, 1, 1);
constexpr int64_t bx = SYCL_PAD_REFLECT_1D_BLOCK_SIZE;
const int64_t tiles0 = (ne0 + bx - 1) / bx;
const dpct::dim3 grid_dims((unsigned)(ne01 * tiles0), (unsigned)ne02,
(unsigned)ne03);
const dpct::dim3 block_dims((unsigned)bx, 1, 1);
stream->parallel_for(
sycl::nd_range<3>(global,
local),
[=](sycl::nd_item<3> item_ct1) { pad_reflect_1d_f32(
(const float *) src0->data, (float *) dst->data,
ne0, ne02, p0, p1,
nb0, nb1, nb2, nb3,
nb00, nb01, nb02, nb03
, item_ct1);
});
stream->submit([&](sycl::handler &cgh) {
auto src0_data_ct0 = src0->data;
auto dst_data_ct1 = dst->data;
auto src0_nb_ct7 = src0->nb[0];
auto src0_nb_ct8 = src0->nb[1];
auto src0_nb_ct9 = src0->nb[2];
auto src0_nb_ct10 = src0->nb[3];
auto dst_nb_ct11 = dst->nb[0];
auto dst_nb_ct12 = dst->nb[1];
auto dst_nb_ct13 = dst->nb[2];
auto dst_nb_ct14 = dst->nb[3];
cgh.parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
pad_reflect_1d_kernel_f32(
src0_data_ct0, dst_data_ct1, ne0, ne00,
ne01_packed, ne02, ne03, src0_nb_ct7,
src0_nb_ct8, src0_nb_ct9, src0_nb_ct10,
dst_nb_ct11, dst_nb_ct12, dst_nb_ct13,
dst_nb_ct14, p0, p1, item_ct1);
});
});
}
+2
View File
@@ -3,6 +3,8 @@
#include "common.hpp"
#define SYCL_PAD_REFLECT_1D_BLOCK_SIZE 256
void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
#endif // GGML_SYCL_PAD_REFLECT_1D_HPP
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,69 @@
#version 450
#include "types.glsl"
#include "sum_rows.glsl"
#extension GL_EXT_control_flow_attributes : enable
#extension GL_KHR_shader_subgroup_arithmetic : enable
#extension GL_KHR_shader_subgroup_basic : enable
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
layout (constant_id = 0) const uint BLOCK_SIZE = 128;
layout (constant_id = 1) const uint SUBGROUP_SIZE = 32;
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
shared FLOAT_TYPE partial[BLOCK_SIZE / SUBGROUP_SIZE];
shared FLOAT_TYPE last_sum;
void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);
const uint i03_offset = i03 * p.ne01*p.ne02;
const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);
const uint i01 = row - i03_offset - i02*p.ne01;
const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;
uint subgroup_id = tid / SUBGROUP_SIZE;
if (tid == 0) {
last_sum = 0;
}
uint col = tid;
uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE);
for (int i = 0; i < num_iter; ++i) {
FLOAT_TYPE v = 0;
if (col < p.n_cols) {
v = FLOAT_TYPE(data_a[src_idx + col]);
}
v = subgroupInclusiveAdd(v);
// Store the largest partial sum for each subgroup, then add the partials for all
// lower subgroups and the final partial sum from the previous iteration.
if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) {
partial[subgroup_id] = v;
}
barrier();
for (int j = 0; j < subgroup_id; ++j) {
v += partial[j];
}
v += last_sum;
barrier();
if (tid == BLOCK_SIZE - 1) {
last_sum = v;
}
if (col < p.n_cols) {
data_d[dst_idx + col] = D_TYPE(v);
}
col += BLOCK_SIZE;
}
}
@@ -4,13 +4,6 @@
#include "types.glsl"
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
#if defined(DATA_A_F32)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
@@ -156,7 +156,7 @@ void main() {
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv, mvmax;
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
@@ -22,6 +22,13 @@ layout (push_constant) uniform parameter
#if !RMS_NORM_ROPE_FUSION
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#endif
@@ -18,6 +18,13 @@ layout (push_constant) uniform parameter
} p;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
uint get_idx() {
@@ -3,6 +3,7 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.glsl"
#include "dequant_funcs.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
@@ -13,8 +13,6 @@
#include "mul_mat_vec_iface.glsl"
#include "dequant_funcs.glsl"
layout (push_constant) uniform parameter
{
uint ncols;
@@ -5,13 +5,15 @@
#define MAT_VEC_FUSION_FLAGS_SCALE0 0x4
#define MAT_VEC_FUSION_FLAGS_SCALE1 0x8
#ifndef MMQ
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#if defined(A_TYPE_VEC4)
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
#endif
#else
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
@@ -10,60 +10,56 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
#if defined(DATA_A_QUANT_LEGACY) || defined(DATA_A_MXFP4)
#define K_PER_ITER 8
#include "mul_mmq_funcs.glsl"
#elif defined(DATA_A_QUANT_K)
#define K_PER_ITER 16
#else
#error unimplemented
#endif
uint a_offset, b_offset, d_offset;
int32_t cache_b_qs[2];
int32_t cache_b_qs[K_PER_ITER / 4];
vec2 cache_b_ds;
#include "mul_mat_vecq_funcs.glsl"
void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
const uint col = i*BLOCK_SIZE + tid*K_PER_ITER;
// Preload data_b block
const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset;
const uint b_qs_idx = tid % 4;
const uint b_qs_idx = tid % (32 / K_PER_ITER);
const uint b_block_idx_outer = b_block_idx / 4;
const uint b_block_idx_inner = b_block_idx % 4;
cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]);
#if QUANT_R == 2
// Assumes K_PER_ITER == 8
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx];
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx + 4];
#else
#if K_PER_ITER == 8
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2];
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2 + 1];
#elif K_PER_ITER == 16
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 ];
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 1];
cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 2];
cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 3];
#else
#error unimplemented
#endif
#endif
uint ibi = first_row*p.ncols;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint a_block_idx = (ibi + col)/QUANT_K + a_offset;
const uint a_block_idx = (ibi + col)/QUANT_K_Q8_1 + a_offset;
ibi += p.ncols;
int32_t q_sum = 0;
#if QUANT_R == 2
const i32vec2 data_a_qs = repack(a_block_idx, b_qs_idx);
q_sum += dotPacked4x8EXT(data_a_qs.x,
cache_b_qs[0]);
q_sum += dotPacked4x8EXT(data_a_qs.y,
cache_b_qs[1]);
#else
int32_t data_a_qs = repack(a_block_idx, b_qs_idx * 2);
q_sum += dotPacked4x8EXT(data_a_qs,
cache_b_qs[0]);
data_a_qs = repack(a_block_idx, b_qs_idx * 2 + 1);
q_sum += dotPacked4x8EXT(data_a_qs,
cache_b_qs[1]);
#endif
#if QUANT_AUXF == 1
temp[j][n] += mul_q8_1(q_sum, get_d(a_block_idx), cache_b_ds, 4);
#else
temp[j][n] += mul_q8_1(q_sum, get_dm(a_block_idx), cache_b_ds, 4);
#endif
temp[j][n] += mmvq_dot_product(a_block_idx, b_qs_idx);
}
}
}
@@ -72,7 +68,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint tid = gl_LocalInvocationID.x;
get_offsets(a_offset, b_offset, d_offset);
a_offset /= QUANT_K;
a_offset /= QUANT_K_Q8_1;
b_offset /= QUANT_K_Q8_1;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
@@ -102,14 +98,6 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
unroll_count = 2;
unrolled_iters = num_iters & ~(unroll_count - 1);
#if K_PER_ITER == 2
if ((p.ncols & 1) != 0 &&
unrolled_iters == num_iters &&
unrolled_iters > 0) {
unrolled_iters -= unroll_count;
}
#endif
while (i < unrolled_iters) {
// Manually partially unroll the loop
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
@@ -128,6 +116,10 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
// do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
@@ -0,0 +1,379 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
#include "types.glsl"
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
FLOAT_TYPE get_dm(uint ib) {
return FLOAT_TYPE(data_a[ib].d);
}
#endif
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
FLOAT_TYPE_VEC2 get_dm(uint ib) {
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
}
#endif
#if defined(DATA_A_MXFP4)
FLOAT_TYPE get_dm(uint ib) {
return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e));
}
#endif
#if defined(DATA_A_Q2_K)
FLOAT_TYPE_VEC2 get_dm(uint ib) {
const uint ib_k = ib / 8;
return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
}
#endif
// Each iqs value maps to a 32-bit integer
#if defined(DATA_A_Q4_0)
// 2-byte loads for Q4_0 blocks (18 bytes)
i32vec2 repack(uint ib, uint iqs) {
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
return i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
}
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return FLOAT_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
}
#endif
#if defined(DATA_A_Q4_1)
// 4-byte loads for Q4_1 blocks (20 bytes)
i32vec2 repack(uint ib, uint iqs) {
const uint32_t vui = data_a_packed32[ib].qs[iqs];
return i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
}
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return FLOAT_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
}
#endif
#if defined(DATA_A_Q5_0)
// 2-byte loads for Q5_0 blocks (22 bytes)
i32vec2 repack(uint ib, uint iqs) {
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs));
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
return i32vec2(v0, v1);
}
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return FLOAT_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
}
#endif
#if defined(DATA_A_Q5_1)
// 4-byte loads for Q5_1 blocks (24 bytes)
i32vec2 repack(uint ib, uint iqs) {
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
return i32vec2(v0, v1);
}
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return FLOAT_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
}
#endif
#if defined(DATA_A_Q8_0)
// 2-byte loads for Q8_0 blocks (34 bytes)
int32_t repack(uint ib, uint iqs) {
return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]));
}
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return FLOAT_TYPE(float(q_sum) * da * dsb.x);
}
#endif
#if defined(DATA_A_MXFP4)
// 1-byte loads for mxfp4 blocks (17 bytes)
i32vec2 repack(uint ib, uint iqs) {
const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
data_a[ib].qs[iqs * 4 + 1],
data_a[ib].qs[iqs * 4 + 2],
data_a[ib].qs[iqs * 4 + 3]));
const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F);
const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
return i32vec2(pack32(i8vec4(kvalues_mxfp4[i_a0.x], kvalues_mxfp4[i_a0.y], kvalues_mxfp4[i_a0.z], kvalues_mxfp4[i_a0.w])),
pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w])));
}
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return FLOAT_TYPE(da * dsb.x * float(q_sum) * 0.5);
}
#endif
#if defined(DATA_A_QUANT_LEGACY) || defined(DATA_A_MXFP4)
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
int32_t q_sum = 0;
#if QUANT_R == 2
const i32vec2 data_a_qs = repack(ib_a, iqs);
q_sum += dotPacked4x8EXT(data_a_qs.x,
cache_b_qs[0]);
q_sum += dotPacked4x8EXT(data_a_qs.y,
cache_b_qs[1]);
#else
int32_t data_a_qs = repack(ib_a, iqs * 2);
q_sum += dotPacked4x8EXT(data_a_qs,
cache_b_qs[0]);
data_a_qs = repack(ib_a, iqs * 2 + 1);
q_sum += dotPacked4x8EXT(data_a_qs,
cache_b_qs[1]);
#endif
// 2 quants per call => divide sums by 8/2 = 4
return mul_q8_1(q_sum, get_dm(ib_a), cache_b_ds, 4);
}
#endif
#if defined(DATA_A_Q2_K)
// 4-byte loads for Q2_K blocks (84 bytes)
i32vec4 repack4(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
return i32vec4((data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303,
(data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303,
(data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303,
(data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303);
}
uint8_t get_scale(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
return data_a[ib_k].scales[iqs_k / 4];
}
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
int32_t sum_d = 0;
int32_t sum_m = 0;
const i32vec4 qs_a = repack4(ib_a, iqs * 4);
const uint8_t scale = get_scale(ib_a, iqs * 4);
const vec2 dm = vec2(get_dm(ib_a));
const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits.
sum_d += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]) * (scale & 0xF);
sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[0]);
sum_d += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]) * (scale & 0xF);
sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[1]);
sum_d += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]) * (scale & 0xF);
sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[2]);
sum_d += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]) * (scale & 0xF);
sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[3]);
return FLOAT_TYPE(float(cache_b_ds.x) * (float(dm.x) * float(sum_d) - float(dm.y) * float(sum_m)));
}
#endif
#if defined(DATA_A_Q3_K)
// 2-byte loads for Q3_K blocks (110 bytes)
i32vec4 repack4(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
const uint hm_shift = iqs_k / 8;
// bitwise OR to add 4 if hmask is set, subtract later
const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals20 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 4] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 4] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals21 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 5] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 5] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals30 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 6] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 6] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals31 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 7] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 7] >> hm_shift) & uint16_t(0x0101)) << 2));
return i32vec4(pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y) - int8_t(4)),
pack32(i8vec4(vals10.x, vals10.y, vals11.x, vals11.y) - int8_t(4)),
pack32(i8vec4(vals20.x, vals20.y, vals21.x, vals21.y) - int8_t(4)),
pack32(i8vec4(vals30.x, vals30.y, vals31.x, vals31.y) - int8_t(4)));
}
float get_d_scale(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint is = iqs_k / 4;
const int8_t scale = int8_t(((data_a[ib_k].scales[is % 8 ] >> (4 * (is / 8))) & 0x0F0F) |
(((data_a[ib_k].scales[8 + (is % 4)] >> (2 * (is / 4))) & 0x0303) << 4));
return float(data_a[ib_k].d) * float(scale - 32);
}
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
int32_t q_sum = 0;
const i32vec4 qs_a = repack4(ib_a, iqs * 4);
const float d_scale = get_d_scale(ib_a, iqs * 4);
q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]);
q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]);
q_sum += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]);
q_sum += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]);
return FLOAT_TYPE(float(cache_b_ds.x) * d_scale * float(q_sum));
}
#endif
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)
i32vec4 repack4(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint qs_idx = (iqs_k / 16) * 8 + (iqs_k % 8);
const uint qs_shift = ((iqs_k % 16) / 8) * 4;
#if defined(DATA_A_Q4_K)
const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F;
const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F;
const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x0F0F0F0F;
const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x0F0F0F0F;
return i32vec4(vals0, vals1, vals2, vals3);
#else // defined(DATA_A_Q5_K)
const uint qh_idx = iqs;
const uint qh_shift = iqs_k / 8;
return i32vec4(((data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F) |
(((data_a_packed32[ib_k].qh[qh_idx ] >> qh_shift) & 0x01010101) << 4),
((data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F) |
(((data_a_packed32[ib_k].qh[qh_idx + 1] >> qh_shift) & 0x01010101) << 4),
((data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x0F0F0F0F) |
(((data_a_packed32[ib_k].qh[qh_idx + 2] >> qh_shift) & 0x01010101) << 4),
((data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x0F0F0F0F) |
(((data_a_packed32[ib_k].qh[qh_idx + 3] >> qh_shift) & 0x01010101) << 4));
#endif
}
vec2 get_dm_scale(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint is = iqs_k / 8;
u8vec2 scale_dm;
if (is < 4) {
scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F);
} else {
scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2),
(data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2));
}
return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm);
}
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
int32_t q_sum = 0;
const i32vec4 qs_a = repack4(ib_a, iqs * 4);
const vec2 dm_scale = get_dm_scale(ib_a, iqs * 4);
q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]);
q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]);
q_sum += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]);
q_sum += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]);
return FLOAT_TYPE(float(cache_b_ds.x) * float(dm_scale.x) * float(q_sum) - float(dm_scale.y) * float(cache_b_ds.y / 2));
}
#endif
#if defined(DATA_A_Q6_K)
// 2-byte loads for Q6_K blocks (210 bytes)
i32vec4 repack4(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint ql_idx = (iqs_k / 32) * 16 + iqs_k % 16;
const uint ql_shift = ((iqs_k % 32) / 16) * 4;
const uint qh_idx = (iqs_k / 32) * 8 + iqs;
const uint qh_shift = ((iqs_k % 32) / 8) * 2;
const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
const i8vec2 vals10 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 2] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 2] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
const i8vec2 vals11 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 3] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 3] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
const i8vec2 vals20 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 4] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 4] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
const i8vec2 vals21 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 5] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 5] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
const i8vec2 vals30 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 6] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 6] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
const i8vec2 vals31 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 7] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 7] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
return i32vec4(pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y)),
pack32(i8vec4(vals10.x, vals10.y, vals11.x, vals11.y)),
pack32(i8vec4(vals20.x, vals20.y, vals21.x, vals21.y)),
pack32(i8vec4(vals30.x, vals30.y, vals31.x, vals31.y)));
}
float get_d_scale(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
return float(data_a[ib_k].d) * float(data_a[ib_k].scales[iqs_k / 4]);
}
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
int32_t q_sum = 0;
const i32vec4 qs_a = repack4(ib_a, iqs * 4);
const float d_scale = get_d_scale(ib_a, iqs * 4);
q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]);
q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]);
q_sum += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]);
q_sum += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]);
return FLOAT_TYPE(float(cache_b_ds.x) * float(d_scale) * float(q_sum));
}
#endif
@@ -78,8 +78,6 @@ layout (constant_id = 10) const uint WARP = 32;
#define BK 32
#define MMQ_SHMEM
#include "mul_mmq_shmem_types.glsl"
#ifdef MUL_MAT_ID
@@ -9,31 +9,6 @@
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
// 2-byte loads for Q4_0 blocks (18 bytes)
// 4-byte loads for Q4_1 blocks (20 bytes)
i32vec2 repack(uint ib, uint iqs) {
#ifdef DATA_A_Q4_0
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
return i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
#else // DATA_A_Q4_1
const uint32_t vui = data_a_packed32[ib].qs[iqs];
return i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
#endif
}
#ifdef DATA_A_Q4_0
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
}
#else // DATA_A_Q4_1
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
}
#endif
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
#ifdef DATA_A_Q4_0
buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
@@ -73,42 +48,17 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
q_sum += dotPacked4x8EXT(qs_a.y, qs_b1);
}
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
#ifdef DATA_A_Q4_0
return ACC_TYPE(float(cache_a[ib_a].dm) * (float(q_sum) * float(cache_b.ds.x) - 8.0 * float(cache_b.ds.y)));
#else // DATA_A_Q4_1
return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm.x) * float(cache_b.ds.x) + float(cache_a[ib_a].dm.y) * float(cache_b.ds.y));
#endif
}
#endif // MMQ_SHMEM
#endif
#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
// 2-byte loads for Q5_0 blocks (22 bytes)
// 4-byte loads for Q5_1 blocks (24 bytes)
i32vec2 repack(uint ib, uint iqs) {
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
#ifdef DATA_A_Q5_0
const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs));
#else // DATA_A_Q5_1
const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
#endif
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
return i32vec2(v0, v1);
}
#ifdef DATA_A_Q5_0
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
}
#else // DATA_A_Q5_1
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
}
#endif
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
#ifdef DATA_A_Q5_0
buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
@@ -154,23 +104,16 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
q_sum += dotPacked4x8EXT(qs_a1, qs_b1);
}
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
#ifdef DATA_A_Q5_0
return ACC_TYPE(float(cache_a[ib_a].dm) * (float(q_sum) * float(cache_b.ds.x) - 16.0 * float(cache_b.ds.y)));
#else // DATA_A_Q5_1
return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm.x) * float(cache_b.ds.x) + float(cache_a[ib_a].dm.y) * float(cache_b.ds.y));
#endif
}
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_Q8_0)
// 2-byte loads for Q8_0 blocks (34 bytes)
int32_t repack(uint ib, uint iqs) {
return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]));
}
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(float(q_sum) * da * dsb.x);
}
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
buf_a[buf_ib].qs[iqs] = pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2],
data_a_packed16[ib].qs[iqs * 2 + 1]));
@@ -197,28 +140,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
q_sum += dotPacked4x8EXT(qs_a, qs_b);
}
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm) * float(cache_b.ds.x));
}
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_MXFP4)
// 1-byte loads for mxfp4 blocks (17 bytes)
i32vec2 repack(uint ib, uint iqs) {
const uint32_t quants = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
data_a[ib].qs[iqs * 4 + 1],
data_a[ib].qs[iqs * 4 + 2],
data_a[ib].qs[iqs * 4 + 3]));
return i32vec2( quants & 0x0F0F0F0F,
(quants >> 4) & 0x0F0F0F0F);
}
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(da * dsb.x * float(q_sum));
}
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
data_a[ib].qs[iqs * 4 + 1],
@@ -252,37 +179,14 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
}
return mul_q8_1(q_sum, cache_a[ib_a].d, cache_b.ds, 1);
return ACC_TYPE(float(cache_a[ib_a].d) * float(cache_b.ds.x) * float(q_sum));
}
#endif // MMQ_SHMEM
#endif
// For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide
// iqs still refers to a 32-bit integer, meaning 0..7 for 32-wide quants
#if defined(DATA_A_Q2_K)
// 4-byte loads for Q2_K blocks (84 bytes)
int32_t repack(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
return int32_t((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x03030303);
}
uint8_t get_scale(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
return data_a[ib_k].scales[iqs_k / 4];
}
ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(dsb.x * (dma.x * float(sum_d) - dma.y * float(sum_m)));
}
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
@@ -326,14 +230,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]);
}
return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1);
return ACC_TYPE(float(cache_b.ds.x) * (float(cache_a[ib_a].dm.x) * float(sum_d) - float(cache_a[ib_a].dm.y) * float(sum_m)));
}
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_Q3_K)
// 2-byte loads for Q3_K blocks (110 bytes)
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_k = ib / 8;
const uint hm_idx = iqs * QUANT_R_MMQ;
@@ -394,18 +296,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
}
result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
return ACC_TYPE(cache_b.ds.x * result);
return ACC_TYPE(float(cache_b.ds.x) * result);
}
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(dsb.x * dma.x * float(q_sum) - dma.y * dsb.y);
}
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
@@ -427,7 +323,6 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
(((data_a_packed32[ib_k].qh[qh_idx] >> qh_shift) & 0x01010101) << 4));
#endif
if (iqs == 0) {
// Scale index
const uint is = iqs_k / 8;
@@ -464,49 +359,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
}
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
}
#endif // MMQ_SHMEM
#endif
#ifdef MMQ_SHMEM
void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bool is_in_bounds) {
if (is_in_bounds) {
const uint ib_outer = ib / 4;
const uint ib_inner = ib % 4;
if (iqs == 0) {
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
}
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
buf_b[buf_ib].qs[iqs * 4 ] = values.x;
buf_b[buf_ib].qs[iqs * 4 + 1] = values.y;
buf_b[buf_ib].qs[iqs * 4 + 2] = values.z;
buf_b[buf_ib].qs[iqs * 4 + 3] = values.w;
} else {
if (iqs == 0) {
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(0.0f);
}
buf_b[buf_ib].qs[iqs * 4 ] = 0;
buf_b[buf_ib].qs[iqs * 4 + 1] = 0;
buf_b[buf_ib].qs[iqs * 4 + 2] = 0;
buf_b[buf_ib].qs[iqs * 4 + 3] = 0;
}
}
void block_b_to_registers(const uint ib) {
cache_b.ds = buf_b[ib].ds;
[[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) {
cache_b.qs[iqs] = buf_b[ib].qs[iqs];
}
return ACC_TYPE(float(cache_b.ds.x) * float(cache_a[ib_a].dm.x) * float(q_sum) - float(cache_a[ib_a].dm.y) * float(cache_b.ds.y));
}
#endif
#if defined(DATA_A_Q6_K)
// 2-byte loads for Q6_K blocks (210 bytes)
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
@@ -558,32 +416,39 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
}
result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
return ACC_TYPE(cache_b.ds.x * result);
}
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
FLOAT_TYPE get_d(uint ib) {
return FLOAT_TYPE(data_a[ib].d);
return ACC_TYPE(float(cache_b.ds.x) * result);
}
#endif
#if defined(DATA_A_MXFP4)
FLOAT_TYPE get_d(uint ib) {
return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e));
}
#endif
void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bool is_in_bounds) {
if (is_in_bounds) {
const uint ib_outer = ib / 4;
const uint ib_inner = ib % 4;
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
FLOAT_TYPE_VEC2 get_dm(uint ib) {
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
}
#endif
if (iqs == 0) {
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
}
#if defined(DATA_A_Q2_K)
FLOAT_TYPE_VEC2 get_dm(uint ib) {
const uint ib_k = ib / 8;
return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
buf_b[buf_ib].qs[iqs * 4 ] = values.x;
buf_b[buf_ib].qs[iqs * 4 + 1] = values.y;
buf_b[buf_ib].qs[iqs * 4 + 2] = values.z;
buf_b[buf_ib].qs[iqs * 4 + 3] = values.w;
} else {
if (iqs == 0) {
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(0.0f);
}
buf_b[buf_ib].qs[iqs * 4 ] = 0;
buf_b[buf_ib].qs[iqs * 4 + 1] = 0;
buf_b[buf_ib].qs[iqs * 4 + 2] = 0;
buf_b[buf_ib].qs[iqs * 4 + 3] = 0;
}
}
void block_b_to_registers(const uint ib) {
cache_b.ds = buf_b[ib].ds;
[[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) {
cache_b.qs[iqs] = buf_b[ib].qs[iqs];
}
}
#endif
@@ -0,0 +1,72 @@
#version 450
#include "types.glsl"
#include "generic_binary_head.glsl"
layout (constant_id = 1) const uint N = 64;
layout (constant_id = 2) const uint K = 32;
layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in;
uint a_base, b_base, x_base;
FLOAT_TYPE get_a(uint r, uint c) {
return FLOAT_TYPE(data_a[a_base + r * p.nb01 + c * p.nb00]);
}
FLOAT_TYPE get_b(uint r, uint c) {
return FLOAT_TYPE(data_b[b_base + r * p.nb11 + c * p.nb10]);
}
void store_x(uint r, uint c, FLOAT_TYPE v) {
data_d[x_base + r * p.nb21 + c * p.nb20] = D_TYPE(v);
}
shared FLOAT_TYPE shA[N * N];
shared FLOAT_TYPE shB[N * K];
void main() {
const uint batch = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
if (batch >= p.ne02 * p.ne03) {
return;
}
const uint i3 = batch / p.ne22;
const uint i2 = batch % p.ne22;
a_base = get_aoffset() + i2 * p.nb02 + i3 * p.nb03;
b_base = get_boffset() + i2 * p.nb12 + i3 * p.nb13;
x_base = get_doffset() + i2 * p.nb22 + i3 * p.nb23;
// Load the A matrix into shA
[[unroll]] for (uint i = 0; i < N * N; i += gl_WorkGroupSize.x) {
uint idx = i + tid;
if (((N * N) % gl_WorkGroupSize.x == 0) || idx < N * N) {
shA[idx] = get_a(idx / N, idx % N);
}
}
// Load the B matrix into shB
[[unroll]] for (uint i = 0; i < N * K; i += gl_WorkGroupSize.x) {
uint idx = i + tid;
if (((N * K) % gl_WorkGroupSize.x == 0) || idx < N * K) {
shB[idx] = get_b(idx / K, idx % K);
}
}
barrier();
FLOAT_TYPE X[N];
// Each thread solves one column
if (tid < K) {
[[unroll]] for (int r = 0; r < N; ++r) {
FLOAT_TYPE b = shB[r * K + tid];
// Compute x[r,c] = (b[r,c] - sum(a[r,c]*x[c])) / a[r,r]
[[unroll]] for (int c = 0; c < r; ++c) {
b -= shA[r * N + c] * X[c];
}
FLOAT_TYPE x = b / shA[r * N + r];
X[r] = x;
store_x(r, tid, x);
}
}
}
@@ -1,6 +1,7 @@
#version 450
#include "types.glsl"
#include "sum_rows.glsl"
#extension GL_EXT_control_flow_attributes : enable
@@ -11,30 +12,6 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
layout (push_constant) uniform parameter
{
uint n_cols;
uint ne01, ne02;
uint nb01, nb02, nb03;
uint nb11, nb12, nb13;
float weight;
uint misalign_offsets;
uint ne0_12mp, ne0_12L;
uint ne0_1mp, ne0_1L;
} p;
uint get_aoffset() { return p.misalign_offsets >> 16; }
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
// see init_fastdiv_values in ggml-vulkan.cpp
uint fastdiv(uint n, uint mp, uint L) {
uint msbs, lsbs;
// msbs = mulhi(n, mp)
umulExtended(n, mp, msbs, lsbs);
return (msbs + n) >> L;
}
shared FLOAT_TYPE tmp[BLOCK_SIZE];
void main() {
@@ -0,0 +1,25 @@
// vk_op_sum_rows_push_constants
layout (push_constant) uniform parameter
{
uint n_cols;
uint ne01, ne02;
uint nb01, nb02, nb03;
uint nb11, nb12, nb13;
float weight;
uint misalign_offsets;
uint ne0_12mp, ne0_12L;
uint ne0_1mp, ne0_1L;
} p;
uint get_aoffset() { return p.misalign_offsets >> 16; }
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
// see init_fastdiv_values in ggml-vulkan.cpp
uint fastdiv(uint n, uint mp, uint L) {
uint msbs, lsbs;
// msbs = mulhi(n, mp)
umulExtended(n, mp, msbs, lsbs);
return (msbs + n) >> L;
}
@@ -0,0 +1,113 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#include "types.glsl"
layout(constant_id = 0) const int BLOCK_SIZE = 1024;
layout(constant_id = 1) const int NCOLS_PADDED_LOG2 = 10;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
// Input can either be the source (A) or intermediate values (S).
// Similarly, output can be either destination (D) or intermediate values (S).
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 0) readonly buffer S {ivec2 data_s[];};
layout (binding = 1) writeonly buffer D {int data_d[];};
layout (binding = 1) writeonly buffer T {ivec2 data_t[];};
layout (push_constant) uniform parameter {
uint orig_ncols;
uint ncols_input;
uint ncols_output;
uint nrows;
uint first_pass;
uint last_pass;
} p;
// pairs of (gid, value)
shared ivec2 dst_row[BLOCK_SIZE];
void topk(bool needs_bounds_check, const uint row) {
const int col = int(gl_LocalInvocationID.x);
// initialize indices
if (gl_GlobalInvocationID.x < p.ncols_input) {
if (p.first_pass != 0) {
const uint row_offset = row * p.ncols_input;
dst_row[col] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));
} else {
const uint row_offset = row * p.orig_ncols;
dst_row[col] = data_s[row_offset + gl_GlobalInvocationID.x];
}
} else {
dst_row[col] = ivec2(p.orig_ncols, 0);
}
barrier();
if (p.ncols_output == 1) {
// Fast path for single output - just do a max reduction
[[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {
if (col < s) {
ivec2 a = dst_row[col];
ivec2 b = dst_row[col + s];
if (a.x >= p.orig_ncols ||
b.x < p.orig_ncols && b.y > a.y) {
dst_row[col] = b;
}
}
barrier();
}
} else {
// bitonic sort on this group of elements
uint num_outer_loop_iters = NCOLS_PADDED_LOG2;
for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) {
uint num_inner_loop_iters = outer_idx + 1;
for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) {
const int ixj = int(col ^ j);
int idx_0 = (col & k) == 0 ? col : ixj;
int idx_1 = (col & k) == 0 ? ixj : col;
ivec2 sh_idx_0 = dst_row[idx_0];
ivec2 sh_idx_1 = dst_row[idx_1];
bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.orig_ncols : false;
bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.orig_ncols : false;
if ((idx_0_oob ||
(!idx_1_oob && intBitsToFloat(sh_idx_0.y) < intBitsToFloat(sh_idx_1.y))) && (ixj > col)) {
dst_row[idx_0] = sh_idx_1;
dst_row[idx_1] = sh_idx_0;
}
barrier();
}
}
}
if (col < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) {
if (p.last_pass != 0) {
const uint row_offset = row * p.ncols_output;
data_d[row_offset + col] = dst_row[col].x;
} else {
const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output;
data_t[row_offset + col] = dst_row[col];
}
}
}
void main() {
// Fast path for fully occupied workgroups
if ((p.ncols_input % BLOCK_SIZE) == 0) {
uint row = gl_WorkGroupID.y;
while (row < p.nrows) {
topk(false, row);
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
}
} else {
uint row = gl_WorkGroupID.y;
while (row < p.nrows) {
topk(true, row);
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
}
}
}
@@ -0,0 +1,199 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_debug_printf : enable
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_KHR_shader_subgroup_ballot : enable
#extension GL_KHR_shader_subgroup_arithmetic : enable
#extension GL_KHR_shader_subgroup_shuffle : enable
#include "types.glsl"
layout(constant_id = 0) const int BLOCK_SIZE = 1024;
layout(constant_id = 1) const int SUBGROUP_SIZE = 32;
layout(constant_id = 2) const int SUBGROUP_SIZE_LOG2 = 5;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
// Input can either be the source (A) or intermediate values (S).
// Similarly, output can be either destination (D) or intermediate values (S).
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 0) readonly buffer S {ivec2 data_s[];};
layout (binding = 1) writeonly buffer D {int data_d[];};
layout (binding = 1) writeonly buffer T {ivec2 data_t[];};
layout (push_constant) uniform parameter {
uint orig_ncols;
uint ncols_input;
uint ncols_output;
uint nrows;
uint first_pass;
uint last_pass;
} p;
// pairs of (gid, value)
shared ivec2 dst_row[BLOCK_SIZE];
shared int counts[SUBGROUP_SIZE];
shared int sh_min_idx;
shared uint sh_total;
shared uint offset_partials[BLOCK_SIZE / SUBGROUP_SIZE];
// Map float values to uint such that comparisons still work.
// Positive values set the high bit, negative values are inverted.
// +0.0 -> 0x80000000, -0.0 -> 0x7FFFFFFF are in the correct places.
uint f2ui(float x) {
uint y = floatBitsToUint(x);
if ((y & 0x80000000) != 0) {
y ^= ~0;
} else {
y |= 0x80000000;
}
return y;
}
void topk(const uint row) {
const int tid = int(gl_LocalInvocationID.x);
// initialize indices
if (gl_GlobalInvocationID.x < p.ncols_input) {
if (p.first_pass != 0) {
const uint row_offset = row * p.ncols_input;
dst_row[tid] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));
} else {
const uint row_offset = row * p.orig_ncols;
dst_row[tid] = data_s[row_offset + gl_GlobalInvocationID.x];
}
} else {
dst_row[tid] = ivec2(p.orig_ncols, 0xFF800000); // -inf
}
barrier();
if (p.ncols_output == 1) {
// Fast path for single output - just do a max reduction
[[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {
if (tid < s) {
ivec2 a = dst_row[tid];
ivec2 b = dst_row[tid + s];
if (a.x >= p.orig_ncols ||
b.x < p.orig_ncols && b.y > a.y) {
dst_row[tid] = b;
}
}
barrier();
}
} else {
// Do an N-ary search to find the K-th largest value.
// We remap the float values to be comparable as unsigned integers,
// and split the range into 2^N smaller ranges where N is the
// subgroup size. Count how many values are in each range, if the K-th
// largest value is in the middle of one of thee ranges then repeat
// and split again.
// Mask is the current set of bits we're searching. Shift is the LSB index.
int shift = 32 - SUBGROUP_SIZE_LOG2;
uint mask = ((1 << SUBGROUP_SIZE_LOG2) - 1) << shift;
// The current range.
uint range_min = 0;
uint range_max = 0xFF800000;
// How many are above the current range, and how many we need to find.
uint total = 0;
uint limit = min(p.ncols_output, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE);
while (mask != 0) {
barrier();
// Initialize bucket counts to zero.
if (tid < SUBGROUP_SIZE) {
counts[tid] = 0;
}
barrier();
// Count how many values are in each bucket.
if (tid < p.ncols_input) {
float y = intBitsToFloat(dst_row[tid].y);
uint fy = f2ui(y);
if (fy >= range_min && fy < range_max) {
uint bucket = (fy & mask) >> shift;
atomicAdd(counts[bucket], 1);
}
}
barrier();
// On the first subgroup, do a scan to count (from the top down) how
// many elements are in the top N buckets. Find the index of the first
// that is over the limit. Copy it to the other invocations through
// shared memory.
if (tid < SUBGROUP_SIZE) {
uint partial_sum = counts[SUBGROUP_SIZE - 1 - tid];
partial_sum = subgroupInclusiveAdd(partial_sum) + total;
uint t = subgroupBallotFindLSB(subgroupBallot(partial_sum >= limit));
if (tid == t) {
sh_min_idx = int(SUBGROUP_SIZE - 1 - t);
sh_total = partial_sum;
}
}
barrier();
int min_idx = sh_min_idx;
total = sh_total;
// Update the range, and break if we've found the K-th largest.
range_max = range_min + ((min_idx + 1) << shift);
range_min = range_min + (min_idx << shift);
if (total == p.ncols_output) {
break;
}
total -= counts[min_idx];
mask >>= SUBGROUP_SIZE_LOG2;
shift -= SUBGROUP_SIZE_LOG2;
if (shift < 0) {
shift = 0;
}
}
ivec2 v = dst_row[tid];
// We need to compact these values to the start of the dst_row array.
// Have each subgroup count how many items it'll store, so other
// subgroups can compute their base offset.
bool top = f2ui(intBitsToFloat(v.y)) >= range_min;
uvec4 b = subgroupBallot(top);
uint bit_count = subgroupBallotBitCount(b);
if ((tid % SUBGROUP_SIZE) == 0) {
offset_partials[tid / SUBGROUP_SIZE] = bit_count;
}
barrier();
uint out_idx = 0;
[[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
if (i < tid / SUBGROUP_SIZE) {
out_idx += offset_partials[i];
}
}
uint bit_count_ex = subgroupBallotExclusiveBitCount(b);
if (top) {
// TODO: Copy directly to the output?
dst_row[out_idx + bit_count_ex] = v;
}
barrier();
}
if (tid < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) {
if (p.last_pass != 0) {
const uint row_offset = row * p.ncols_output;
data_d[row_offset + tid] = dst_row[tid].x;
} else {
const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output;
data_t[row_offset + tid] = dst_row[tid];
}
}
}
void main() {
uint row = gl_WorkGroupID.y;
while (row < p.nrows) {
topk(row);
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
}
}
@@ -0,0 +1,43 @@
#version 450
#include "rte.glsl"
#include "types.glsl"
#include "generic_unary_head.glsl"
#define GGML_TRI_TYPE_UPPER_DIAG 0
#define GGML_TRI_TYPE_UPPER 1
#define GGML_TRI_TYPE_LOWER_DIAG 2
#define GGML_TRI_TYPE_LOWER 3
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L);
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L);
const uint i02_offset = i02*p.ne01*p.ne00;
const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L);
const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
int param = floatBitsToInt(p.param1);
bool pass = false;
switch (param) {
case GGML_TRI_TYPE_UPPER_DIAG: pass = i00 >= i01; break;
case GGML_TRI_TYPE_UPPER: pass = i00 > i01; break;
case GGML_TRI_TYPE_LOWER_DIAG: pass = i00 <= i01; break;
case GGML_TRI_TYPE_LOWER: pass = i00 < i01; break;
}
if (pass) {
const float val = float(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val);
} else {
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(0);
}
}
@@ -679,14 +679,20 @@ void process_shaders() {
string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
// mul mat vec with integer dot product
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (is_legacy_quant(tname)) {
if (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname)) {
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
}
#endif
@@ -846,6 +852,9 @@ void process_shaders() {
string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("tri_f16", "tri.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("tri_f32", "tri.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("softplus_f16", "softplus.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("softplus_f32", "softplus.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
@@ -913,9 +922,13 @@ void process_shaders() {
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}});
string_to_spv("topk_argsort_f32", "topk_argsort.comp", {{"A_TYPE", "float"}});
string_to_spv("topk_nary_search_f32", "topk_nary_search.comp", {{"A_TYPE", "float"}});
string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
for (std::string dim_str : {"", "_3d"}) {
for (bool bda : {false, true}) {
@@ -940,6 +953,8 @@ void process_shaders() {
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("solve_tri_f32", "solve_tri.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
for (auto transpose : {false, true}) {
for (auto unroll : {false, true}) {
for (auto a_f16 : {false, true}) {
@@ -1091,7 +1106,7 @@ void write_output_files() {
for (const std::string& btype : btypes) {
for (const auto& tname : type_names) {
if (btype == "q8_1" && !is_legacy_quant(tname)) {
if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "mxfp4" && !is_k_quant(tname)) {
continue;
}
hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n";
@@ -1100,6 +1115,16 @@ void write_output_files() {
src << "const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_data};\n";
src << "const uint64_t arr_dmmv_" << tname << "_" << btype << "_f32_len[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n";
}
if (btype == "f16") {
continue;
}
hdr << "extern const void * arr_dmmv_id_" << tname << "_" << btype << "_f32_data[3];\n";
hdr << "extern const uint64_t arr_dmmv_id_" << tname << "_" << btype << "_f32_len[3];\n";
if (basename(input_filepath) == "mul_mat_vec.comp") {
src << "const void * arr_dmmv_id_" << tname << "_" << btype << "_f32_data[3] = {mul_mat_vec_id_" << tname << "_" << btype << "_f32_data, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_data, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_no_shmem_data};\n";
src << "const uint64_t arr_dmmv_id_" << tname << "_" << btype << "_f32_len[3] = {mul_mat_vec_id_" << tname << "_" << btype << "_f32_len, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n";
}
}
}
+48 -29
View File
@@ -990,6 +990,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"ARANGE",
"TIMESTEP_EMBEDDING",
"ARGSORT",
"TOP_K",
"LEAKY_RELU",
"TRI",
"FILL",
@@ -1023,7 +1024,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GLU",
};
static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94");
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@@ -1098,6 +1099,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"arange(start, stop, step)",
"timestep_embedding(timesteps, dim, max_period)",
"argsort(x)",
"top_k(x)",
"leaky_relu(x)",
"tri(x)",
"fill(x, c)",
@@ -1131,7 +1133,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"glu(x)",
};
static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94");
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@@ -5036,28 +5038,6 @@ struct ggml_tensor * ggml_roll(
return result;
}
// ggml_arange
struct ggml_tensor * ggml_arange(
struct ggml_context * ctx,
float start,
float stop,
float step) {
GGML_ASSERT(stop > start);
const int64_t steps = (int64_t) ceilf((stop - start) / step);
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps);
ggml_set_op_params_f32(result, 0, start);
ggml_set_op_params_f32(result, 1, stop);
ggml_set_op_params_f32(result, 2, step);
result->op = GGML_OP_ARANGE;
return result;
}
// ggml_timestep_embedding
struct ggml_tensor * ggml_timestep_embedding(
@@ -5139,6 +5119,7 @@ struct ggml_tensor * ggml_argsort(
struct ggml_tensor * a,
enum ggml_sort_order order) {
GGML_ASSERT(a->ne[0] <= INT32_MAX);
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne);
ggml_set_op_params_i32(result, 0, (int32_t) order);
@@ -5149,6 +5130,24 @@ struct ggml_tensor * ggml_argsort(
return result;
}
// ggml_argsort_top_k
struct ggml_tensor * ggml_argsort_top_k(
struct ggml_context * ctx,
struct ggml_tensor * a,
int k) {
GGML_ASSERT(a->ne[0] >= k);
struct ggml_tensor * result = ggml_argsort(ctx, a, GGML_SORT_ORDER_DESC);
result = ggml_view_4d(ctx, result,
k, result->ne[1], result->ne[2], result->ne[3],
result->nb[1], result->nb[2], result->nb[3],
0);
return result;
}
// ggml_top_k
struct ggml_tensor * ggml_top_k(
@@ -5157,12 +5156,32 @@ struct ggml_tensor * ggml_top_k(
int k) {
GGML_ASSERT(a->ne[0] >= k);
struct ggml_tensor * result = ggml_argsort(ctx, a, GGML_SORT_ORDER_DESC);
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_I32, k, a->ne[1], a->ne[2], a->ne[3]);
result = ggml_view_4d(ctx, result,
k, result->ne[1], result->ne[2], result->ne[3],
result->nb[1], result->nb[2], result->nb[3],
0);
result->op = GGML_OP_TOP_K;
result->src[0] = a;
return result;
}
// ggml_arange
struct ggml_tensor * ggml_arange(
struct ggml_context * ctx,
float start,
float stop,
float step) {
GGML_ASSERT(stop > start);
const int64_t steps = (int64_t) ceilf((stop - start) / step);
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps);
ggml_set_op_params_f32(result, 0, start);
ggml_set_op_params_f32(result, 1, stop);
ggml_set_op_params_f32(result, 2, step);
result->op = GGML_OP_ARANGE;
return result;
}
+66
View File
@@ -25,6 +25,20 @@ class Keys:
ALIGNMENT = "general.alignment"
FILE_TYPE = "general.file_type"
# Recommended Sampler Parameters
SAMPLING_SEQUENCE = "general.sampling.sequence"
SAMPLING_TOP_K = "general.sampling.top_k"
SAMPLING_TOP_P = "general.sampling.top_p"
SAMPLING_MIN_P = "general.sampling.min_p"
SAMPLING_XTC_PROBABILITY = "general.sampling.xtc_probability"
SAMPLING_XTC_THRESHOLD = "general.sampling.xtc_threshold"
SAMPLING_TEMP = "general.sampling.temp"
SAMPLING_PENALTY_LAST_N = "general.sampling.penalty_last_n"
SAMPLING_PENALTY_REPEAT = "general.sampling.penalty_repeat"
SAMPLING_MIROSTAT = "general.sampling.mirostat"
SAMPLING_MIROSTAT_TAU = "general.sampling.mirostat_tau"
SAMPLING_MIROSTAT_ETA = "general.sampling.mirostat_eta"
# Authorship Metadata
NAME = "general.name"
AUTHOR = "general.author"
@@ -352,6 +366,7 @@ class MODEL_ARCH(IntEnum):
QWEN2VL = auto()
QWEN3 = auto()
QWEN3MOE = auto()
QWEN3NEXT = auto()
QWEN3VL = auto()
QWEN3VLMOE = auto()
PHI2 = auto()
@@ -427,6 +442,7 @@ class MODEL_ARCH(IntEnum):
APERTUS = auto()
COGVLM = auto()
MINIMAXM2 = auto()
RND1 = auto()
PANGU_EMBED = auto()
@@ -516,6 +532,7 @@ class MODEL_TENSOR(IntEnum):
SSM_D = auto()
SSM_NORM = auto()
SSM_OUT = auto()
SSM_BETA_ALPHA = auto() # qwen3next
TIME_MIX_W0 = auto()
TIME_MIX_W1 = auto()
TIME_MIX_W2 = auto()
@@ -721,6 +738,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.QWEN2VL: "qwen2vl",
MODEL_ARCH.QWEN3: "qwen3",
MODEL_ARCH.QWEN3MOE: "qwen3moe",
MODEL_ARCH.QWEN3NEXT: "qwen3next",
MODEL_ARCH.QWEN3VL: "qwen3vl",
MODEL_ARCH.QWEN3VLMOE: "qwen3vlmoe",
MODEL_ARCH.PHI2: "phi2",
@@ -797,6 +815,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.APERTUS: "apertus",
MODEL_ARCH.MINIMAXM2: "minimax-m2",
MODEL_ARCH.COGVLM: "cogvlm",
MODEL_ARCH.RND1: "rnd1",
MODEL_ARCH.PANGU_EMBED: "pangu-embedded",
}
@@ -884,6 +903,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm",
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
MODEL_TENSOR.SSM_BETA_ALPHA: "blk.{bid}.ssm_ba",
MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0",
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2",
@@ -1553,6 +1573,35 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
MODEL_ARCH.QWEN3NEXT: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.ATTN_GATE,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_INP_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.SSM_A,
MODEL_TENSOR.SSM_CONV1D,
MODEL_TENSOR.SSM_DT,
MODEL_TENSOR.SSM_NORM,
MODEL_TENSOR.SSM_IN,
MODEL_TENSOR.SSM_BETA_ALPHA,
MODEL_TENSOR.SSM_OUT
],
MODEL_ARCH.QWEN3VL: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
@@ -2991,6 +3040,23 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.VISEXP_UP,
MODEL_TENSOR.VISEXP_DOWN,
],
MODEL_ARCH.RND1: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
MODEL_ARCH.PANGU_EMBED: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
+53 -6
View File
@@ -4,6 +4,7 @@ import logging
import os
import shutil
import struct
import sys
import tempfile
from dataclasses import dataclass
from enum import Enum, auto
@@ -370,10 +371,15 @@ class GGUFWriter:
def add_tensor(
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
raw_dtype: GGMLQuantizationType | None = None,
raw_dtype: GGMLQuantizationType | None = None, tensor_endianess: GGUFEndian | None = None
) -> None:
if self.endianess == GGUFEndian.BIG:
tensor.byteswap(inplace=True)
# if tensor endianness is not passed, assume it's native to system
if tensor_endianess is None:
tensor_endianess = GGUFEndian.BIG if sys.byteorder == 'big' else GGUFEndian.LITTLE
if tensor_endianess != self.endianess:
# Don't byteswap inplace since lazy copies cannot handle it
tensor = tensor.byteswap(inplace=False)
if self.use_temp_file and self.temp_file is None:
fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256 * 1024 * 1024)
fp.seek(0)
@@ -394,13 +400,18 @@ class GGUFWriter:
if pad != 0:
fp.write(bytes([0] * pad))
def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
def write_tensor_data(self, tensor: np.ndarray[Any, Any], tensor_endianess: GGUFEndian | None = None) -> None:
if self.state is not WriterState.TI_DATA and self.state is not WriterState.WEIGHTS:
raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}')
assert self.fout is not None
if self.endianess == GGUFEndian.BIG:
tensor.byteswap(inplace=True)
# if tensor endianness is not passed, assume it's native to system
if tensor_endianess is None:
tensor_endianess = GGUFEndian.BIG if sys.byteorder == 'big' else GGUFEndian.LITTLE
if tensor_endianess != self.endianess:
# Don't byteswap inplace since lazy copies cannot handle it
tensor = tensor.byteswap(inplace=False)
file_id = -1
for i, tensors in enumerate(self.tensors):
@@ -496,6 +507,42 @@ class GGUFWriter:
def add_file_type(self, ftype: int) -> None:
self.add_uint32(Keys.General.FILE_TYPE, ftype)
def add_sampling_sequence(self, sequence: str) -> None:
self.add_string(Keys.General.SAMPLING_SEQUENCE, sequence)
def add_sampling_top_k(self, top_k: int) -> None:
self.add_int32(Keys.General.SAMPLING_TOP_K, top_k)
def add_sampling_top_p(self, top_p: float) -> None:
self.add_float32(Keys.General.SAMPLING_TOP_P, top_p)
def add_sampling_min_p(self, min_p: float) -> None:
self.add_float32(Keys.General.SAMPLING_MIN_P, min_p)
def add_sampling_xtc_probability(self, xtc_probability: float) -> None:
self.add_float32(Keys.General.SAMPLING_XTC_PROBABILITY, xtc_probability)
def add_sampling_xtc_threshold(self, xtc_threshold: float) -> None:
self.add_float32(Keys.General.SAMPLING_XTC_THRESHOLD, xtc_threshold)
def add_sampling_temp(self, temp: float) -> None:
self.add_float32(Keys.General.SAMPLING_TEMP, temp)
def add_sampling_penalty_last_n(self, penalty_last_n: int) -> None:
self.add_int32(Keys.General.SAMPLING_PENALTY_LAST_N, penalty_last_n)
def add_sampling_penalty_repeat(self, penalty_repeat: float) -> None:
self.add_float32(Keys.General.SAMPLING_PENALTY_REPEAT, penalty_repeat)
def add_sampling_mirostat(self, mirostat: int) -> None:
self.add_int32(Keys.General.SAMPLING_MIROSTAT, mirostat)
def add_sampling_mirostat_tau(self, mirostat_tau: float) -> None:
self.add_float32(Keys.General.SAMPLING_MIROSTAT_TAU, mirostat_tau)
def add_sampling_mirostat_eta(self, mirostat_eta: float) -> None:
self.add_float32(Keys.General.SAMPLING_MIROSTAT_ETA, mirostat_eta)
def add_name(self, name: str) -> None:
self.add_string(Keys.General.NAME, name)
+85
View File
@@ -17,6 +17,20 @@ logger = logging.getLogger("metadata")
@dataclass
class Metadata:
# Recommended Sampler Parameters to be written to GGUF KV Store
sampling_sequence: Optional[str] = None
sampling_top_k: Optional[int] = None
sampling_top_p: Optional[float] = None
sampling_min_p: Optional[float] = None
sampling_xtc_probability: Optional[float] = None
sampling_xtc_threshold: Optional[float] = None
sampling_temp: Optional[float] = None
sampling_penalty_last_n: Optional[int] = None
sampling_penalty_repeat: Optional[float] = None
sampling_mirostat: Optional[int] = None
sampling_mirostat_tau: Optional[float] = None
sampling_mirostat_eta: Optional[float] = None
# Authorship Metadata to be written to GGUF KV Store
name: Optional[str] = None
author: Optional[str] = None
@@ -54,15 +68,43 @@ class Metadata:
model_card = Metadata.load_model_card(model_path)
hf_params = Metadata.load_hf_parameters(model_path)
gen_config = Metadata.load_generation_config(model_path)
# TODO: load adapter_config.json when possible, it usually contains the base model of the LoRA adapter
# heuristics
metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params)
if gen_config:
metadata.sampling_sequence = gen_config.get("sequence", metadata.sampling_sequence)
metadata.sampling_top_k = gen_config.get("top_k", metadata.sampling_top_k)
metadata.sampling_top_p = gen_config.get("top_p", metadata.sampling_top_p)
metadata.sampling_min_p = gen_config.get("min_p", metadata.sampling_min_p)
metadata.sampling_xtc_probability = gen_config.get("xtc_probability", metadata.sampling_xtc_probability)
metadata.sampling_xtc_threshold = gen_config.get("xtc_threshold", metadata.sampling_xtc_threshold)
metadata.sampling_temp = gen_config.get("temperature", metadata.sampling_temp)
metadata.sampling_penalty_last_n = gen_config.get("penalty_last_n", metadata.sampling_penalty_last_n)
metadata.sampling_penalty_repeat = gen_config.get("penalty_repeat", metadata.sampling_penalty_repeat)
metadata.sampling_mirostat = gen_config.get("mirostat", metadata.sampling_mirostat)
metadata.sampling_mirostat_tau = gen_config.get("mirostat_tau", metadata.sampling_mirostat_tau)
metadata.sampling_mirostat_eta = gen_config.get("mirostat_eta", metadata.sampling_mirostat_eta)
# Metadata Override File Provided
# This is based on LLM_KV_NAMES mapping in llama.cpp
metadata_override = Metadata.load_metadata_override(metadata_override_path)
metadata.sampling_sequence = metadata_override.get(Keys.General.SAMPLING_SEQUENCE, metadata.sampling_sequence)
metadata.sampling_top_k = metadata_override.get(Keys.General.SAMPLING_TOP_K, metadata.sampling_top_k)
metadata.sampling_top_p = metadata_override.get(Keys.General.SAMPLING_TOP_P, metadata.sampling_top_p)
metadata.sampling_min_p = metadata_override.get(Keys.General.SAMPLING_MIN_P, metadata.sampling_min_p)
metadata.sampling_xtc_probability = metadata_override.get(Keys.General.SAMPLING_XTC_PROBABILITY, metadata.sampling_xtc_probability)
metadata.sampling_xtc_threshold = metadata_override.get(Keys.General.SAMPLING_XTC_THRESHOLD, metadata.sampling_xtc_threshold)
metadata.sampling_temp = metadata_override.get(Keys.General.SAMPLING_TEMP, metadata.sampling_temp)
metadata.sampling_penalty_last_n = metadata_override.get(Keys.General.SAMPLING_PENALTY_LAST_N, metadata.sampling_penalty_last_n)
metadata.sampling_penalty_repeat = metadata_override.get(Keys.General.SAMPLING_PENALTY_REPEAT, metadata.sampling_penalty_repeat)
metadata.sampling_mirostat = metadata_override.get(Keys.General.SAMPLING_MIROSTAT, metadata.sampling_mirostat)
metadata.sampling_mirostat_tau = metadata_override.get(Keys.General.SAMPLING_MIROSTAT_TAU, metadata.sampling_mirostat_tau)
metadata.sampling_mirostat_eta = metadata_override.get(Keys.General.SAMPLING_MIROSTAT_ETA, metadata.sampling_mirostat_eta)
metadata.name = metadata_override.get(Keys.General.NAME, metadata.name)
metadata.author = metadata_override.get(Keys.General.AUTHOR, metadata.author)
metadata.version = metadata_override.get(Keys.General.VERSION, metadata.version)
@@ -172,6 +214,23 @@ class Metadata:
with open(config_path, "r", encoding="utf-8") as f:
return json.load(f)
@staticmethod
def load_generation_config(model_path: Optional[Path] = None) -> dict[str, Any]:
if model_path is None or not model_path.is_dir():
return {}
generation_config_path = model_path / "generation_config.json"
if not generation_config_path.is_file():
return {}
try:
with open(generation_config_path, "r", encoding="utf-8") as f:
return json.load(f)
except (json.JSONDecodeError, IOError):
# not all models have valid generation_config.json
return {}
@staticmethod
def id_to_title(string):
# Convert capitalization into title form unless acronym or version number
@@ -546,6 +605,32 @@ class Metadata:
def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter):
assert self.name is not None
if self.sampling_sequence is not None:
gguf_writer.add_sampling_sequence(self.sampling_sequence)
if self.sampling_top_k is not None:
gguf_writer.add_sampling_top_k(self.sampling_top_k)
if self.sampling_top_p is not None:
gguf_writer.add_sampling_top_p(self.sampling_top_p)
if self.sampling_min_p is not None:
gguf_writer.add_sampling_min_p(self.sampling_min_p)
if self.sampling_xtc_probability is not None:
gguf_writer.add_sampling_xtc_probability(self.sampling_xtc_probability)
if self.sampling_xtc_threshold is not None:
gguf_writer.add_sampling_xtc_threshold(self.sampling_xtc_threshold)
if self.sampling_temp is not None:
gguf_writer.add_sampling_temp(self.sampling_temp)
if self.sampling_penalty_last_n is not None:
gguf_writer.add_sampling_penalty_last_n(self.sampling_penalty_last_n)
if self.sampling_penalty_repeat is not None:
gguf_writer.add_sampling_penalty_repeat(self.sampling_penalty_repeat)
if self.sampling_mirostat is not None:
gguf_writer.add_sampling_mirostat(self.sampling_mirostat)
if self.sampling_mirostat_tau is not None:
gguf_writer.add_sampling_mirostat_tau(self.sampling_mirostat_tau)
if self.sampling_mirostat_eta is not None:
gguf_writer.add_sampling_mirostat_eta(self.sampling_mirostat_eta)
gguf_writer.add_name(self.name)
if self.author is not None:

Some files were not shown because too many files have changed in this diff Show More