Compare commits

..

35 Commits

Author SHA1 Message Date
Akarshan Biswas 6c02a032fa SYCL: Remove misleading ggml_sycl_op_flatten function (#12387)
* SYCL: Remove misleading ggml_sycl_op_flatten function

* remove trailing whitespace

* Fix L2 norm from rebase

* remove try catch block from element_wise.cpp

* remove comment from common.hp

* ggml-sycl.cpp: Add try catch sycl::exception block in compute_forward

* norm.cpp: remove try catch exception block
2025-03-31 11:25:24 +02:00
Sigbjørn Skjæret f52d59d771 llava : fix clip loading GGUFs with missing description (#12660) 2025-03-31 11:07:07 +02:00
marcoStocchi 52de2e5949 tts : remove printfs (#12640)
* tts.cpp : llama tokens console output is done using LOG_INF instead of printf(). Therefore the options '--log-disable' and '--log-file' have now uniform impact on all output.
2025-03-31 11:20:30 +03:00
Sigbjørn Skjæret 2c3f8b850a llama : support BailingMoE (Ling) (#12634) 2025-03-30 22:21:03 +02:00
Georgi Gerganov 4663bd353c metal : use constexpr in FA kernels + fix typedef (#12659)
* metal : use constexpr in FA kernels

ggml-ci

* cont

ggml-ci

* cont : fix typedef

ggml-ci
2025-03-30 22:04:04 +03:00
Juyoung Suk b3de7cac73 llama : add Trillion 7B model support (#12556)
* Support Trillion 7B

* Update llama.h

* Update llama.h

* Update llama-vocab.cpp for Trillion

* Update llama-vocab.cpp
2025-03-30 20:38:33 +02:00
Sergei Vorobyov 7242dd9675 llama-chat : Add Yandex instruct model template support (#12621)
* add yandex template

* update yandex chat template

* fix tests

* adjust chat template

* fix style

* fix tool macro in template

* add clarify comment

---------

Co-authored-by: Sergei Vorobev <serv01@yandex-team.ru>
Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
2025-03-30 20:12:03 +02:00
R0CKSTAR 492d7f1ff7 musa: fix all warnings, re-enable -DLLAMA_FATAL_WARNINGS=ON in ci and update doc (#12611)
* musa: fix all warnings

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>

* musa: enable -DLLAMA_FATAL_WARNINGS=ON in run.sh

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>

* musa: update ci doc (install ccache)

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>

* fix Windows build issue

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>

* Address review comments

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>

* Address review comments

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>

---------

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
2025-03-30 10:59:38 +02:00
Georgi Gerganov d3f1f0acfb sync : ggml
ggml-ci
2025-03-30 08:33:31 +03:00
Xuan-Son Nguyen 360dc22c00 cpu : rm unused variable (ggml/1166) 2025-03-30 08:33:31 +03:00
cmdr2 a62d7fa7a9 cpu: de-duplicate some of the operators and refactor (ggml/1144)
* cpu: de-duplicate some of the operators and refactor

* Fix PR comments

* Fix PR comments
2025-03-30 08:33:31 +03:00
Daniel Bevenius e408d4351a ggml : add logging for native build options/vars (whisper/2935)
This commit adds debug level logging for the native build options and
variables to ggml/CMakeLists.txt.

The motivation for this is that it can be useful to see the effective
result of `GGML_NATIVE`, `GGML_NATIVE_DEFAULT`, and `INS_ENB` for a
cmake build. I've found myself adding similar logging a few times now,
so I thought it might be a good idea to add this.

Example output, specifying `-DCMAKE_MESSAGE_LOG_LEVEL=DEBUG` when
running cmake produces the following output:
```console
-- GGML_NATIVE         : OFF
-- GGML_NATIVE_DEFAULT : OFF
-- INS_ENB             : OFF
```
2025-03-30 08:33:31 +03:00
Daniel Bevenius 3891e183c6 examples : command.wasm updates (whisper/2904)
This commit updates the command.wasm example by adding a server.py script to make it easy to start a local http server to try out the example, updates the build instructions, and also addresses some of the compiler warnings that were being generated.

* emscripten : fix TOTAL_STACK for wasm

This commit moves the TOTAL_STACK setting from the compile flags to the
linker flags. This is because the TOTAL_STACK setting is a linker
setting.

The motivation for this change is that currently the following warnings
are generated when building:
```console
em++: warning: linker setting ignored during compilation: 'TOTAL_STACK' [-Wunused-command-line-argument]
em++: warning: linker setting ignored during compilation: 'TOTAL_STACK' [-Wunused-command-line-argument]
em++: warning: linker setting ignored during compilation: 'TOTAL_STACK' [-Wunused-command-line-argument]
em++: warning: linker setting ignored during compilation: 'TOTAL_STACK' [-Wunused-command-line-argument]
em++: warning: linker setting ignored during compilation: 'TOTAL_STACK' [-Wunused-command-line-argument]
em++: warning: linker setting ignored during compilation: 'TOTAL_STACK' [-Wunused-command-line-argument]
```

* examples : suppress C++17 deprecation warning for std::codecvt_utf8

This commit suppresses the C++17 deprecation warning for
std::codecvt_utf8 similar to what is done in
examples/talk-llama/unicode.cpp.

The motivation for this change is to suppress these warnings:
```console
/Users/danbev/work/ai/whisper-work/examples/common.cpp:251:31: warning: 'codecvt_utf8<wchar_t>' is deprecated [-Wdeprecated-declarations]
  251 |     std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
      |                               ^
/Users/danbev/work/wasm/emsdk/upstream/emscripten/cache/sysroot/include/c++/v1/codecvt:193:28: note: 'codecvt_utf8<wchar_t>' has been explicitly marked deprecated here
  193 | class _LIBCPP_TEMPLATE_VIS _LIBCPP_DEPRECATED_IN_CXX17 codecvt_utf8 : public __codecvt_utf8<_Elem> {
      |                            ^
/Users/danbev/work/wasm/emsdk/upstream/emscripten/cache/sysroot/include/c++/v1/__config:723:41: note: expanded from macro '_LIBCPP_DEPRECATED_IN_CXX17'
  723 | #    define _LIBCPP_DEPRECATED_IN_CXX17 _LIBCPP_DEPRECATED
      |                                         ^
/Users/danbev/work/wasm/emsdk/upstream/emscripten/cache/sysroot/include/c++/v1/__config:688:49: note: expanded from macro '_LIBCPP_DEPRECATED'
  688 | #      define _LIBCPP_DEPRECATED __attribute__((__deprecated__))
      |                                                 ^
/Users/danbev/work/ai/whisper-work/examples/common.cpp:251:10: warning: 'wstring_convert<std::codecvt_utf8<wchar_t>>' is deprecated [-Wdeprecated-declarations]
  251 |     std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
      |          ^
/Users/danbev/work/wasm/emsdk/upstream/emscripten/cache/sysroot/include/c++/v1/locale:3145:28: note: 'wstring_convert<std::codecvt_utf8<wchar_t>>' has been explicitly marked deprecated here
 3145 | class _LIBCPP_TEMPLATE_VIS _LIBCPP_DEPRECATED_IN_CXX17 wstring_convert {
      |                            ^
/Users/danbev/work/wasm/emsdk/upstream/emscripten/cache/sysroot/include/c++/v1/__config:723:41: note: expanded from macro '_LIBCPP_DEPRECATED_IN_CXX17'
  723 | #    define _LIBCPP_DEPRECATED_IN_CXX17 _LIBCPP_DEPRECATED
      |                                         ^
/Users/danbev/work/wasm/emsdk/upstream/emscripten/cache/sysroot/include/c++/v1/__config:688:49: note: expanded from macro '_LIBCPP_DEPRECATED'
  688 | #      define _LIBCPP_DEPRECATED __attribute__((__deprecated__))
      |                                                 ^
/Users/danbev/work/ai/whisper-work/examples/common.cpp:257:31: warning: 'codecvt_utf8<wchar_t>' is deprecated [-Wdeprecated-declarations]
  257 |     std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
      |                               ^
/Users/danbev/work/wasm/emsdk/upstream/emscripten/cache/sysroot/include/c++/v1/codecvt:193:28: note: 'codecvt_utf8<wchar_t>' has been explicitly marked deprecated here
  193 | class _LIBCPP_TEMPLATE_VIS _LIBCPP_DEPRECATED_IN_CXX17 codecvt_utf8 : public __codecvt_utf8<_Elem> {
      |                            ^
/Users/danbev/work/wasm/emsdk/upstream/emscripten/cache/sysroot/include/c++/v1/__config:723:41: note: expanded from macro '_LIBCPP_DEPRECATED_IN_CXX17'
  723 | #    define _LIBCPP_DEPRECATED_IN_CXX17 _LIBCPP_DEPRECATED
      |                                         ^
/Users/danbev/work/wasm/emsdk/upstream/emscripten/cache/sysroot/include/c++/v1/__config:688:49: note: expanded from macro '_LIBCPP_DEPRECATED'
  688 | #      define _LIBCPP_DEPRECATED __attribute__((__deprecated__))
      |                                                 ^
/Users/danbev/work/ai/whisper-work/examples/common.cpp:257:10: warning: 'wstring_convert<std::codecvt_utf8<wchar_t>>' is deprecated [-Wdeprecated-declarations]
  257 |     std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
      |          ^
/Users/danbev/work/wasm/emsdk/upstream/emscripten/cache/sysroot/include/c++/v1/locale:3145:28: note: 'wstring_convert<std::codecvt_utf8<wchar_t>>' has been explicitly marked deprecated here
 3145 | class _LIBCPP_TEMPLATE_VIS _LIBCPP_DEPRECATED_IN_CXX17 wstring_convert {
      |                            ^
/Users/danbev/work/wasm/emsdk/upstream/emscripten/cache/sysroot/include/c++/v1/__config:723:41: note: expanded from macro '_LIBCPP_DEPRECATED_IN_CXX17'
  723 | #    define _LIBCPP_DEPRECATED_IN_CXX17 _LIBCPP_DEPRECATED
      |                                         ^
/Users/danbev/work/wasm/emsdk/upstream/emscripten/cache/sysroot/include/c++/v1/__config:688:49: note: expanded from macro '_LIBCPP_DEPRECATED'
  688 | #      define _LIBCPP_DEPRECATED __attribute__((__deprecated__))
      |                                                 ^
4 warnings generated.
```

* ggml : suppress double-promotion warning in GGML_F16x4_REDUCE

This commit adds a cast to `ggml_float` in the `GGML_F16x4_REDUCE` macro
to suppress a double-promotion warning.

Currently the following warning is generated when compiling the
command.wasm example:
```console
/whisper-work/src/ggml-cpu/ggml-cpu.c:1592:5: warning: implicit conversion increases floating-point precision: 'float' to 'ggml_float' (aka 'double') [-Wdouble-promotion]
 1592 |     GGML_F16_VEC_REDUCE(sumf, sum);
      |     ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/Users/danbev/work/ai/whisper-work/src/ggml-cpu/ggml-cpu.c:932:37: note: expanded from macro 'GGML_F16_VEC_REDUCE'
  932 | #define GGML_F16_VEC_REDUCE         GGML_F16x4_REDUCE
      |                                     ^
/Users/danbev/work/ai/whisper-work/src/ggml-cpu/ggml-cpu.c:920:44: note: expanded from macro 'GGML_F16x4_REDUCE'
  918 |     res = wasm_f32x4_extract_lane(x[0], 0) +       \
      |         ~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  919 |           wasm_f32x4_extract_lane(x[0], 1) +       \
      |           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  920 |           wasm_f32x4_extract_lane(x[0], 2) +       \
      |           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~
  921 |           wasm_f32x4_extract_lane(x[0], 3);        \
      |           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/whisper-work/src/ggml-cpu/ggml-cpu.c:1640:9: warning: implicit conversion increases floating-point precision: 'float' to 'ggml_float' (aka 'double') [-Wdouble-promotion]
 1640 |         GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
      |         ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/Users/danbev/work/ai/whisper-work/src/ggml-cpu/ggml-cpu.c:932:37: note: expanded from macro 'GGML_F16_VEC_REDUCE'
  932 | #define GGML_F16_VEC_REDUCE         GGML_F16x4_REDUCE
      |                                     ^
/Users/danbev/work/ai/whisper-work/src/ggml-cpu/ggml-cpu.c:920:44: note: expanded from macro 'GGML_F16x4_REDUCE'
  918 |     res = wasm_f32x4_extract_lane(x[0], 0) +       \
      |         ~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  919 |           wasm_f32x4_extract_lane(x[0], 1) +       \
      |           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  920 |           wasm_f32x4_extract_lane(x[0], 2) +       \
      |           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~
  921 |           wasm_f32x4_extract_lane(x[0], 3);        \
      |           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2 warnings generated.
```
wasm_f32x4_extract_lane returns a 32-bit float and this is what the
addition is performed on. But there is an implicit conversion from
32-bit float to 64-bit double when the result is assigned to `res`,
which is of type `ggml_float`. My understanding here is that this is
intentional and adding a cast to `ggml_float` should suppress the
warning.

* emscripten : add -Wno-deprecated to for emscripten

This commit adds -Wno-deprecated to the CMAKE_CXX_FLAGS for emscripten
builds.

The motivation for this is that currently there a number of warnings
generated like the following:
```console
warning: JS library symbol '$print' is deprecated. Please open a bug if you have a continuing need for this symbol [-Wdeprecated]
warning: JS library symbol '$printErr' is deprecated. Please open a bug if you have a continuing need for this symbol [-Wdeprecated]
em++: warning: warnings in JS library compilation [-Wjs-compiler]
em++: warning: linker setting ignored during compilation: 'ENVIRONMENT' [-Wunused-command-line-argument]
warning: JS library symbol '$print' is deprecated. Please open a bug if you have a continuing need for this symbol [-Wdeprecated]
warning: JS library symbol '$printErr' is deprecated. Please open a bug if you have a continuing need for this symbol [-Wdeprecated]
em++: warning: warnings in JS library compilation [-Wjs-compiler]
warning: JS library symbol '$print' is deprecated. Please open a bug if you have a continuing need for this symbol [-Wdeprecated]
warning: JS library symbol '$printErr' is deprecated. Please open a bug if you have a continuing need for this symbol [-Wdeprecated]
em++: warning: warnings in JS library compilation [-Wjs-compiler]
em++: warning: linker setting ignored during compilation: 'ENVIRONMENT' [-Wunused-command-line-argument]
em++: warning: linker setting ignored during compilation: 'ENVIRONMENT' [-Wunused-command-line-argument]
```

The downside of this is that we might miss other deprecation warnings
in the future so I'm not sure if this is acceptable. But it make the
wasm examples cleaner without the warnings.

* examples : fix tautological-compare warning in stb_vorbis.c [no ci]

This commit applies a fix to address a tautological-compare warning
in stb_vorbis.c.

The motivation for this is that currently the following warning is
generated when compiling the commmand-wasm example:
```console
/Users/danbev/work/ai/whisper-work/examples/stb_vorbis.c:1404:75: warning: pointer comparison always evaluates to false [-Wtautological-compare]
 1404 |       if (f->stream_start + loc >= f->stream_end || f->stream_start + loc < f->stream_start) {
      |                                                                           ^
1 warning generated.
```

This fix was taken from an open pull request on the stb repository
that addreses this issue:
https://github.com/nothings/stb/pull/1746

* squash! examples : update command.wasm instructions [no ci]

This commit adds a Python script to serve the the wasm examples build
in the `build-em` directory. Initially I thought that it would be enough
to start a simple python server but I did not notice that there was an
error in the browser console when I did that:
```console
command.js:1 Uncaught (in promise) DataCloneError: Failed to execute 'postMessage' on 'Worker': SharedArrayBuffer transfer requires self.crossOriginIsolated.
    at command.js:1:1206224
    at new Promise (<anonymous>)
    at loadWasmModuleToWorker (command.js:1:1204981)
    at Array.map (<anonymous>)
    at Object.loadWasmModuleToAllWorkers (command.js:1:1206428)
    at command.js:1:1204318
    at callRuntimeCallbacks (command.js:1:1202062)
    at preRun (command.js:1:6136)
    at run (command.js:1:1294094)
    at removeRunDependency (command.js:1:7046)
```
We need a few CORS headers to be set and in order hopefully make this
easy for users a Python script is added to the examples directory.
This should be able to server all the wasm examples provided they have
been built. command.wasm's README.md is updated to reflect this change.

* examples : remove unused functions

This commit removed the unused functions convert_to_utf8 and
convert_to_wstring from examples/common.cpp.

* Revert "examples : fix tautological-compare warning in stb_vorbis.c [no ci]"

This reverts commit 8e3c47d96141c7675c985562ebdc705e839e338a.

We should not make this change here and instead when the upstream PR is
merged we can sync with it.

Refs: https://github.com/ggerganov/whisper.cpp/issues/2784
2025-03-30 08:33:31 +03:00
Xuan-Son Nguyen af6ae1efb2 llama : fix non-causal mask for gemma 3 (#12615) 2025-03-30 00:07:37 +01:00
Djip007 0bb2919335 llama : change cpu_buft_list order: ACCEL -> GPU host -> CPU extra -> CPU (#12632)
this allow to use GPU host when possible over CPU repack.
this have the same effect to resolve this issues (#12498) without
completely disable CPU extra buffer.

Co-authored-by: philou <philou@framework>
2025-03-29 14:07:37 +01:00
Jay a69f846351 cmake : fix ccache conflict (#12522)
If users already set CMAKE_C_COMPILER_LAUNCHER globally, setting it in
cmake again will lead to conflict and compile fail.

Signed-off-by: Jay <BusyJay@users.noreply.github.com>
2025-03-29 11:04:58 +01:00
hipudding d07a0d7a79 CANN : remove clang-format in ggml-cann (#12607) 2025-03-29 11:03:28 +01:00
Sigbjørn Skjæret 3714c3ee1a llama : fix incorrect Qwen2Moe ffn_moe_out graph callback (#12631) 2025-03-28 22:13:02 +01:00
Georgi Gerganov b4ae50810e metal : improve FA + improve MoE (#12612)
* ggml : FA with different K, V head sizes (CPU)

ggml-ci

* metal : add FA with HS=192

* metal : extend FA to support different K and V head sizes

ggml-ci

* metal : add FA vector kernels for heads K 192 and V 128

ggml-ci

* ggml : restrict op on other backends to equal head sizes

ggml-ci

* metal : optimize FA-vec kernel

ggml-ci

* metal : FA remove mq registers

* metal : improve MoE mul_mat_id condition

ggml-ci

* metal : fix comments + remove unnecessary addition

ggml-ci

* metal : avoid too much shared memory usage with mul_mat_id

ggml-ci
2025-03-28 20:21:59 +02:00
Icenowy Zheng b86f600723 vulkan: fix coopmat shader generation when cross-compiling (#12272)
* vulkan: fix coopmat shader generation when cross-compiling

Previously the status of coopmat{,2} support isn't passed to the
vulkan-shaders-gen project building on the host, which leads to build
failure because of the cross-compiling code expecting coopmat{,2}
shaders that didn't get generated.

Fix this by passing the coopmat{,2} support status to vulkan-shaders
subproject.

Signed-off-by: Icenowy Zheng <uwu@icenowy.me>

* Only call coop-mat shaders once

* Fix whitespace

---------

Signed-off-by: Icenowy Zheng <uwu@icenowy.me>
Co-authored-by: bandoti <141645996+bandoti@users.noreply.github.com>
2025-03-28 14:51:06 -03:00
Johannes Gäßler dd373dd3bf llama: fix error on bad grammar (#12628) 2025-03-28 18:08:52 +01:00
Benson Wong 5d01670266 server : include speculative decoding stats when timings_per_token is enabled (#12603)
* Include speculative decoding stats when timings_per_token is true

New fields added to the `timings` object:

  - draft_n           : number of draft tokens generated
  - draft_accepted_n  : number of draft tokens accepted
  - draft_accept_ratio: ratio of accepted/generated

* Remove redundant draft_accept_ratio var

* add draft acceptance rate to server console output
2025-03-28 10:05:44 +02:00
Radoslav Gerganov ef03229ff4 rpc : update README for cache usage (#12620) 2025-03-28 09:44:13 +02:00
amritahs-ibm 13731766db llamafile : ppc64le GEMV forwarding for FP32. (#12594)
This patch enables usage of MMA when one of the
dimensions of the matrix(ie either M or N) is 1. This
is useful in case of token generation where N < 2.

The concept of 'GEMV Forwarding' is used where when one
of the matrix has a single row/column, the elements are
broadcasted, instead of using packing routine to prepack
the matrix elements.

This change results in 5% - 15% improvement in total
speed(ie all tokens/total time), across various batch
sizes. This is in comparision with the corresponding
dot product implementation.

The patch is tested with FP32 models of Meta-Lllama-3-8B,
Mistral-7B, Llama-2-7B-chat-hf on a IBM POWER10 machine.

Signed-off-by: Amrita H S <amritahs@linux.vnet.ibm.com>
2025-03-28 09:43:22 +02:00
Radoslav Gerganov ab6ab8f809 rpc : send hash when tensor data is above some fixed threshold (#12496)
* rpc : send hash when tensor data is above some fixed threshold

ref #10095

* rpc : put cache under $HOME/.cache/llama.cpp

* try to fix win32 build

* another try to fix win32 build

* remove llama as dependency
2025-03-28 08:18:04 +02:00
Piotr 2099a9d5db server : Support listening on a unix socket (#12613)
* server : Bump cpp-httplib to include AF_UNIX windows support

Signed-off-by: Piotr Stankiewicz <piotr.stankiewicz@docker.com>

* server : Allow running the server example on a unix socket

Signed-off-by: Piotr Stankiewicz <piotr.stankiewicz@docker.com>

---------

Signed-off-by: Piotr Stankiewicz <piotr.stankiewicz@docker.com>
2025-03-27 23:41:04 +01:00
Georgi Gerganov 2969019837 media : add SVG logo [no ci] (#12616) 2025-03-27 23:09:05 +02:00
lhez 5dec47dcd4 opencl: add multi and vision rope, gelu_quick and im2col (#12600)
* opencl: add `im2col`

* opencl: add `gelu_quick`

* opencl: add mrope

* opencl: add vision rope
2025-03-27 08:08:08 -07:00
Si1w f125b8dccf llama : add PLM GGUF Conversion & Inference Support (#12457)
* add edgellm model arch[conversation feature doesn't work]

* remove output.weight layer for edgellm arch

* [Model] update the name of the model

* update the name of model arch in convert gguf

* [Model] Refarctor the model arch into llama-model

* [Bug] Fix the bug in create attn kv

* [Code] Fix editorconfig erros

* [Code] Remove Trailing whitespace

* [Code] Remove Trailing whitespace

* [Code] Change the order of model arch in list

* [Code] Fix flake8 Lint errors

* Remove trailing white space

* [Code] Remove  call in model arch
2025-03-27 12:49:15 +02:00
HighDoping 953c2a62cf model : restore support for T5Encoder (#12590) 2025-03-27 11:43:33 +01:00
Csaba Kecskemeti d5c6309d91 convert : Support Qwen2_5_VLForConditionalGeneration (#12595) 2025-03-27 11:11:23 +01:00
Georgi Gerganov 029c693fdc sync : ggml
ggml-ci
2025-03-27 10:09:29 +02:00
Georgi Gerganov 771d84371c scripts : update sync + fix cmake merge
ggml-ci
2025-03-27 10:09:29 +02:00
Georgi Gerganov df0665a483 sync : ggml
ggml-ci
2025-03-27 09:04:38 +02:00
Georgi Gerganov 0306aad1ca cmake : sync/merge PowerPC build commands (#0) 2025-03-27 09:04:38 +02:00
91 changed files with 4403 additions and 4003 deletions
+2
View File
@@ -112,6 +112,8 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
- [x] [RWKV-6](https://github.com/BlinkDL/RWKV-LM)
- [x] [QRWKV-6](https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1)
- [x] [GigaChat-20B-A3B](https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct)
- [X] [Trillion-7B-preview](https://huggingface.co/trillionlabs/Trillion-7B-preview)
- [x] [Ling models](https://huggingface.co/collections/inclusionAI/ling-67c51c85b34a7ea0aba94c32)
#### Multimodal
+1 -1
View File
@@ -60,7 +60,7 @@ docker run --privileged -it \
Inside the container, execute the following commands:
```bash
apt update -y && apt install -y bc cmake git python3.10-venv time unzip wget
apt update -y && apt install -y bc cmake ccache git python3.10-venv time unzip wget
git config --global --add safe.directory /ws
GG_BUILD_MUSA=1 bash ./ci/run.sh /ci-results /ci-cache
```
+1 -1
View File
@@ -69,7 +69,7 @@ fi
if [ ! -z ${GG_BUILD_MUSA} ]; then
# Use qy1 by default (MTT S80)
MUSA_ARCH=${MUSA_ARCH:-21}
CMAKE_EXTRA="-DGGML_MUSA=ON -DMUSA_ARCHITECTURES=${MUSA_ARCH}"
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_MUSA=ON -DMUSA_ARCHITECTURES=${MUSA_ARCH}"
fi
## helpers
+1 -1
View File
@@ -1979,7 +1979,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
add_opt(common_arg(
{"--host"}, "HOST",
string_format("ip address to listen (default: %s)", params.hostname.c_str()),
string_format("ip address to listen, or bind to an UNIX socket if the address ends with .sock (default: %s)", params.hostname.c_str()),
[](common_params & params, const std::string & value) {
params.hostname = value;
}
+3
View File
@@ -208,6 +208,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
trigger_patterns_c.data(), trigger_patterns_c.size(),
trigger_tokens.data(), trigger_tokens.size())
: llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
if (!grmr) {
return nullptr;
}
}
auto * result = new common_sampler {
+132 -1
View File
@@ -708,6 +708,12 @@ class Model:
if chkhsh == "7dec86086fcc38b66b7bc1575a160ae21cf705be7718b9d5598190d7c12db76f":
# ref: https://huggingface.co/UW/OLMo2-8B-SuperBPE-t180k
res = "superbpe"
if chkhsh == "1994ffd01900cfb37395608534236ecd63f2bd5995d6cb1004dda1af50240f15":
# ref: https://huggingface.co/trillionlabs/Trillion-7B-preview
res = "trillion"
if chkhsh == "96a5f08be6259352137b512d4157e333e21df7edd3fcd152990608735a65b224":
# ref: https://huggingface.co/inclusionAI/Ling-lite
res = "bailingmoe"
if res is None:
logger.warning("\n")
@@ -2269,7 +2275,7 @@ class Qwen2Model(Model):
self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"])
@Model.register("Qwen2VLForConditionalGeneration")
@Model.register("Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration")
class Qwen2VLModel(Model):
model_arch = gguf.MODEL_ARCH.QWEN2VL
@@ -4419,6 +4425,29 @@ class DeepseekV2Model(Model):
raise ValueError(f"Unprocessed experts: {experts}")
@Model.register("PLMForCausalLM")
class PLMModel(Model):
model_arch = gguf.MODEL_ARCH.PLM
def set_vocab(self):
self._set_vocab_gpt2()
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"])
self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
self.gguf_writer.add_value_length(hparams["v_head_dim"])
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
return [(self.map_tensor_name(name), data_torch)]
def prepare_tensors(self):
super().prepare_tensors()
@Model.register("T5WithLMHeadModel")
@Model.register("T5ForConditionalGeneration")
@Model.register("MT5ForConditionalGeneration")
@@ -5107,6 +5136,108 @@ class GraniteMoeModel(GraniteModel):
return super().modify_tensors(data_torch, name, bid)
@Model.register("BailingMoeForCausalLM")
class BailingMoeModel(Model):
model_arch = gguf.MODEL_ARCH.BAILINGMOE
def set_vocab(self):
self._set_vocab_gpt2()
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
if "head_dim" in hparams:
rope_dim = hparams["head_dim"]
else:
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
self.gguf_writer.add_rope_dimension_count(rope_dim)
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
self.gguf_writer.add_expert_weights_scale(1.0)
self.gguf_writer.add_expert_count(hparams["num_experts"])
self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"])
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
_experts: list[dict[str, Tensor]] | None = None
@staticmethod
def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
if n_head_kv is not None and n_head != n_head_kv:
n_head = n_head_kv
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
.swapaxes(1, 2)
.reshape(weights.shape))
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams.get("num_key_value_heads")
n_embd = self.hparams["hidden_size"]
head_dim = self.hparams.get("head_dim", n_embd // n_head)
output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT)
if name.endswith("attention.dense.weight"):
return [(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT, bid), data_torch)]
elif name.endswith("query_key_value.weight"):
q, k, v = data_torch.split([n_head * head_dim, n_kv_head * head_dim, n_kv_head * head_dim], dim=-2)
return [
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), BailingMoeModel.permute(q, n_head, n_head)),
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), BailingMoeModel.permute(k, n_head, n_kv_head)),
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), v)
]
elif name.find("mlp.experts") != -1:
n_experts = self.hparams["num_experts"]
assert bid is not None
tensors: list[tuple[str, Tensor]] = []
if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]
self._experts[bid][name] = data_torch
if len(self._experts[bid]) >= n_experts * 3:
# merge the experts into a single 3d tensor
for w_name in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []
for xid in range(n_experts):
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]
data_torch = torch.stack(datas, dim=0)
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
new_name = self.map_tensor_name(merged_name)
tensors.append((new_name, data_torch))
return tensors
new_name = self.map_tensor_name(name)
if new_name == output_name and self.hparams.get("norm_head"):
data_torch = data_torch.float()
data_torch /= torch.norm(data_torch, p=2, dim=0, keepdim=True) + 1e-7
return [(new_name, data_torch)]
def prepare_tensors(self):
super().prepare_tensors()
if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")
@Model.register("ChameleonForConditionalGeneration")
@Model.register("ChameleonForCausalLM") # obsolete
class ChameleonModel(Model):
+2
View File
@@ -111,6 +111,8 @@ models = [
{"name": "deepseek-r1-qwen", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"},
{"name": "gpt-4o", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Xenova/gpt-4o", },
{"name": "superbpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/UW/OLMo2-8B-SuperBPE-t180k", },
{"name": "trillion", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/trillionlabs/Trillion-7B-preview", },
{"name": "bailingmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-lite", },
]
+5 -3
View File
@@ -1396,14 +1396,16 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
const int n_kv = gguf_get_n_kv(ctx);
const int ftype = get_u32(ctx, KEY_FTYPE);
const std::string ftype_str = get_ftype(ftype);
const int idx_desc = get_key_idx(ctx, KEY_DESCRIPTION);
const std::string description = gguf_get_val_str(ctx, idx_desc);
const int idx_name = gguf_find_key(ctx, KEY_NAME);
if (idx_name != -1) { // make name optional temporarily as some of the uploaded models missing it due to a bug
const std::string name = gguf_get_val_str(ctx, idx_name);
LOG_INF("%s: model name: %s\n", __func__, name.c_str());
}
LOG_INF("%s: description: %s\n", __func__, description.c_str());
const int idx_desc = get_key_idx(ctx, KEY_DESCRIPTION);
if (idx_desc != -1) { // ditto
const std::string description = gguf_get_val_str(ctx, idx_desc);
LOG_INF("%s: description: %s\n", __func__, description.c_str());
}
LOG_INF("%s: GGUF version: %d\n", __func__, gguf_get_version(ctx));
LOG_INF("%s: alignment: %zu\n", __func__, gguf_get_alignment(ctx));
LOG_INF("%s: n_tensors: %d\n", __func__, n_tensors);
+4 -2
View File
@@ -1,2 +1,4 @@
add_executable(rpc-server rpc-server.cpp)
target_link_libraries(rpc-server PRIVATE ggml llama)
set(TARGET rpc-server)
add_executable(${TARGET} rpc-server.cpp)
target_link_libraries(${TARGET} PRIVATE ggml)
target_compile_features(${TARGET} PRIVATE cxx_std_17)
+11
View File
@@ -72,3 +72,14 @@ $ bin/llama-cli -m ../models/tinyllama-1b/ggml-model-f16.gguf -p "Hello, my name
This way you can offload model layers to both local and remote devices.
### Local cache
The RPC server can use a local cache to store large tensors and avoid transferring them over the network.
This can speed up model loading significantly, especially when using large models.
To enable the cache, use the `-c` option:
```bash
$ bin/rpc-server -c
```
By default, the cache is stored in the `$HOME/.cache/llama.cpp/rpc` directory and can be controlled via the `LLAMA_CACHE` environment variable.
+140 -6
View File
@@ -1,3 +1,7 @@
#if defined(_MSC_VER)
#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
#endif
#include "ggml-cpu.h"
#ifdef GGML_USE_CUDA
@@ -18,26 +22,142 @@
#include "ggml-rpc.h"
#ifdef _WIN32
# define DIRECTORY_SEPARATOR '\\'
# include <locale>
# include <windows.h>
# include <fcntl.h>
# include <io.h>
#else
# define DIRECTORY_SEPARATOR '/'
# include <unistd.h>
# include <sys/stat.h>
#endif
#include <codecvt>
#include <string>
#include <stdio.h>
#include <vector>
#include <filesystem>
namespace fs = std::filesystem;
// NOTE: this is copied from common.cpp to avoid linking with libcommon
// returns true if successful, false otherwise
static bool fs_create_directory_with_parents(const std::string & path) {
#ifdef _WIN32
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
std::wstring wpath = converter.from_bytes(path);
// if the path already exists, check whether it's a directory
const DWORD attributes = GetFileAttributesW(wpath.c_str());
if ((attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_DIRECTORY)) {
return true;
}
size_t pos_slash = 0;
// process path from front to back, procedurally creating directories
while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) {
const std::wstring subpath = wpath.substr(0, pos_slash);
const wchar_t * test = subpath.c_str();
const bool success = CreateDirectoryW(test, NULL);
if (!success) {
const DWORD error = GetLastError();
// if the path already exists, ensure that it's a directory
if (error == ERROR_ALREADY_EXISTS) {
const DWORD attributes = GetFileAttributesW(subpath.c_str());
if (attributes == INVALID_FILE_ATTRIBUTES || !(attributes & FILE_ATTRIBUTE_DIRECTORY)) {
return false;
}
} else {
return false;
}
}
pos_slash += 1;
}
return true;
#else
// if the path already exists, check whether it's a directory
struct stat info;
if (stat(path.c_str(), &info) == 0) {
return S_ISDIR(info.st_mode);
}
size_t pos_slash = 1; // skip leading slashes for directory creation
// process path from front to back, procedurally creating directories
while ((pos_slash = path.find('/', pos_slash)) != std::string::npos) {
const std::string subpath = path.substr(0, pos_slash);
struct stat info;
// if the path already exists, ensure that it's a directory
if (stat(subpath.c_str(), &info) == 0) {
if (!S_ISDIR(info.st_mode)) {
return false;
}
} else {
// create parent directories
const int ret = mkdir(subpath.c_str(), 0755);
if (ret != 0) {
return false;
}
}
pos_slash += 1;
}
return true;
#endif // _WIN32
}
// NOTE: this is copied from common.cpp to avoid linking with libcommon
static std::string fs_get_cache_directory() {
std::string cache_directory = "";
auto ensure_trailing_slash = [](std::string p) {
// Make sure to add trailing slash
if (p.back() != DIRECTORY_SEPARATOR) {
p += DIRECTORY_SEPARATOR;
}
return p;
};
if (getenv("LLAMA_CACHE")) {
cache_directory = std::getenv("LLAMA_CACHE");
} else {
#ifdef __linux__
if (std::getenv("XDG_CACHE_HOME")) {
cache_directory = std::getenv("XDG_CACHE_HOME");
} else {
cache_directory = std::getenv("HOME") + std::string("/.cache/");
}
#elif defined(__APPLE__)
cache_directory = std::getenv("HOME") + std::string("/Library/Caches/");
#elif defined(_WIN32)
cache_directory = std::getenv("LOCALAPPDATA");
#endif // __linux__
cache_directory = ensure_trailing_slash(cache_directory);
cache_directory += "llama.cpp";
}
return ensure_trailing_slash(cache_directory);
}
struct rpc_server_params {
std::string host = "127.0.0.1";
int port = 50052;
size_t backend_mem = 0;
bool use_cache = false;
};
static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) {
fprintf(stderr, "Usage: %s [options]\n\n", argv[0]);
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help show this help message and exit\n");
fprintf(stderr, " -H HOST, --host HOST host to bind to (default: %s)\n", params.host.c_str());
fprintf(stderr, " -p PORT, --port PORT port to bind to (default: %d)\n", params.port);
fprintf(stderr, " -m MEM, --mem MEM backend memory size (in MB)\n");
fprintf(stderr, " -h, --help show this help message and exit\n");
fprintf(stderr, " -H HOST, --host HOST host to bind to (default: %s)\n", params.host.c_str());
fprintf(stderr, " -p PORT, --port PORT port to bind to (default: %d)\n", params.port);
fprintf(stderr, " -m MEM, --mem MEM backend memory size (in MB)\n");
fprintf(stderr, " -c, --cache enable local file cache\n");
fprintf(stderr, "\n");
}
@@ -58,6 +178,8 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params &
if (params.port <= 0 || params.port > 65535) {
return false;
}
} else if (arg == "-c" || arg == "--cache") {
params.use_cache = true;
} else if (arg == "-m" || arg == "--mem") {
if (++i >= argc) {
return false;
@@ -164,8 +286,20 @@ int main(int argc, char * argv[]) {
} else {
get_backend_memory(&free_mem, &total_mem);
}
printf("Starting RPC server on %s, backend memory: %zu MB\n", endpoint.c_str(), free_mem / (1024 * 1024));
ggml_backend_rpc_start_server(backend, endpoint.c_str(), free_mem, total_mem);
const char * cache_dir = nullptr;
std::string cache_dir_str = fs_get_cache_directory() + "rpc/";
if (params.use_cache) {
if (!fs_create_directory_with_parents(cache_dir_str)) {
fprintf(stderr, "Failed to create cache directory: %s\n", cache_dir_str.c_str());
return 1;
}
cache_dir = cache_dir_str.c_str();
}
printf("Starting RPC server\n");
printf(" endpoint : %s\n", endpoint.c_str());
printf(" local cache : %s\n", cache_dir ? cache_dir : "n/a");
printf(" backend memory : %zu MB\n", free_mem / (1024 * 1024));
ggml_backend_rpc_start_server(backend, endpoint.c_str(), cache_dir, free_mem, total_mem);
ggml_backend_free(backend);
return 0;
}
+314 -248
View File
File diff suppressed because it is too large Load Diff
+57 -8
View File
@@ -489,8 +489,12 @@ struct result_timings {
double predicted_per_token_ms;
double predicted_per_second;
// Optional speculative metrics - only included when > 0
int32_t draft_n = 0;
int32_t draft_n_accepted = 0;
json to_json() const {
return {
json base = {
{"prompt_n", prompt_n},
{"prompt_ms", prompt_ms},
{"prompt_per_token_ms", prompt_per_token_ms},
@@ -501,6 +505,13 @@ struct result_timings {
{"predicted_per_token_ms", predicted_per_token_ms},
{"predicted_per_second", predicted_per_second},
};
if (draft_n > 0) {
base["draft_n"] = draft_n;
base["draft_n_accepted"] = draft_n_accepted;
}
return base;
}
};
@@ -1299,6 +1310,10 @@ struct server_slot {
std::function<void(int)> callback_on_release;
// Speculative decoding stats
int32_t n_draft_total = 0; // Total draft tokens generated
int32_t n_draft_accepted = 0; // Draft tokens actually accepted
void reset() {
SLT_DBG(*this, "%s", "\n");
@@ -1315,6 +1330,10 @@ struct server_slot {
generated_tokens.clear();
generated_token_probs.clear();
// clear speculative decoding stats
n_draft_total = 0;
n_draft_accepted = 0;
}
bool is_non_causal() const {
@@ -1381,6 +1400,12 @@ struct server_slot {
timings.predicted_per_token_ms = t_token_generation / n_decoded;
timings.predicted_per_second = 1e3 / t_token_generation * n_decoded;
// Add speculative metrics
if (n_draft_total > 0) {
timings.draft_n = n_draft_total;
timings.draft_n_accepted = n_draft_accepted;
}
return timings;
}
@@ -1428,6 +1453,15 @@ struct server_slot {
t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second,
t_token_generation, n_decoded, t_gen, n_gen_second,
t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded);
if (n_draft_total > 0) {
const float draft_ratio = (float) n_draft_accepted / n_draft_total;
SLT_INF(*this,
"\n"
"draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n",
draft_ratio, n_draft_accepted, n_draft_total
);
}
}
json to_json() const {
@@ -3290,6 +3324,9 @@ struct server_context {
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);
// keep track of total number of tokens generated in the draft
slot.n_draft_total += draft.size();
// ignore small drafts
if (slot.params.speculative.n_min > (int) draft.size()) {
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min);
@@ -3315,6 +3352,9 @@ struct server_context {
slot.n_past += ids.size();
slot.n_decoded += ids.size();
// update how many tokens out of draft was accepted
slot.n_draft_accepted += ids.size() - 1;
slot.cache_tokens.push_back(id);
slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
@@ -4459,15 +4499,24 @@ int main(int argc, char ** argv) {
llama_backend_free();
};
// bind HTTP listen port
bool was_bound = false;
if (params.port == 0) {
int bound_port = svr->bind_to_any_port(params.hostname);
if ((was_bound = (bound_port >= 0))) {
params.port = bound_port;
}
if (string_ends_with(std::string(params.hostname), ".sock")) {
LOG_INF("%s: setting address family to AF_UNIX\n", __func__);
svr->set_address_family(AF_UNIX);
// bind_to_port requires a second arg, any value other than 0 should
// simply get ignored
was_bound = svr->bind_to_port(params.hostname, 8080);
} else {
was_bound = svr->bind_to_port(params.hostname, params.port);
LOG_INF("%s: binding port with default address family\n", __func__);
// bind HTTP listen port
if (params.port == 0) {
int bound_port = svr->bind_to_any_port(params.hostname);
if ((was_bound = (bound_port >= 0))) {
params.port = bound_port;
}
} else {
was_bound = svr->bind_to_port(params.hostname, params.port);
}
}
if (!was_bound) {
+5 -3
View File
@@ -699,11 +699,13 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
const std::string voice_data = audio_data;
auto tmp = common_tokenize(vocab, voice_data, false, true);
printf("\n\n");
std::ostringstream tokens_oss;
for (size_t i = 0; i < tmp.size(); ++i) {
printf("%d, ", tmp[i]);
tokens_oss << tmp[i] << ", ";
}
printf("\n\n");
LOG_INF("\n\n%s: llama tokens: %s\n\n", __func__, tokens_oss.str().c_str());
prompt_add(prompt_inp, tmp);
#else
prompt_add(prompt_inp, llama_tokens {
+6 -1
View File
@@ -100,6 +100,10 @@ else()
set(INS_ENB ON)
endif()
message(DEBUG "GGML_NATIVE : ${GGML_NATIVE}")
message(DEBUG "GGML_NATIVE_DEFAULT : ${GGML_NATIVE_DEFAULT}")
message(DEBUG "INS_ENB : ${INS_ENB}")
option(GGML_CPU_HBM "ggml: use memkind for CPU HBM" OFF)
option(GGML_CPU_AARCH64 "ggml: use runtime weight conversion of Q4_0 to Q4_X_X" ON)
option(GGML_CPU_KLEIDIAI "ggml: use KleidiAI optimized kernels if applicable" OFF)
@@ -127,7 +131,8 @@ option(GGML_RV_ZFH "ggml: enable riscv zfh" OFF)
option(GGML_VXE "ggml: enable vxe" ON)
option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF)
set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM")
set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM")
set(GGML_CPU_POWERPC_CPUTYPE "" CACHE STRING "ggml: CPU type for PowerPC")
if (WIN32)
+22
View File
@@ -0,0 +1,22 @@
find_package(Git)
# the commit's SHA1
execute_process(COMMAND
"${GIT_EXECUTABLE}" describe --match=NeVeRmAtCh --always --abbrev=8
WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}"
OUTPUT_VARIABLE GIT_SHA1
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
# the date of the commit
execute_process(COMMAND
"${GIT_EXECUTABLE}" log -1 --format=%ad --date=local
WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}"
OUTPUT_VARIABLE GIT_DATE
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
# the subject of the commit
execute_process(COMMAND
"${GIT_EXECUTABLE}" log -1 --format=%s
WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}"
OUTPUT_VARIABLE GIT_COMMIT_SUBJECT
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
+1 -1
View File
@@ -5,7 +5,7 @@
set_and_check(GGML_INCLUDE_DIR "@PACKAGE_GGML_INCLUDE_INSTALL_DIR@")
set_and_check(GGML_LIB_DIR "@PACKAGE_GGML_LIB_INSTALL_DIR@")
set_and_check(GGML_BIN_DIR "@PACKAGE_GGML_BIN_INSTALL_DIR@")
#set_and_check(GGML_BIN_DIR "@PACKAGE_GGML_BIN_INSTALL_DIR@")
find_package(Threads REQUIRED)
+3 -1
View File
@@ -17,7 +17,9 @@ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const c
GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total);
GGML_BACKEND_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem);
GGML_BACKEND_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint,
const char * cache_dir,
size_t free_mem, size_t total_mem);
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void);
+5 -5
View File
@@ -1791,11 +1791,11 @@ extern "C" {
#define GGML_KQ_MASK_PAD 64
// q: [n_embd, n_batch, n_head, 1]
// k: [n_embd, n_kv, n_head_kv, 1]
// v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !!
// mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
// res: [n_embd, n_head, n_batch, 1] !! permuted !!
// q: [n_embd_k, n_batch, n_head, 1]
// k: [n_embd_k, n_kv, n_head_kv, 1]
// v: [n_embd_v, n_kv, n_head_kv, 1] !! not transposed !!
// mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
// res: [n_embd_v, n_head, n_batch, 1] !! permuted !!
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
struct ggml_context * ctx,
struct ggml_tensor * q,
+1 -1
View File
@@ -65,7 +65,7 @@ if (GGML_LTO)
endif()
endif()
if (GGML_CCACHE)
if (GGML_CCACHE AND NOT CMAKE_C_COMPILER_LAUNCHER AND NOT CMAKE_CXX_COMPILER_LAUNCHER)
find_program(GGML_CCACHE_FOUND ccache)
find_program(GGML_SCCACHE_FOUND sccache)
-168
View File
@@ -1,168 +0,0 @@
---
Language: Cpp
# BasedOnStyle: Google
AccessModifierOffset: -1
AlignAfterOpenBracket: Align
AlignConsecutiveMacros: false
AlignConsecutiveAssignments: false
AlignConsecutiveDeclarations: false
AlignEscapedNewlines: Left
AlignOperands: true
AlignTrailingComments: true
AllowAllArgumentsOnNextLine: true
AllowAllConstructorInitializersOnNextLine: true
AllowAllParametersOfDeclarationOnNextLine: true
AllowShortBlocksOnASingleLine: Never
AllowShortCaseLabelsOnASingleLine: false
AllowShortFunctionsOnASingleLine: All
AllowShortLambdasOnASingleLine: All
AllowShortIfStatementsOnASingleLine: WithoutElse
AllowShortLoopsOnASingleLine: true
AlwaysBreakAfterDefinitionReturnType: None
AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: true
AlwaysBreakTemplateDeclarations: Yes
BinPackArguments: true
BinPackParameters: true
BraceWrapping:
AfterCaseLabel: false
AfterClass: false
AfterControlStatement: false
AfterEnum: false
AfterFunction: false
AfterNamespace: false
AfterObjCDeclaration: false
AfterStruct: false
AfterUnion: false
AfterExternBlock: false
BeforeCatch: false
BeforeElse: false
IndentBraces: false
SplitEmptyFunction: true
SplitEmptyRecord: true
SplitEmptyNamespace: true
BreakBeforeBinaryOperators: None
BreakBeforeBraces: Attach
BreakBeforeInheritanceComma: false
BreakInheritanceList: BeforeColon
BreakBeforeTernaryOperators: true
BreakConstructorInitializersBeforeComma: false
BreakConstructorInitializers: BeforeColon
BreakAfterJavaFieldAnnotations: false
BreakStringLiterals: true
ColumnLimit: 80
CommentPragmas: '^ IWYU pragma:'
CompactNamespaces: false
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Cpp11BracedListStyle: true
DeriveLineEnding: true
DerivePointerAlignment: true
DisableFormat: false
ExperimentalAutoDetectBinPacking: false
FixNamespaceComments: true
ForEachMacros:
- foreach
- Q_FOREACH
- BOOST_FOREACH
IncludeBlocks: Regroup
IncludeCategories:
- Regex: '^<ext/.*\.h>'
Priority: 2
SortPriority: 0
- Regex: '^<.*\.h>'
Priority: 1
SortPriority: 0
- Regex: '^<.*'
Priority: 2
SortPriority: 0
- Regex: '.*'
Priority: 3
SortPriority: 0
IncludeIsMainRegex: '([-_](test|unittest))?$'
IncludeIsMainSourceRegex: ''
IndentCaseLabels: true
IndentGotoLabels: true
IndentPPDirectives: None
IndentWidth: 4
IndentWrappedFunctionNames: false
JavaScriptQuotes: Leave
JavaScriptWrapImports: true
KeepEmptyLinesAtTheStartOfBlocks: false
MacroBlockBegin: ''
MacroBlockEnd: ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
ObjCBinPackProtocolList: Never
ObjCBlockIndentWidth: 2
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: true
PenaltyBreakAssignment: 2
PenaltyBreakBeforeFirstCallParameter: 1
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
PenaltyBreakString: 1000
PenaltyBreakTemplateDeclaration: 10
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 200
PointerAlignment: Left
RawStringFormats:
- Language: Cpp
Delimiters:
- cc
- CC
- cpp
- Cpp
- CPP
- 'c++'
- 'C++'
CanonicalDelimiter: ''
BasedOnStyle: google
- Language: TextProto
Delimiters:
- pb
- PB
- proto
- PROTO
EnclosingFunctions:
- EqualsProto
- EquivToProto
- PARSE_PARTIAL_TEXT_PROTO
- PARSE_TEST_PROTO
- PARSE_TEXT_PROTO
- ParseTextOrDie
- ParseTextProtoOrDie
CanonicalDelimiter: ''
BasedOnStyle: google
ReflowComments: true
SortIncludes: true
SortUsingDeclarations: true
SpaceAfterCStyleCast: false
SpaceAfterLogicalNot: false
SpaceAfterTemplateKeyword: true
SpaceBeforeAssignmentOperators: true
SpaceBeforeCpp11BracedList: false
SpaceBeforeCtorInitializerColon: true
SpaceBeforeInheritanceColon: true
SpaceBeforeParens: ControlStatements
SpaceBeforeRangeBasedForLoopColon: true
SpaceInEmptyBlock: false
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 2
SpacesInAngles: false
SpacesInConditionalStatement: false
SpacesInContainerLiterals: true
SpacesInCStyleCastParentheses: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
SpaceBeforeSquareBrackets: false
Standard: Auto
StatementMacros:
- Q_UNUSED
- QT_REQUIRE_VERSION
TabWidth: 8
UseCRLF: false
UseTab: Never
...
+12 -6
View File
@@ -158,6 +158,12 @@ typedef sycl::half2 ggml_half2;
#endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP
#ifdef _MSC_VER
#define GGML_EXTENSION
#else // _MSC_VER
#define GGML_EXTENSION __extension__
#endif // _MSC_VER
#define QK4_0 32
typedef struct {
ggml_half d; // delta
@@ -167,7 +173,7 @@ static_assert(sizeof(block_q4_0) == sizeof(ggml_half) + QK4_0 / 2, "wrong q4_0 b
#define QK4_1 32
typedef struct {
union {
GGML_EXTENSION union {
struct {
ggml_half d; // delta
ggml_half m; // min
@@ -188,7 +194,7 @@ static_assert(sizeof(block_q5_0) == sizeof(ggml_half) + sizeof(uint32_t) + QK5_0
#define QK5_1 32
typedef struct {
union {
GGML_EXTENSION union {
struct {
ggml_half d; // delta
ggml_half m; // min
@@ -209,7 +215,7 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_half) + QK8_0, "wrong q8_0 block
#define QK8_1 32
typedef struct {
union {
GGML_EXTENSION union {
struct {
ggml_half d; // delta
ggml_half s; // d * sum(qs[i])
@@ -250,7 +256,7 @@ static_assert(sizeof(block_tq2_0) == sizeof(ggml_half) + QK_K / 4, "wrong tq2_0
typedef struct {
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
uint8_t qs[QK_K/4]; // quants
union {
GGML_EXTENSION union {
struct {
ggml_half d; // super-block scale for quantized scales
ggml_half dmin; // super-block scale for quantized mins
@@ -277,7 +283,7 @@ static_assert(sizeof(block_q3_K) == sizeof(ggml_half) + QK_K / 4 + QK_K / 8 + 12
// weight is represented as x = a * q + b
// Effectively 4.5 bits per weight
typedef struct {
union {
GGML_EXTENSION union {
struct {
ggml_half d; // super-block scale for quantized scales
ggml_half dmin; // super-block scale for quantized mins
@@ -294,7 +300,7 @@ static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_half) + K_SCALE_SIZE + QK_K/2,
// weight is represented as x = a * q + b
// Effectively 5.5 bits per weight
typedef struct {
union {
GGML_EXTENSION union {
struct {
ggml_half d; // super-block scale for quantized scales
ggml_half dmin; // super-block scale for quantized mins
+25 -14
View File
@@ -23,6 +23,11 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
ggml-cpu/amx/mmq.cpp
ggml-cpu/amx/mmq.h
ggml-cpu/ggml-cpu-impl.h
ggml-cpu/common.h
ggml-cpu/binary-ops.h
ggml-cpu/binary-ops.cpp
ggml-cpu/unary-ops.h
ggml-cpu/unary-ops.cpp
)
target_compile_features(${GGML_CPU_NAME} PRIVATE c_std_11 cxx_std_17)
@@ -289,23 +294,29 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
endif()
elseif ("${CMAKE_SYSTEM_PROCESSOR} " STREQUAL "ppc64le " OR "${CMAKE_SYSTEM_PROCESSOR} " STREQUAL "powerpc ")
message(STATUS "PowerPC detected")
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
file(READ "/proc/cpuinfo" POWER10_M)
elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "powerpc")
execute_process(COMMAND bash -c "prtconf |grep 'Implementation' | head -n 1" OUTPUT_VARIABLE POWER10_M)
endif()
if (GGML_NATIVE)
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
file(READ "/proc/cpuinfo" POWER10_M)
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "powerpc")
execute_process(COMMAND bash -c "prtconf |grep 'Implementation' | head -n 1" OUTPUT_VARIABLE POWER10_M)
endif()
string(REGEX MATCHALL "POWER *([0-9]+)" MATCHED_STRING "${POWER10_M}")
string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}")
string(REGEX MATCHALL "POWER *([0-9]+)" MATCHED_STRING "${POWER10_M}")
string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}")
if (EXTRACTED_NUMBER GREATER_EQUAL 10)
list(APPEND ARCH_FLAGS -mcpu=power10 -mpowerpc64)
elseif (EXTRACTED_NUMBER EQUAL 9)
list(APPEND ARCH_FLAGS -mcpu=power9 -mpowerpc64)
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
list(APPEND ARCH_FLAGS -mcpu=powerpc64le -mtune=native)
if (EXTRACTED_NUMBER GREATER_EQUAL 10)
list(APPEND ARCH_FLAGS -mcpu=power10 -mpowerpc64)
elseif (EXTRACTED_NUMBER EQUAL 9)
list(APPEND ARCH_FLAGS -mcpu=power9 -mpowerpc64)
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
list(APPEND ARCH_FLAGS -mcpu=powerpc64le -mtune=native)
else()
list(APPEND ARCH_FLAGS -mcpu=native -mtune=native -mpowerpc64)
endif()
else()
list(APPEND ARCH_FLAGS -mcpu=native -mtune=native -mpowerpc64)
if (GGML_CPU_POWERPC_CPUTYPE)
list(APPEND ARCH_FLAGS -mcpu=${GGML_CPU_POWERPC_CPUTYPE})
endif()
endif()
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64")
message(STATUS "loongarch64 detected")
+158
View File
@@ -0,0 +1,158 @@
#include "binary-ops.h"
#if defined(GGML_USE_ACCELERATE)
#include <Accelerate/Accelerate.h>
using vDSP_fn_t = void (*)(const float *, vDSP_Stride, const float *, vDSP_Stride, float *, vDSP_Stride, vDSP_Length);
#endif
static inline float op_add(float a, float b) {
return a + b;
}
static inline float op_sub(float a, float b) {
return a - b;
}
static inline float op_mul(float a, float b) {
return a * b;
}
static inline float op_div(float a, float b) {
return a / b;
}
template <float (*op)(float, float), typename src0_t, typename src1_t, typename dst_t>
static inline void vec_binary_op_contiguous(const int64_t n, dst_t * z, const src0_t * x, const src1_t * y) {
constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
constexpr auto src1_to_f32 = type_conversion_table<src1_t>::to_f32;
constexpr auto f32_to_dst = type_conversion_table<dst_t >::from_f32;
for (int i = 0; i < n; i++) {
z[i] = f32_to_dst(op(src0_to_f32(x[i]), src1_to_f32(y[i])));
}
}
template <float (*op)(float, float), typename src0_t, typename src1_t, typename dst_t>
static inline void vec_binary_op_non_contiguous(const int64_t n, const int64_t ne10, const int64_t nb10, dst_t * z, const src0_t * x, const src1_t * y) {
constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
constexpr auto src1_to_f32 = type_conversion_table<src1_t>::to_f32;
constexpr auto f32_to_dst = type_conversion_table<dst_t >::from_f32;
for (int i = 0; i < n; i++) {
int i10 = i % ne10;
const src1_t * y_ptr = (const src1_t *)((const char *)y + i10*nb10);
z[i] = f32_to_dst(op(src0_to_f32(x[i]), src1_to_f32(*y_ptr)));
}
}
template <float (*op)(float, float), typename src0_t, typename src1_t, typename dst_t>
static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
GGML_TENSOR_BINARY_OP_LOCALS
GGML_ASSERT( nb0 == sizeof(dst_t));
GGML_ASSERT(nb00 == sizeof(src0_t));
const auto [ir0, ir1] = get_thread_range(params, src0);
const bool is_src1_contiguous = (nb10 == sizeof(src1_t));
if (!is_src1_contiguous) { // broadcast not implemented yet for non-contiguous
GGML_ASSERT(ggml_are_same_shape(src0, src1));
}
#ifdef GGML_USE_ACCELERATE
vDSP_fn_t vDSP_op = nullptr;
// TODO - avoid the f32-only check using type 'trait' lookup tables and row-based src-to-float conversion functions
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
if (op == op_add) {
vDSP_op = vDSP_vadd;
} else if (op == op_sub) {
vDSP_op = vDSP_vsub;
} else if (op == op_mul) {
vDSP_op = vDSP_vmul;
} else if (op == op_div) {
vDSP_op = vDSP_vdiv;
}
}
#endif
for (int64_t ir = ir0; ir < ir1; ++ir) {
const int64_t i03 = ir/(ne02*ne01);
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
const int64_t i13 = i03 % ne13;
const int64_t i12 = i02 % ne12;
const int64_t i11 = i01 % ne11;
dst_t * dst_ptr = (dst_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
const src1_t * src1_ptr = (const src1_t *) ((const char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
if (is_src1_contiguous) {
// src1 is broadcastable across src0 and dst in i1, i2, i3
const int64_t nr0 = ne00 / ne10;
for (int64_t r = 0; r < nr0; ++r) {
#ifdef GGML_USE_ACCELERATE
if constexpr (std::is_same_v<src0_t, float> && std::is_same_v<src1_t, float> && std::is_same_v<dst_t, float>) {
if (vDSP_op != nullptr) {
vDSP_op(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);
continue;
}
}
#endif
vec_binary_op_contiguous<op>(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
}
} else {
vec_binary_op_non_contiguous<op>(ne0, ne10, nb10, dst_ptr, src0_ptr, src1_ptr);
}
}
}
// TODO: Use the 'traits' lookup table (for type conversion fns), instead of a mass of 'if' conditions with long templates
template <float (*op)(float, float)>
static void binary_op(const ggml_compute_params * params, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
/* */ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32
apply_binary_op<op, float, float, float>(params, dst);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16
apply_binary_op<op, ggml_fp16_t, ggml_fp16_t, ggml_fp16_t>(params, dst);
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16
apply_binary_op<op, ggml_bf16_t, ggml_bf16_t, ggml_bf16_t>(params, dst);
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_BF16) {
apply_binary_op<op, ggml_bf16_t, float, ggml_bf16_t>(params, dst);
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
apply_binary_op<op, ggml_bf16_t, float, float>(params, dst);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
apply_binary_op<op, ggml_fp16_t, float, ggml_fp16_t>(params, dst);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
apply_binary_op<op, ggml_fp16_t, float, float>(params, dst);
} else {
GGML_ABORT("%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
}
}
void ggml_compute_forward_add_non_quantized(const ggml_compute_params * params, ggml_tensor * dst) {
binary_op<op_add>(params, dst);
}
void ggml_compute_forward_sub(const ggml_compute_params * params, ggml_tensor * dst) {
binary_op<op_sub>(params, dst);
}
void ggml_compute_forward_mul(const ggml_compute_params * params, ggml_tensor * dst) {
binary_op<op_mul>(params, dst);
}
void ggml_compute_forward_div(const ggml_compute_params * params, ggml_tensor * dst) {
binary_op<op_div>(params, dst);
}
+16
View File
@@ -0,0 +1,16 @@
#pragma once
#include "common.h"
#ifdef __cplusplus
extern "C" {
#endif
void ggml_compute_forward_add_non_quantized(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_sub(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_mul(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_div(const struct ggml_compute_params * params, struct ggml_tensor * dst);
#ifdef __cplusplus
}
#endif
+72
View File
@@ -0,0 +1,72 @@
#pragma once
#include "ggml.h"
#include "ggml-cpu-traits.h"
#include "ggml-cpu-impl.h"
#include "ggml-impl.h"
#ifdef __cplusplus
#include <utility>
// convenience functions/macros for use in template calls
// note: these won't be required after the 'traits' lookup table is used.
static inline ggml_fp16_t f32_to_f16(float x) {
return GGML_FP32_TO_FP16(x);
}
static inline float f16_to_f32(ggml_fp16_t x) {
return GGML_FP16_TO_FP32(x);
}
static inline ggml_bf16_t f32_to_bf16(float x) {
return GGML_FP32_TO_BF16(x);
}
static inline float bf16_to_f32(ggml_bf16_t x) {
return GGML_BF16_TO_FP32(x);
}
static inline float f32_to_f32(float x) {
return x;
}
// TODO - merge this into the traits table, after using row-based conversions
template <class T>
struct type_conversion_table;
template <>
struct type_conversion_table<ggml_fp16_t> {
static constexpr float (*to_f32)(ggml_fp16_t) = f16_to_f32;
static constexpr ggml_fp16_t (*from_f32)(float) = f32_to_f16;
};
template <>
struct type_conversion_table<float> {
static constexpr float (*to_f32)(float) = f32_to_f32;
static constexpr float (*from_f32)(float) = f32_to_f32;
};
template <>
struct type_conversion_table<ggml_bf16_t> {
static constexpr float (*to_f32)(ggml_bf16_t) = bf16_to_f32;
static constexpr ggml_bf16_t (*from_f32)(float) = f32_to_bf16;
};
static std::pair<int64_t, int64_t> get_thread_range(const struct ggml_compute_params * params, const struct ggml_tensor * src0) {
const int64_t ith = params->ith;
const int64_t nth = params->nth;
const int64_t nr = ggml_nrows(src0);
// rows per thread
const int64_t dr = (nr + nth - 1)/nth;
// row range for this thread
const int64_t ir0 = dr*ith;
const int64_t ir1 = MIN(ir0 + dr, nr);
return {ir0, ir1};
}
#endif
File diff suppressed because it is too large Load Diff
+16 -2
View File
@@ -2680,13 +2680,25 @@ class tinyBLAS_PPC {
__builtin_mma_xxsetaccz(&acc_0);
vec_t vec_A[4] {0}, vec_B[4] = {0};
for (int l=0; l<k; l+=4) {
if (RN >= 4 && RM == 1) {
/* 'GEMV Forwarding' concept is used in first two conditional loops.
* when one of the matrix has a single row/column, the elements are
* broadcasted, instead of using packing routine to prepack the
* matrix elements.
*/
if (RM == 1) {
TA* a = const_cast<TA*>(A+(ii)*lda+l);
packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
vec_A[0] = (vec_t)vec_xl(0,a);
vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1));
vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2));
vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3));
} else if (RN == 1) {
packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
TB* b = const_cast<TB*>(B+(jj)*ldb+l);
vec_B[0] = (vec_t)vec_xl(0,b);
vec_B[1] = (vec_t)vec_splats(*((TB*)&vec_B+1));
vec_B[2] = (vec_t)vec_splats(*((TB*)&vec_B+2));
vec_B[3] = (vec_t)vec_splats(*((TB*)&vec_B+3));
} else {
packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
@@ -2790,8 +2802,10 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
assert(params->ith < params->nth);
// only enable sgemm for prompt processing
#if !defined(__MMA__)
if (n < 2)
return false;
#endif
if (Ctype != GGML_TYPE_F32)
return false;
+186
View File
@@ -0,0 +1,186 @@
#include "unary-ops.h"
static inline float op_abs(float x) {
return fabsf(x);
}
static inline float op_sgn(float x) {
return (x > 0.f) ? 1.f : ((x < 0.f) ? -1.f : 0.f);
}
static inline float op_neg(float x) {
return -x;
}
static inline float op_step(float x) {
return (x > 0.f) ? 1.f : 0.f;
}
static inline float op_tanh(float x) {
return tanhf(x);
}
static inline float op_elu(float x) {
return (x > 0.f) ? x : expm1f(x);
}
static inline float op_relu(float x) {
return (x > 0.f) ? x : 0.f;
}
static inline float op_sigmoid(float x) {
return 1.f / (1.f + expf(-x));
}
static inline float op_hardsigmoid(float x) {
return fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f));
}
static inline float op_exp(float x) {
return expf(x);
}
static inline float op_hardswish(float x) {
return x * fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f));
}
static inline float op_sqr(float x) {
return x * x;
}
static inline float op_sqrt(float x) {
return sqrtf(x);
}
static inline float op_sin(float x) {
return sinf(x);
}
static inline float op_cos(float x) {
return cosf(x);
}
static inline float op_log(float x) {
return logf(x);
}
template <float (*op)(float), typename src0_t, typename dst_t>
static inline void vec_unary_op(int64_t n, dst_t * y, const src0_t * x) {
constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
constexpr auto f32_to_dst = type_conversion_table<dst_t >::from_f32;
for (int i = 0; i < n; i++) {
y[i] = f32_to_dst(op(src0_to_f32(x[i])));
}
}
template <float (*op)(float), typename src0_t, typename dst_t>
static void apply_unary_op(const ggml_compute_params * params, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst));
GGML_TENSOR_UNARY_OP_LOCALS
GGML_ASSERT( nb0 == sizeof(dst_t));
GGML_ASSERT(nb00 == sizeof(src0_t));
const auto [ir0, ir1] = get_thread_range(params, src0);
for (int64_t ir = ir0; ir < ir1; ++ir) {
const int64_t i03 = ir/(ne02*ne01);
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
dst_t * dst_ptr = (dst_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
vec_unary_op<op>(ne0, dst_ptr, src0_ptr);
}
}
// TODO: Use the 'traits' lookup table (for type conversion fns), instead of a mass of 'if' conditions with long templates
template <float (*op)(float)>
static void unary_op(const ggml_compute_params * params, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
/* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32
apply_unary_op<op, float, float>(params, dst);
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16
apply_unary_op<op, ggml_fp16_t, ggml_fp16_t>(params, dst);
} else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16
apply_unary_op<op, ggml_bf16_t, ggml_bf16_t>(params, dst);
} else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) {
apply_unary_op<op, ggml_bf16_t, float>(params, dst);
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
apply_unary_op<op, ggml_fp16_t, float>(params, dst);
} else {
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__,
ggml_type_name(dst->type), ggml_type_name(src0->type));
GGML_ABORT("fatal error");
}
}
void ggml_compute_forward_abs(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_abs>(params, dst);
}
void ggml_compute_forward_sgn(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_sgn>(params, dst);
}
void ggml_compute_forward_neg(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_neg>(params, dst);
}
void ggml_compute_forward_step(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_step>(params, dst);
}
void ggml_compute_forward_tanh(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_tanh>(params, dst);
}
void ggml_compute_forward_elu(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_elu>(params, dst);
}
void ggml_compute_forward_relu(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_relu>(params, dst);
}
void ggml_compute_forward_sigmoid(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_sigmoid>(params, dst);
}
void ggml_compute_forward_hardsigmoid(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_hardsigmoid>(params, dst);
}
void ggml_compute_forward_exp(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_exp>(params, dst);
}
void ggml_compute_forward_hardswish(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_hardswish>(params, dst);
}
void ggml_compute_forward_sqr(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_sqr>(params, dst);
}
void ggml_compute_forward_sqrt(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_sqrt>(params, dst);
}
void ggml_compute_forward_sin(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_sin>(params, dst);
}
void ggml_compute_forward_cos(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_cos>(params, dst);
}
void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_log>(params, dst);
}
+28
View File
@@ -0,0 +1,28 @@
#pragma once
#include "common.h"
#ifdef __cplusplus
extern "C" {
#endif
void ggml_compute_forward_abs(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_sgn(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_neg(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_step(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_tanh(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_elu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_sigmoid(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_hardsigmoid(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_exp(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_hardswish(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_sqr(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_sqrt(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst);
#ifdef __cplusplus
}
#endif
+4
View File
@@ -288,6 +288,10 @@ static __device__ void no_device_code(
__trap();
GGML_UNUSED(no_device_code); // suppress unused function warning
#if defined(GGML_USE_MUSA)
__builtin_unreachable();
#endif // defined(GGML_USE_MUSA)
}
#ifdef __CUDA_ARCH__
+2 -2
View File
@@ -38,7 +38,7 @@ static __global__ void concat_f32_dim1(const float * x, const float * y, float *
blockIdx.y * ne0 +
blockIdx.z * ne0 * gridDim.y;
if (blockIdx.y < ne01) { // src0
if (blockIdx.y < (unsigned)ne01) { // src0
int offset_src =
nidx +
blockIdx.y * ne0 +
@@ -64,7 +64,7 @@ static __global__ void concat_f32_dim2(const float * x, const float * y, float *
blockIdx.y * ne0 +
blockIdx.z * ne0 * gridDim.y;
if (blockIdx.z < ne02) { // src0
if (blockIdx.z < (unsigned)ne02) { // src0
int offset_src =
nidx +
blockIdx.y * ne0 +
+4 -2
View File
@@ -34,6 +34,10 @@ static __global__ void conv_transpose_1d_kernel(
}
}
dst[global_index] = accumulator;
GGML_UNUSED(p0); GGML_UNUSED(d0); GGML_UNUSED(src0_ne3);
GGML_UNUSED(src1_ne3); GGML_UNUSED(dst_ne3);
GGML_UNUSED(src1_ne1); GGML_UNUSED(dst_ne1);
GGML_UNUSED(src1_ne2); GGML_UNUSED(dst_ne2);
}
static void conv_transpose_1d_f32_f32_cuda(
@@ -75,8 +79,6 @@ void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor
const int p0 = 0;//opts[3];
const int d0 = 1;//opts[4];
const int64_t kernel_size = ggml_nelements(src0);
const int64_t input_size = ggml_nelements(src1);
const int64_t output_size = ggml_nelements(dst);
conv_transpose_1d_f32_f32_cuda(s0, p0, d0, output_size,
+1 -1
View File
@@ -577,7 +577,7 @@ static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __res
return;
}
const src_t * x = (src_t *) vx;
const src_t * x = (const src_t *) vx;
y[i] = x[i];
}
+5 -4
View File
@@ -315,14 +315,14 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared(
float vals[sizeof(int)] = {0.0f};
#pragma unroll
for (int l = 0; l < sizeof(int); ++l) {
for (int l = 0; l < int(sizeof(int)); ++l) {
vals[l] = scale * x[4*threadIdx.x + l];
}
float amax = fabsf(vals[0]);
float sum = vals[0];
#pragma unroll
for (int l = 1; l < sizeof(int); ++l) {
for (int l = 1; l < int(sizeof(int)); ++l) {
amax = fmaxf(amax, fabsf(vals[l]));
sum += vals[l];
}
@@ -338,7 +338,7 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared(
if (d != 0.0f) {
#pragma unroll
for (int l = 0; l < sizeof(int); ++l) {
for (int l = 0; l < int(sizeof(int)); ++l) {
q8[l] = roundf(vals[l] / d);
}
}
@@ -638,7 +638,7 @@ static __global__ void flash_attn_combine_results(
float VKQ_denominator = 0.0f;
for (int l = 0; l < parallel_blocks; ++l) {
const float diff = meta[l].x - kqmax;
const float KQ_max_scale = expf(diff);
float KQ_max_scale = expf(diff);
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
@@ -649,6 +649,7 @@ static __global__ void flash_attn_combine_results(
dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
}
[[noreturn]]
static void on_no_fattn_vec_case(const int D) {
if (D == 64) {
fprintf(stderr, "Unsupported KV type combination for head_size 64.\n");
+55 -30
View File
@@ -406,6 +406,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
#endif // CP_ASYNC_AVAILABLE
#else
GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2);
GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_KV);
GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
GGML_UNUSED(kb0);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
@@ -797,6 +806,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
__syncthreads();
}
#else
GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2);
GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1);
GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_KV); GGML_UNUSED(stride_mask);
GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
@@ -931,6 +946,16 @@ static __global__ void flash_attn_ext_f16(
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
#else
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
}
@@ -985,38 +1010,38 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/4, 4); \
extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/8, 8); \
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 8);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 8);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 8);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 8);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 8);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 8);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 8)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 8)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 8)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 8)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 8)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 8)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 16);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 16);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 16);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 16);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 16);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 16);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 16)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 16)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 16)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 16)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 16)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 16)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 32);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 32);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 32);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 32);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 32);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 32);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 32)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 32)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 32)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 32)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 32)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 32)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 64);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 64);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 64);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 64);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 64);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 64)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 64)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 64)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 64)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 64)
// Kernels with ncols == 128 are only 4% faster due to register pressure.
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128);
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128);
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128);
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128);
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128);
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128); // Needs too much shared memory.
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128)
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128)
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128)
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128)
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128)
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128) // Needs too much shared memory.
+13 -1
View File
@@ -282,7 +282,19 @@ static __global__ void flash_attn_tile_ext_f16(
}
}
#else
NO_DEVICE_CODE;
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
}
+12
View File
@@ -281,6 +281,18 @@ static __global__ void flash_attn_tile_ext_f32(
}
}
#else
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
NO_DEVICE_CODE;
#endif // FLASH_ATTN_AVAILABLE
}
+13 -1
View File
@@ -292,7 +292,19 @@ static __global__ void flash_attn_vec_ext_f16(
dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
}
#else
NO_DEVICE_CODE;
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
}
+10
View File
@@ -277,6 +277,16 @@ static __global__ void flash_attn_vec_ext_f32(
dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
}
#else
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
NO_DEVICE_CODE;
#endif // FLASH_ATTN_AVAILABLE
}
+11 -1
View File
@@ -430,7 +430,17 @@ static __global__ void flash_attn_ext_f16(
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = dst_meta_val;
}
#else
NO_DEVICE_CODE;
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
GGML_UNUSED(ne31); GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
}
+7
View File
@@ -3232,6 +3232,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
#ifndef FLASH_ATTN_AVAILABLE
return false;
#endif // FLASH_ATTN_AVAILABLE
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
// different head sizes of K and V are not supported yet
return false;
}
if (op->src[0]->ne[0] == 192) {
return false;
}
if (op->src[0]->ne[3] != 1) {
return false;
}
+2
View File
@@ -26,6 +26,7 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
: "=r"(ret) : "r"(x));
#else
GGML_UNUSED(x);
NO_DEVICE_CODE;
#endif // defined(NEW_MMA_AVAILABLE)
return ret;
@@ -178,6 +179,7 @@ namespace ggml_cuda_mma {
: "l"(xs));
#else
load_generic(xs0, stride);
GGML_UNUSED(t);
#endif // NEW_MMA_AVAILABLE
}
+38 -22
View File
@@ -945,7 +945,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
}
}
#else
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
@@ -1024,7 +1024,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
}
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
for (int k01 = 0; k01 < WARP_SIZE/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
const int k0 = k00 + k01;
#pragma unroll
@@ -1035,19 +1035,34 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
if (k01 < WARP_SIZE/2) {
constexpr int ns = 2;
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
&x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
&y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
} else {
constexpr int ns = 1;
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
&x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
&y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
}
constexpr int ns = 2;
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
&x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
&y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
}
}
}
// Some compilers fail to unroll the loop over k01 if there is a conditional statement for ns in the inner loop.
// As a workaround 2 separate loops are used instead.
#pragma unroll
for (int k01 = WARP_SIZE/2; k01 < WARP_SIZE; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
const int k0 = k00 + k01;
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
const int j = j0 + threadIdx.y;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
constexpr int ns = 1;
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
&x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
&y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
}
}
}
@@ -1176,7 +1191,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
}
}
#else
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
@@ -1253,7 +1268,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const float d = bxi->d;
#pragma unroll
for (int l = 0; l < sizeof(int); ++l) {
for (int l = 0; l < int(sizeof(int)); ++l) {
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*(threadIdx.x % (WARP_SIZE/8)) + l] = d*sc8[l];
}
#else
@@ -1376,7 +1391,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
#pragma unroll
for (int l = 0; l < sizeof(int); ++l) {
for (int l = 0; l < int(sizeof(int)); ++l) {
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
}
}
@@ -1517,7 +1532,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
#pragma unroll
for (int l = 0; l < sizeof(int); ++l) {
for (int l = 0; l < int(sizeof(int)); ++l) {
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
}
}
@@ -1810,7 +1825,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
}
}
#else
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
@@ -2570,6 +2585,8 @@ static __device__ void mul_mat_q_process_tile(
} else {
write_back(sum, dst + jt*mmq_x*ne0 + it*mmq_y, ne0, tile_x_max_i, tile_y_max_j);
}
GGML_UNUSED(ne00); GGML_UNUSED(ne10);
}
@@ -2695,7 +2712,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
const int it = (kbc_stop - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
// Skip fixup tile if it's unrelated to the output tile assigned to this CUDA block:
if (it != blockIdx.x || jt != blockIdx.y) {
if ((unsigned)it != blockIdx.x || (unsigned)jt != blockIdx.y) {
continue;
}
@@ -2825,7 +2842,6 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
template <ggml_type type>
void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
const int id = ggml_cuda_get_device();
const int nsm = ggml_cuda_info().devices[id].nsm;
const int cc = ggml_cuda_info().devices[id].cc;
const int smpbo = ggml_cuda_info().devices[id].smpbo;
+1 -1
View File
@@ -29,7 +29,7 @@ static __global__ void mul_mat_vec(
__syncthreads();
}
float sumf;
float sumf = 0.0f;
if constexpr (std::is_same<T, half>::value) {
const half2 * x2 = (const half2 *) x;
+4 -2
View File
@@ -151,7 +151,7 @@ static __global__ void mul_mat_vec_q(
constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
// partial sum for each thread
float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
float tmp[ncols_y][rows_per_cuda_block] = {{0.0f}};
const block_q8_1 * y = (const block_q8_1 *) vy;
@@ -197,10 +197,12 @@ static __global__ void mul_mat_vec_q(
tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]);
}
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < (unsigned)nrows_dst)) {
dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
}
}
GGML_UNUSED(nrows_x);
}
static std::pair<dim3, dim3> calc_launch_params(const int ncols_y, const int nrows_x, const int warp_size, const mmvq_parameter_table_id table_id) {
+1 -1
View File
@@ -14,7 +14,7 @@ static __global__ void pad_f32(const float * x, float * dst, const int ne0, cons
nidx +
blockIdx.y * ne0 +
blockIdx.z * ne0 * gridDim.y;
if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02*ne03) {
if (nidx < ne00 && blockIdx.y < (unsigned)ne01 && blockIdx.z < (unsigned)(ne02*ne03)) {
int offset_src =
nidx +
blockIdx.y * ne00 +
+1 -1
View File
@@ -19,7 +19,7 @@ static __global__ void upscale_f32(const float * x, float * dst,
int i02 = i12 / sf2;
int i03 = i13 / sf3;
dst[index] = *(float *)((char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
dst[index] = *( (const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00) );
}
static void upscale_f32_cuda(const float * x, float * dst,
+6 -3
View File
@@ -219,9 +219,12 @@ typedef struct {
int32_t ne11;
int32_t ne_12_2; // assume K and V are same shape
int32_t ne_12_3;
uint64_t nb_12_1;
uint64_t nb_12_2;
uint64_t nb_12_3;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
uint64_t nb31;
int32_t ne1;
int32_t ne2;
File diff suppressed because it is too large Load Diff
+288 -227
View File
@@ -48,7 +48,7 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
template <typename type4>
void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) {
reg = (type4)(*(src + il));
reg = (type4)(*(src));
}
#if defined(GGML_METAL_USE_BF16)
@@ -56,6 +56,11 @@ template <typename type4x4>
void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
reg = (type4x4)(*src);
}
template <typename type4>
void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg) {
reg = (type4)(*(src));
}
#endif
template <typename type4x4>
@@ -3100,7 +3105,8 @@ template<
typename vd4x4_t, // key type in device memory
short nl_v,
void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
short D, // head size
short DK, // K head size
short DV, // V head size
short Q = 8, // queries per threadgroup
short KV = 8, // key/value processed per each simdgroup
short C = 32> // cache items per threadgroup
@@ -3122,20 +3128,24 @@ kernel void kernel_flash_attn_ext(
const int iq2 = tgpig[1];
const int iq1 = tgpig[0]*Q;
const short D4 = D/4;
const short D8 = D/8;
const short D16 = D/16;
const short NW = N_SIMDWIDTH;
const short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
constexpr short DK4 = DK/4;
constexpr short DK8 = DK/8;
constexpr short DK16 = DK/16;
constexpr short DV4 = DV/4;
constexpr short DV8 = DV/8;
constexpr short DV16 = DV/16;
constexpr short NW = N_SIMDWIDTH;
constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
const short TS = nsg*SH; // shared memory size per query in (s_t == float)
const short T = D + 2*TS; // shared memory size per query in (half)
const short T = DK + 2*TS; // shared memory size per query in (half)
threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*D); // holds the query data
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*D); // same as above but in q4_t
threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*D); // reuse query data for accumulation
threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*D); // same as above but in o4_t
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix
threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation
threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*DK); // scratch buffer for attention, mask and diagonal matrix
threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
@@ -3144,23 +3154,23 @@ kernel void kernel_flash_attn_ext(
threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
o8x8_t lo[D8];
o8x8_t lo[DV8];
// load heads from Q to shared memory
for (short j = sgitg; j < Q; j += nsg) {
device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
for (short i = tiisg; i < D4; i += NW) {
for (short i = tiisg; i < DK4; i += NW) {
if (iq1 + j < args.ne01) {
sq4[j*D4 + i] = (q4_t) q4[i];
sq4[j*DK4 + i] = (q4_t) q4[i];
} else {
sq4[j*D4 + i] = (q4_t) 0.0f;
sq4[j*DK4 + i] = (q4_t) 0.0f;
}
}
}
// zero out lo
for (short i = 0; i < D8; ++i) {
for (short i = 0; i < DV8; ++i) {
lo[i] = make_filled_simdgroup_matrix<o_t, 8>((o_t) 0.0f);
}
@@ -3190,13 +3200,6 @@ kernel void kernel_flash_attn_ext(
const short ikv2 = iq2/(args.ne02/args.ne_12_2);
const short ikv3 = iq3/(args.ne03/args.ne_12_3);
// load the queries from shared memory into local memory
q8x8_t mq[D8];
for (short i = 0; i < D8; ++i) {
simdgroup_load(mq[i], sq + i*8, D);
}
const bool has_mask = mask != q;
half slope = 1.0f;
@@ -3249,20 +3252,22 @@ kernel void kernel_flash_attn_ext(
// this is compile-time check, so it does not have runtime overhead
if (is_same<kd4x4_t, k4x4_t>::value) {
// we can read directly from global memory
device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
#pragma unroll(D8)
for (short i = 0; i < D8; ++i) {
#pragma unroll(DK8)
for (short i = 0; i < DK8; ++i) {
k8x8_t mk;
simdgroup_load(mk, pk + i*8, args.nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
simdgroup_load(mk, pk + i*8, args.nb11/sizeof(k_t), 0, true); // transpose // TODO: use ne10
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
q8x8_t mq;
simdgroup_load(mq, sq + i*8, DK);
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
}
} else {
for (short ii = 0; ii < D16; ii += 4) {
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
for (short ii = 0; ii < DK16; ii += 4) {
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
if (D16%4 == 0) {
if (DK16%4 == 0) {
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
{
k4x4_t tmp;
@@ -3275,15 +3280,18 @@ kernel void kernel_flash_attn_ext(
#pragma unroll(4)
for (short k = 0; k < 4; ++k) {
k8x8_t mk;
q8x8_t mq;
simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
}
} else {
if (ii + tx < D16) {
if (ii + tx < DK16) {
k4x4_t tmp;
deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
sk4x4[4*ty + tx] = tmp;
@@ -3291,14 +3299,17 @@ kernel void kernel_flash_attn_ext(
simdgroup_barrier(mem_flags::mem_threadgroup);
for (short k = 0; k < 4 && ii + k < D16; ++k) {
for (short k = 0; k < 4 && ii + k < DK16; ++k) {
k8x8_t mk;
q8x8_t mq;
simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
}
}
}
@@ -3350,8 +3361,8 @@ kernel void kernel_flash_attn_ext(
s8x8_t mm;
simdgroup_load(mm, ss + 2*C, TS, 0, false);
#pragma unroll(D8)
for (short i = 0; i < D8; ++i) {
#pragma unroll(DV8)
for (short i = 0; i < DV8; ++i) {
simdgroup_multiply(lo[i], mm, lo[i]);
}
}
@@ -3364,20 +3375,20 @@ kernel void kernel_flash_attn_ext(
if (is_same<vd4x4_t, v4x4_t>::value) {
// we can read directly from global memory
device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
#pragma unroll(D8)
for (short i = 0; i < D8; ++i) {
#pragma unroll(DV8)
for (short i = 0; i < DV8; ++i) {
v8x8_t mv;
simdgroup_load(mv, pv + i*8, args.nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20
simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
}
} else {
for (short ii = 0; ii < D16; ii += 4) {
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
for (short ii = 0; ii < DV16; ii += 4) {
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
if (D16%4 == 0) {
if (DV16%4 == 0) {
// no need for bound checks
{
v4x4_t tmp;
@@ -3398,7 +3409,7 @@ kernel void kernel_flash_attn_ext(
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
}
} else {
if (ii + tx < D16) {
if (ii + tx < DV16) {
v4x4_t tmp;
deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
sv4x4[4*ty + tx] = tmp;
@@ -3406,7 +3417,7 @@ kernel void kernel_flash_attn_ext(
simdgroup_barrier(mem_flags::mem_threadgroup);
for (short k = 0; k < 4 && ii + k < D16; ++k) {
for (short k = 0; k < 4 && ii + k < DV16; ++k) {
v8x8_t mv;
simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
@@ -3440,8 +3451,8 @@ kernel void kernel_flash_attn_ext(
// each simdgroup stores its output to shared memory, reusing sq
if (sgitg == sg) {
for (short i = 0; i < D8; ++i) {
simdgroup_store(lo[i], so + i*8, D, 0, false);
for (short i = 0; i < DV8; ++i) {
simdgroup_store(lo[i], so + i*8, DV, 0, false);
}
}
@@ -3480,11 +3491,11 @@ kernel void kernel_flash_attn_ext(
simdgroup_load(ms0, ss + 2*C, TS, 0, false);
simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
#pragma unroll(D8)
for (short i = 0; i < D8; ++i) {
#pragma unroll(DV8)
for (short i = 0; i < DV8; ++i) {
o8x8_t t;
simdgroup_load (t, so + i*8, D, 0, false);
simdgroup_load (t, so + i*8, DV, 0, false);
simdgroup_multiply(t, ms1, t);
simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
@@ -3495,8 +3506,8 @@ kernel void kernel_flash_attn_ext(
// store result to shared memory (reuse sq)
if (sgitg == 0) {
for (short i = 0; i < D8; ++i) {
simdgroup_store(lo[i], so + i*8, D, 0, false);
for (short i = 0; i < DV8; ++i) {
simdgroup_store(lo[i], so + i*8, DV, 0, false);
}
}
@@ -3507,8 +3518,8 @@ kernel void kernel_flash_attn_ext(
for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) {
const float S = ss[j*TS + 0];
for (short i = tiisg; i < D4; i += NW) {
dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
for (short i = tiisg; i < DV4; i += NW) {
dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4 + i] = (float4) so4[j*DV4 + i]/S;
}
}
}
@@ -3525,80 +3536,94 @@ kernel void kernel_flash_attn_ext(
float, simdgroup_float8x8, \
half, half4, simdgroup_half8x8
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>) flash_attn_ext_t;
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>;
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80>;
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96>;
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 112>;
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>;
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256>;
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96, 96>;
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 112, 112>;
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128, 128>;
template [[host_name("kernel_flash_attn_ext_f16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 192>;
template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 128>;
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256, 256>;
#if defined(GGML_METAL_USE_BF16)
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64>;
template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80>;
template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96>;
template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112>;
template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128>;
template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256>;
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128, 128>;
template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
#endif
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64>;
template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80>;
template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96>;
template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112>;
template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128>;
template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256>;
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96, 96>;
template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112, 112>;
template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128, 128>;
template [[host_name("kernel_flash_attn_ext_q4_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 192>;
template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 128>;
template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64>;
template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80>;
template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96>;
template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112>;
template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128>;
template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256>;
template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96, 96>;
template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112, 112>;
template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128, 128>;
template [[host_name("kernel_flash_attn_ext_q4_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 192>;
template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 128>;
template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64>;
template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80>;
template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96>;
template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112>;
template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128>;
template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256>;
template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96, 96>;
template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112, 112>;
template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128, 128>;
template [[host_name("kernel_flash_attn_ext_q5_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 192>;
template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 128>;
template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64>;
template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80>;
template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96>;
template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112>;
template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128>;
template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256>;
template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96, 96>;
template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112, 112>;
template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128, 128>;
template [[host_name("kernel_flash_attn_ext_q5_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 192>;
template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 128>;
template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64>;
template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80>;
template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96>;
template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112>;
template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128>;
template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256>;
template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96, 96>;
template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112, 112>;
template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128, 128>;
template [[host_name("kernel_flash_attn_ext_q8_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 192>;
template [[host_name("kernel_flash_attn_ext_q8_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 128>;
template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256, 256>;
#undef FA_TYPES
template<
typename q4_t, // query types in shared memory
typename q4x4_t,
typename k4x4_t, // key types in shared memory
typename v4x4_t, // value types in shared memory
typename qk_t, // Q*K types
typename s_t, // soft-max types
typename q4_t, // query types in shared memory
typename k4_t, // key types in shared memory
typename v4_t, // value types in shared memory
typename qk_t, // Q*K types
typename s_t, // soft-max types
typename s4_t,
typename s4x4_t,
typename o4x4_t, // attention accumulation types
typename kd4x4_t, // key type in device memory
typename o4_t, // attention accumulation types
typename kd4_t, // key type in device memory
short nl_k,
void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
typename vd4x4_t, // key type in device memory
void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &),
typename vd4_t, // key type in device memory
short nl_v,
void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
short D, // head size
short Q = 1, // queries per threadgroup
short C = 32> // cache items per threadgroup
void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
short DK, // K head size
short DV, // V head size
short NE = 4, // head elements per thread
short Q = 1, // queries per threadgroup
short C = 32> // cache items per threadgroup
kernel void kernel_flash_attn_ext_vec(
constant ggml_metal_kargs_flash_attn_ext & args,
device const char * q,
@@ -3617,29 +3642,28 @@ kernel void kernel_flash_attn_ext_vec(
const int iq2 = tgpig[1];
const int iq1 = tgpig[0];
const short D4 = D/4;
const short D16 = D/16;
const short NW = N_SIMDWIDTH;
const short NL = NW/4; // note: this can be adjusted to support D%64 == 0 and D%32 == 0
const short SH = 2*C; // shared memory per simdgroup
constexpr short DK4 = DK/4;
constexpr short DV4 = DV/4;
constexpr short NW = N_SIMDWIDTH;
constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
constexpr short SH = 2*C; // shared memory per simdgroup
const short T = D + nsg*SH; // shared memory size per query in (half)
const short T = DK + nsg*SH; // shared memory size per query in (half)
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*D); // holds the query data
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*D); // same as above but in q4_t
threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shmem_f16 + 0*D); // same as above but in q4x4_t
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*D); // scratch buffer for attention
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*D); // same as above but in s4_t
threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*D); // scratch buffer for mask
threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shmem_f16 + sgitg*D + Q*T); // scratch buffer for the results
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*DK); // scratch buffer for mask
threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
o4x4_t lo[D16/NL];
// store the result for all queries in local memory (the O matrix from the paper)
o4_t lo[DV4/NL];
// load heads from Q to shared memory
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
for (short i = tiisg; i < D4; i += NW) {
for (short i = tiisg; i < DK4; i += NW) {
if (iq1 < args.ne01) {
sq4[i] = (q4_t) q4[i];
} else {
@@ -3648,8 +3672,8 @@ kernel void kernel_flash_attn_ext_vec(
}
// zero out lo
for (short i = 0; i < D16/NL; ++i) {
lo[i] = (o4x4_t) 0.0f;
for (short i = 0; i < DV4/NL; ++i) {
lo[i] = (o4_t) 0.0f;
}
// zero out shared memory SH
@@ -3674,14 +3698,6 @@ kernel void kernel_flash_attn_ext_vec(
const short ikv2 = iq2/(args.ne02/args.ne_12_2);
const short ikv3 = iq3/(args.ne03/args.ne_12_3);
// load the queries from shared memory into local memory
q4x4_t mq[D16/NL];
#pragma unroll(D16/NL)
for (short ii = 0; ii < D16; ii += NL) {
mq[ii/NL] = sq4x4[ii + tx];
}
const bool has_mask = mask != q;
// pointer to the mask
@@ -3713,43 +3729,56 @@ kernel void kernel_flash_attn_ext_vec(
// Q*K^T
{
// each simdgroup processes 1 query and 4 (NW/NL) keys
for (short cc = 0; cc < C/4; ++cc) {
qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 };
// each simdgroup processes 1 query and NE (NW/NL) head elements
for (short cc = 0; cc < C/NE; ++cc) {
qk_t mqk = 0.0f;
device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
#pragma unroll(D16/NL)
for (short ii = 0; ii < D16; ii += NL) {
#pragma unroll(DK4/NL)
for (short ii = 0; ii < DK4; ii += NL) {
const short i = ii + tx;
k4x4_t mk;
deq_k(pk + i/nl_k, i%nl_k, mk);
k4_t mk;
deq_k_t4(pk + i/nl_k, i%nl_k, mk);
// note: this is less precise than the version below
//mqka[0] += dot(mq[ii/NL][0], mk[0]);
//mqka[1] += dot(mq[ii/NL][1], mk[1]);
//mqka[2] += dot(mq[ii/NL][2], mk[2]);
//mqka[3] += dot(mq[ii/NL][3], mk[3]);
//mqka[0] += dot(mq[0], mk[0]);
//mqka[1] += dot(mq[1], mk[1]);
//mqka[2] += dot(mq[2], mk[2]);
//mqka[3] += dot(mq[3], mk[3]);
mqka[0] += dot((float4) mq[ii/NL][0], (float4) mk[0]);
mqka[1] += dot((float4) mq[ii/NL][1], (float4) mk[1]);
mqka[2] += dot((float4) mq[ii/NL][2], (float4) mk[2]);
mqka[3] += dot((float4) mq[ii/NL][3], (float4) mk[3]);
//q4x4_t mq = sq4x4[i];
//mqka[0] += dot((float4) mq[0], (float4) mk[0]);
//mqka[1] += dot((float4) mq[1], (float4) mk[1]);
//mqka[2] += dot((float4) mq[2], (float4) mk[2]);
//mqka[3] += dot((float4) mq[3], (float4) mk[3]);
mqk += dot((float4) mk, (float4) sq4[i]);
}
qk_t mqk = mqka[0] + mqka[1] + mqka[2] + mqka[3];
static_assert(NE > 1, "NE must be > 1"); // note: not sure why NE == 1 fails
// simdgroup reduce
// simdgroup reduce (NE = 4)
// [ 0 .. 7] -> [ 0]
// [ 8 .. 15] -> [ 8]
// [16 .. 23] -> [16]
// [24 .. 31] -> [24]
//mqk += simd_shuffle_down(mqk, 16);
//mqk += simd_shuffle_down(mqk, 8);
mqk += simd_shuffle_down(mqk, 4);
mqk += simd_shuffle_down(mqk, 2);
mqk += simd_shuffle_down(mqk, 1);
if (NE <= 1) {
mqk += simd_shuffle_down(mqk, 16);
}
if (NE <= 2) {
mqk += simd_shuffle_down(mqk, 8);
}
if (NE <= 4) {
mqk += simd_shuffle_down(mqk, 4);
}
if (NE <= 8) {
mqk += simd_shuffle_down(mqk, 2);
}
if (NE <= 16) {
mqk += simd_shuffle_down(mqk, 1);
}
// mqk = mqk*scale + mask*slope
if (tx == 0) {
@@ -3759,9 +3788,9 @@ kernel void kernel_flash_attn_ext_vec(
mqk = args.logit_softcap*precise::tanh(mqk);
}
mqk += sm[4*cc + ty]*slope;
mqk += sm[NE*cc + ty]*slope;
ss[4*cc + ty] = mqk;
ss[NE*cc + ty] = mqk;
}
}
}
@@ -3784,8 +3813,8 @@ kernel void kernel_flash_attn_ext_vec(
ss[tiisg] = vs;
// O = diag(ms)*O
#pragma unroll(D16/NL)
for (short ii = 0; ii < D16; ii += NL) {
#pragma unroll(DV4/NL)
for (short ii = 0; ii < DV4; ii += NL) {
lo[ii/NL] *= ms;
}
}
@@ -3794,17 +3823,18 @@ kernel void kernel_flash_attn_ext_vec(
// O = O + (Q*K^T)*V
{
for (short cc = 0; cc < C/4; ++cc) {
device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
//#pragma unroll(C/NE)
for (short cc = 0; cc < C/NE; ++cc) {
device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
const s4x4_t ms(ss[4*cc + ty]);
const s4_t ms(ss[NE*cc + ty]);
#pragma unroll(D16/NL)
for (short ii = 0; ii < D16; ii += NL) {
#pragma unroll(DV4/NL)
for (short ii = 0; ii < DV4; ii += NL) {
const short i = ii + tx;
v4x4_t mv;
deq_v(pv4 + i/nl_v, i%nl_v, mv);
v4_t mv;
deq_v_t4(pv4 + i/nl_v, i%nl_v, mv);
lo[ii/NL] += mv*ms;
}
@@ -3819,7 +3849,7 @@ kernel void kernel_flash_attn_ext_vec(
}
}
// simdgroup reduce
// simdgroup reduce (NE = 4)
// [ 0, 8, 16, 24] -> [ 0]
// [ 1, 9, 17, 25] -> [ 1]
// [ 2, 10, 18, 26] -> [ 2]
@@ -3828,37 +3858,48 @@ kernel void kernel_flash_attn_ext_vec(
// [ 5, 13, 21, 29] -> [ 5]
// [ 6, 14, 22, 30] -> [ 6]
// [ 7, 15, 23, 31] -> [ 7]
for (short ii = 0; ii < D16; ii += NL) {
lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16);
lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 8);
//lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4);
//lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2);
//lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1);
for (short ii = 0; ii < DV4; ii += NL) {
if (NE > 1) {
lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16);
lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16);
lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16);
lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16);
}
lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16);
lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 8);
//lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4);
//lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2);
//lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1);
if (NE > 2) {
lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 8);
lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 8);
lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 8);
lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 8);
}
lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16);
lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 8);
//lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4);
//lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2);
//lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1);
if (NE > 4) {
lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4);
lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4);
lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4);
lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4);
}
lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16);
lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 8);
//lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4);
//lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2);
//lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1);
if (NE > 8) {
lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2);
lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2);
lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2);
lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2);
}
if (NE > 16) {
lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1);
lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1);
lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1);
lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// store results to shared memory
for (short i = tiisg; i < D16; i += NL) {
sr4x4[i] = lo[i/NL];
for (short i = tiisg; i < DV4; i += NL) {
sr4[i] = lo[i/NL];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -3885,22 +3926,22 @@ kernel void kernel_flash_attn_ext_vec(
}
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
for (short i = tiisg; i < D16; i += NW) {
sr4x4[i] = sr4x4[i]*ms0 + sr4x4[i + r*D16]*ms1;
for (short i = tiisg; i < DV4; i += NW) {
sr4[i] = sr4[i]*ms0 + sr4[i + r*DV4]*ms1;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
device float4x4 * dst44 = (device float4x4 *) dst;
device float4 * dst4 = (device float4 *) dst;
// final rescale with 1/S and store to global memory
if (sgitg == 0) {
const float S = ss[0];
for (short i = tiisg; i < D16; i += NW) {
dst44[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
for (short i = tiisg; i < DV4; i += NW) {
dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*DV4 + i] = (float4) sr4[i]/S;
}
}
}
@@ -3909,34 +3950,54 @@ kernel void kernel_flash_attn_ext_vec(
// in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
//
#define FA_TYPES \
half4, half4x4, \
half4x4, \
half4x4, \
float, \
half, half4, half4x4, \
half4x4
half4, \
half4, \
half4, \
float, \
half, half4, \
half4
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>) flash_attn_ext_vec_t;
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>;
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>;
#if defined(GGML_METAL_USE_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128>;
template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 128, 128, 4>;
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128>;
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 128, 128, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 128, 128, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 128, 128, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 128, 128, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 4>;
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256>;
template [[host_name("kernel_flash_attn_ext_vec_f16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 192, 4>;
#if defined(GGML_METAL_USE_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256>;
template [[host_name("kernel_flash_attn_ext_vec_bf16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 192, 4>;
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256>;
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 192, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 192, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 192, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 192, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 192, 4>;
template [[host_name("kernel_flash_attn_ext_vec_f16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 4>;
#if defined(GGML_METAL_USE_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 128, 4>;
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 128, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 128, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 128, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 128, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 4>;
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 4>;
#if defined(GGML_METAL_USE_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 256, 256, 4>;
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 256, 256, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 256, 256, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 256, 256, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 4>;
#undef FA_TYPES
+1
View File
@@ -63,6 +63,7 @@ set(GGML_OPENCL_KERNELS
ggml-opencl_transpose_16
ggml-opencl_transpose_32
ggml-opencl_transpose_32_16
ggml-opencl_im2col
)
foreach (K ${GGML_OPENCL_KERNELS})
+238 -14
View File
@@ -224,12 +224,14 @@ struct ggml_backend_opencl_context {
cl_program program;
cl_program program_1;
cl_program program_2;
cl_program program_im2col;
cl_kernel kernel_add, kernel_add_row;
cl_kernel kernel_mul, kernel_mul_row;
cl_kernel kernel_scale;
cl_kernel kernel_silu, kernel_silu_4;
cl_kernel kernel_gelu, kernel_gelu_4;
cl_kernel kernel_gelu_quick, kernel_gelu_quick_4;
cl_kernel kernel_relu;
cl_kernel kernel_clamp;
cl_kernel kernel_norm;
@@ -239,6 +241,7 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16;
cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0;
cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16;
cl_kernel kernel_rope_multi_f32, kernel_rope_multi_f16, kernel_rope_vision_f32, kernel_rope_vision_f16;
cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32;
cl_kernel kernel_mul_mat_f32_f32;
cl_kernel kernel_mul_mat_f16_f16;
@@ -252,6 +255,7 @@ struct ggml_backend_opencl_context {
kernel_mul_mat_q4_0_f32_flat_img_v0;
cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat;
cl_kernel kernel_mul_mv_q6_K_f32;
cl_kernel kernel_im2col_f32, kernel_im2col_f16;
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
// Transpose kernels
@@ -708,6 +712,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
CL_CHECK((backend_ctx->kernel_silu_4 = clCreateKernel(backend_ctx->program, "kernel_silu_4", &err), err));
CL_CHECK((backend_ctx->kernel_gelu = clCreateKernel(backend_ctx->program, "kernel_gelu", &err), err));
CL_CHECK((backend_ctx->kernel_gelu_4 = clCreateKernel(backend_ctx->program, "kernel_gelu_4", &err), err));
CL_CHECK((backend_ctx->kernel_gelu_quick = clCreateKernel(backend_ctx->program, "kernel_gelu_quick", &err), err));
CL_CHECK((backend_ctx->kernel_gelu_quick_4 = clCreateKernel(backend_ctx->program, "kernel_gelu_quick_4", &err), err));
CL_CHECK((backend_ctx->kernel_relu = clCreateKernel(backend_ctx->program, "kernel_relu", &err), err));
CL_CHECK((backend_ctx->kernel_clamp = clCreateKernel(backend_ctx->program, "kernel_clamp", &err), err));
CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program, "kernel_norm", &err), err));
@@ -722,6 +728,10 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
CL_CHECK((backend_ctx->kernel_rope_norm_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_norm_f16", &err), err));
CL_CHECK((backend_ctx->kernel_rope_neox_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f32", &err), err));
CL_CHECK((backend_ctx->kernel_rope_neox_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f16", &err), err));
CL_CHECK((backend_ctx->kernel_rope_multi_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_multi_f32", &err), err));
CL_CHECK((backend_ctx->kernel_rope_multi_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_multi_f16", &err), err));
CL_CHECK((backend_ctx->kernel_rope_vision_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_vision_f32", &err), err));
CL_CHECK((backend_ctx->kernel_rope_vision_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_vision_f16", &err), err));
CL_CHECK((backend_ctx->kernel_cpy_f16_f16 = clCreateKernel(backend_ctx->program, "kernel_cpy_f16_f16", &err), err));
CL_CHECK((backend_ctx->kernel_cpy_f16_f32 = clCreateKernel(backend_ctx->program, "kernel_cpy_f16_f32", &err), err));
CL_CHECK((backend_ctx->kernel_cpy_f32_f16 = clCreateKernel(backend_ctx->program, "kernel_cpy_f32_f16", &err), err));
@@ -769,6 +779,19 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
CL_CHECK((backend_ctx->kernel_convert_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_2, "kernel_convert_block_q4_0_noshuffle", &err), err));
// im2col kernels
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src_im2col {
#include "ggml-opencl_im2col.cl.h"
};
#else
const std::string kernel_src_im2col = read_file("ggml-opencl_im2col.cl");
#endif
backend_ctx->program_im2col = build_program_from_source(context, device, kernel_src_im2col.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_im2col_f32 = clCreateKernel(backend_ctx->program_im2col, "kernel_im2col_f32", &err), err));
CL_CHECK((backend_ctx->kernel_im2col_f16 = clCreateKernel(backend_ctx->program_im2col, "kernel_im2col_f16", &err), err));
// Kernels for Adreno
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -1187,6 +1210,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_GELU_QUICK:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
default:
return false;
@@ -1216,14 +1240,26 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
return op->ne[3] == 1;
case GGML_OP_ROPE: {
const int mode = ((const int32_t *) op->op_params)[2];
if (mode & GGML_ROPE_TYPE_MROPE) {
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_mrope && !is_vision) {
if (op->src[0]->type == GGML_TYPE_F32 ||
op->src[0]->type == GGML_TYPE_F16) {
return true;
}
return false;
}
if (mode & GGML_ROPE_TYPE_VISION) {
if (is_vision) {
if (op->src[0]->type == GGML_TYPE_F32 ||
op->src[0]->type == GGML_TYPE_F16) {
return true;
}
return false;
}
return true;
}
case GGML_OP_IM2COL:
return true;
default:
return false;
}
@@ -2582,6 +2618,53 @@ static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const
#endif
}
static void ggml_cl_gelu_quick(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
GGML_ASSERT(dst);
GGML_ASSERT(dst->extra);
UNUSED(src1);
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
cl_command_queue queue = backend_ctx->queue;
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
cl_kernel kernel;
int n = ggml_nelements(dst);
if (n % 4 == 0) {
kernel = backend_ctx->kernel_gelu_quick_4;
n /= 4;
} else {
kernel = backend_ctx->kernel_gelu_quick;
}
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
size_t global_work_size[] = {(size_t)n, 1, 1};
size_t local_work_size[] = {64, 1, 1};
#ifdef GGML_OPENCL_PROFILING
cl_event evt;
clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt);
g_profiling_info.emplace_back();
populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
#else
clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL);
#endif
}
static void ggml_cl_silu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
@@ -3980,6 +4063,7 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const
float attn_factor;
float beta_fast;
float beta_slow;
int32_t sections[4];
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
@@ -3987,23 +4071,23 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int32_t)*4);
const bool is_neox = mode & 2;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_mrope) {
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
}
if (is_vision) {
GGML_ASSERT(n_dims == ne00/2);
}
cl_kernel kernel;
if (!is_neox) {
switch (src0->type) {
case GGML_TYPE_F32:
kernel = backend_ctx->kernel_rope_norm_f32;
break;
case GGML_TYPE_F16:
kernel = backend_ctx->kernel_rope_norm_f16;
break;
default:
GGML_ASSERT(false);
};
} else {
if (is_neox) {
switch (src0->type) {
case GGML_TYPE_F32:
kernel = backend_ctx->kernel_rope_neox_f32;
@@ -4014,6 +4098,39 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const
default:
GGML_ASSERT(false);
};
} else if (is_mrope && !is_vision) {
switch (src0->type) {
case GGML_TYPE_F32:
kernel = backend_ctx->kernel_rope_multi_f32;
break;
case GGML_TYPE_F16:
kernel = backend_ctx->kernel_rope_multi_f16;
break;
default:
GGML_ASSERT(false);
};
} else if (is_vision) {
switch (src0->type) {
case GGML_TYPE_F32:
kernel = backend_ctx->kernel_rope_vision_f32;
break;
case GGML_TYPE_F16:
kernel = backend_ctx->kernel_rope_vision_f16;
break;
default:
GGML_ASSERT(false);
}
} else {
switch (src0->type) {
case GGML_TYPE_F32:
kernel = backend_ctx->kernel_rope_norm_f32;
break;
case GGML_TYPE_F16:
kernel = backend_ctx->kernel_rope_norm_f16;
break;
default:
GGML_ASSERT(false);
};
}
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
@@ -4049,6 +4166,9 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const
CL_CHECK(clSetKernelArg(kernel, 30, sizeof(float), &attn_factor));
CL_CHECK(clSetKernelArg(kernel, 31, sizeof(float), &beta_fast));
CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float), &beta_slow));
if (is_mrope || is_vision) {
CL_CHECK(clSetKernelArg(kernel, 33, sizeof(int32_t)*4, &sections));
}
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
size_t local_work_size[] = {(size_t)nth, 1, 1};
@@ -4064,6 +4184,98 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const
#endif
}
static void ggml_cl_im2col(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src1);
GGML_ASSERT(src1->extra);
GGML_ASSERT(dst);
GGML_ASSERT(dst->extra);
// src0 - filter, src1 - input
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
cl_command_queue queue = backend_ctx->queue;
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offset1 = extra1->offset + src1->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
const cl_long IC = src1->ne[is_2D ? 2 : 1];
const cl_long IH = is_2D ? src1->ne[1] : 1;
const cl_long IW = src1->ne[0];
const cl_long KH = is_2D ? src0->ne[1] : 1;
const cl_long KW = src0->ne[0];
const cl_long OH = is_2D ? dst->ne[2] : 1;
const cl_long OW = dst->ne[1];
// nb is byte offset, src is type float32
const cl_ulong delta_offset = src1->nb[is_2D ? 2 : 1]/4;
const cl_long batch = src1->ne[is_2D ? 3 : 2];
const cl_ulong batch_offset = src1->nb[is_2D ? 3 : 2]/4;
const cl_long pelements = OW*KW*KH;
const cl_long CHW = IC*KH*KW;
cl_kernel kernel;
if(dst->type == GGML_TYPE_F16) {
kernel = backend_ctx->kernel_im2col_f16;
} else {
kernel = backend_ctx->kernel_im2col_f32;
}
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &batch_offset));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &delta_offset));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_long), &IW));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_long), &IH));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_long), &IC));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_long), &OW));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_long), &OH));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_long), &KW));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_long), &KH));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_long), &pelements));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_long), &CHW));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &s0));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &s1));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &p0));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &p1));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &d0));
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &d1));
const int num_blocks = (pelements + 256 - 1) / 256;
size_t global_work_size[] = {(size_t)num_blocks*256, (size_t)OH, (size_t)batch*IC};
size_t local_work_size[] = {256, 1, 1};
#ifdef GGML_OPENCL_PROFILING
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
g_profiling_info.emplace_back();
populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
#else
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
#endif
}
//------------------------------------------------------------------------------
// Op offloading
//------------------------------------------------------------------------------
@@ -4122,6 +4334,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
}
func = ggml_cl_gelu;
break;
case GGML_UNARY_OP_GELU_QUICK:
if (!any_on_device) {
return false;
}
func = ggml_cl_gelu_quick;
break;
case GGML_UNARY_OP_SILU:
if (!any_on_device) {
return false;
@@ -4194,6 +4412,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
}
func = ggml_cl_rope;
break;
case GGML_OP_IM2COL:
if (!any_on_device) {
return false;
}
func = ggml_cl_im2col;
break;
default:
return false;
}
+389
View File
@@ -404,6 +404,7 @@ kernel void kernel_scale(
// gelu
//------------------------------------------------------------------------------
#define GELU_COEF_A 0.044715f
#define GELU_QUICK_COEF -1.702f
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f
kernel void kernel_gelu(
@@ -434,6 +435,32 @@ kernel void kernel_gelu_4(
dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
}
kernel void kernel_gelu_quick(
global float * src0,
ulong offset0,
global float * dst,
ulong offsetd
) {
src0 = (global float*)((global char*)src0 + offset0);
dst = (global float*)((global char*)dst + offsetd);
float x = src0[get_global_id(0)];
dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
}
kernel void kernel_gelu_quick_4(
global float4 * src0,
ulong offset0,
global float4 * dst,
ulong offsetd
) {
src0 = (global float4*)((global char*)src0 + offset0);
dst = (global float4*)((global char*)dst + offsetd);
float4 x = src0[get_global_id(0)];
dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
}
//------------------------------------------------------------------------------
// silu
//------------------------------------------------------------------------------
@@ -1325,6 +1352,368 @@ kernel void kernel_rope_neox_f16(
}
}
kernel void kernel_rope_multi_f32(
global void * src0,
ulong offset0,
global int * src1,
ulong offset1,
global float * src2,
ulong offset2,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne0,
int ne1,
int ne2,
int ne3,
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3,
int n_past,
int n_dims,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow,
int4 sections
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1);
src2 = (global float*)((global char*)src2 + offset2);
dst = (global float*)((global char*)dst + offsetd);
int i3 = get_group_id(2);
int i2 = get_group_id(1);
int i1 = get_group_id(0);
float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
global int * pos = src1;
const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3;
const int sec_w = sections.s1 + sections.s0;
float inv_ndims = -1.f/n_dims;
for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
if (i0 < n_dims) {
int ic = i0/2;
const int sector = (i0 / 2) % sect_dims;
float theta_base = 0.0f;
if (sector < sections.s0) {
theta_base = pos[i2];
}
else if (sector >= sections.s0 && sector < sec_w) {
theta_base = pos[i2 + ne2 * 1];
}
else if (sector >= sec_w && sector < sec_w + sections.s2) {
theta_base = pos[i2 + ne2 * 2];
}
else if (sector >= sec_w + sections.s2) {
theta_base = pos[i2 + ne2 * 3];
}
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = src[0];
const float x1 = src[n_dims/2];
dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
} else {
global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
kernel void kernel_rope_multi_f16(
global void * src0,
ulong offset0,
global int * src1,
ulong offset1,
global float * src2,
ulong offset2,
global half * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne0,
int ne1,
int ne2,
int ne3,
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3,
int n_past,
int n_dims,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow,
int4 sections
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1);
src2 = (global float*)((global char*)src2 + offset2);
dst = (global float*)((global char*)dst + offsetd);
int i3 = get_group_id(2);
int i2 = get_group_id(1);
int i1 = get_group_id(0);
float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
global int * pos = src1;
const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3;
const int sec_w = sections.s1 + sections.s0;
float inv_ndims = -1.f/n_dims;
for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
if (i0 < n_dims) {
int ic = i0/2;
const int sector = (i0 / 2) % sect_dims;
float theta_base = 0.0f;
if (sector < sections.s0) {
theta_base = pos[i2];
}
else if (sector >= sections.s0 && sector < sec_w) {
theta_base = pos[i2 + ne2 * 1];
}
else if (sector >= sec_w && sector < sec_w + sections.s2) {
theta_base = pos[i2 + ne2 * 2];
}
else if (sector >= sec_w + sections.s2) {
theta_base = pos[i2 + ne2 * 3];
}
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = src[0];
const float x1 = src[n_dims/2];
dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
} else {
global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
kernel void kernel_rope_vision_f32(
global void * src0,
ulong offset0,
global int * src1,
ulong offset1,
global float * src2,
ulong offset2,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne0,
int ne1,
int ne2,
int ne3,
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3,
int n_past,
int n_dims,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow,
int4 sections
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1);
src2 = (global float*)((global char*)src2 + offset2);
dst = (global float*)((global char*)dst + offsetd);
int i3 = get_group_id(2);
int i2 = get_group_id(1);
int i1 = get_group_id(0);
float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
global int * pos = src1;
const int sect_dims = sections.s0 + sections.s1;
const int sec_w = sections.s1 + sections.s0;
float inv_ndims = -1.f/n_dims;
for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
int ic = i0/2;
const int sector = (i0/2) % sect_dims;
float theta_base = 0.0f;
if (sector < sections.s0) {
const int p = sector;
theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p);
} else if (sector >= sections.s0 && sector < sec_w) {
const int p = sector - sections.s0;
theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p);
}
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = src[0];
const float x1 = src[n_dims];
dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
}
}
kernel void kernel_rope_vision_f16(
global void * src0,
ulong offset0,
global int * src1,
ulong offset1,
global float * src2,
ulong offset2,
global half * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne0,
int ne1,
int ne2,
int ne3,
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3,
int n_past,
int n_dims,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow,
int4 sections
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1);
src2 = (global float*)((global char*)src2 + offset2);
dst = (global float*)((global char*)dst + offsetd);
int i3 = get_group_id(2);
int i2 = get_group_id(1);
int i1 = get_group_id(0);
float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
global int * pos = src1;
const int sect_dims = sections.s0 + sections.s1;
const int sec_w = sections.s1 + sections.s0;
float inv_ndims = -1.f/n_dims;
for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
int ic = i0/2;
const int sector = (i0/2) % sect_dims;
float theta_base = 0.0f;
if (sector < sections.s0) {
const int p = sector;
theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p);
} else if (sector >= sections.s0 && sector < sec_w) {
const int p = sector - sections.s0;
theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p);
}
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = src[0];
const float x1 = src[n_dims];
dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
}
}
//------------------------------------------------------------------------------
// cpy
//------------------------------------------------------------------------------
@@ -0,0 +1,146 @@
#ifdef cl_khr_fp16
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#elif defined(cl_amd_fp16)
#pragma OPENCL EXTENSION cl_amd_fp16 : enable
#else
#error "Half precision floating point not supportedby OpenCL implementation on your device."
#endif
#ifdef cl_khr_subgroups
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#elif defined(cl_intel_subgroups)
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#error "Subgroup not supported on your device."
#endif
#ifdef cl_intel_required_subgroup_size
// Always use subgroup size of 32 on Intel.
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
// Always use subgroups size of 64 on Adreno.
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#else
// TODO: do not know how to choose subgroup size on other GPUs.
#error "Selecting subgroup size is not supported on your device."
#endif
kernel void kernel_im2col_f32(
global float * src1,
ulong offset1,
global float * dst,
ulong offsetd,
ulong batch_offset,
ulong delta_offset,
long IW,
long IH,
long IC,
long OW,
long OH,
long KW,
long KH,
long pelements,
long CHW,
int s0,
int s1,
int p0,
int p1,
int d0,
int d1
) {
// threadIdx.x + blockIdx.x * blockDim.x
long i = get_global_id(0);
if (i >= pelements) {
return;
}
src1 = (global float*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);
long ksize = OW * (KH > 1 ? KW : 1);
long kx = i / ksize;
long kd = kx * ksize;
long ky = (i - kd) / OW;
long ix = i % OW;
long oh = get_group_id(1);
long batch = get_group_id(2) / IC;
long ic = get_group_id(2) % IC;
long iiw = ix * s0 + kx * d0 - p0;
long iih = oh * s1 + ky * d1 - p1;
long offset_dst =
((batch * OH + oh) * OW + ix) * CHW +
(ic * (KW * KH) + ky * KW + kx);
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst[offset_dst] = 0.0f;
} else {
long offset_src = ic * delta_offset + batch * batch_offset;
dst[offset_dst] = src1[offset_src + iih * IW + iiw];
}
}
kernel void kernel_im2col_f16(
global float * src1,
ulong offset1,
global half * dst,
ulong offsetd,
ulong batch_offset,
ulong delta_offset,
long IW,
long IH,
long IC,
long OW,
long OH,
long KW,
long KH,
long pelements,
long CHW,
int s0,
int s1,
int p0,
int p1,
int d0,
int d1
) {
long i = get_global_id(0);
if (i >= pelements) {
return;
}
src1 = (global float*)((global char*)src1 + offset1);
dst = (global half*)((global char*)dst + offsetd);
long ksize = OW * (KH > 1 ? KW : 1);
long kx = i / ksize;
long kd = kx * ksize;
long ky = (i - kd) / OW;
long ix = i % OW;
long oh = get_group_id(1);
long batch = get_group_id(2) / IC;
long ic = get_group_id(2) % IC;
long iiw = ix * s0 + kx * d0 - p0;
long iih = oh * s1 + ky * d1 - p1;
long offset_dst =
((batch * OH + oh) * OW + ix) * CHW +
(ic * (KW * KH) + ky * KW + kx);
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst[offset_dst] = 0.0f;
} else {
long offset_src = ic * delta_offset + batch * batch_offset;
dst[offset_dst] = src1[offset_src + iih * IW + iiw];
}
}
+143 -7
View File
@@ -26,6 +26,10 @@
# include <unistd.h>
#endif
#include <cstring>
#include <fstream>
#include <filesystem>
namespace fs = std::filesystem;
#ifdef _WIN32
typedef SOCKET sockfd_t;
@@ -80,6 +84,7 @@ enum rpc_cmd {
RPC_CMD_FREE_BUFFER,
RPC_CMD_BUFFER_CLEAR,
RPC_CMD_SET_TENSOR,
RPC_CMD_SET_TENSOR_HASH,
RPC_CMD_GET_TENSOR,
RPC_CMD_COPY_TENSOR,
RPC_CMD_GRAPH_COMPUTE,
@@ -89,6 +94,9 @@ enum rpc_cmd {
RPC_CMD_COUNT,
};
// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
struct rpc_msg_get_alloc_size_req {
rpc_tensor tensor;
};
@@ -135,6 +143,10 @@ struct rpc_msg_buffer_clear_req {
uint8_t value;
};
struct rpc_msg_set_tensor_hash_rsp {
uint8_t result;
};
struct rpc_msg_get_tensor_req {
rpc_tensor tensor;
uint64_t offset;
@@ -187,6 +199,18 @@ struct ggml_backend_rpc_buffer_context {
// RPC helper functions
// Computes FNV-1a hash of the data
static uint64_t fnv_hash(const uint8_t * data, size_t len) {
const uint64_t fnv_prime = 0x100000001b3ULL;
uint64_t hash = 0xcbf29ce484222325ULL;
for (size_t i = 0; i < len; ++i) {
hash ^= data[i];
hash *= fnv_prime;
}
return hash;
}
static std::shared_ptr<socket_t> make_socket(sockfd_t fd) {
#ifdef _WIN32
if (fd == INVALID_SOCKET) {
@@ -483,10 +507,26 @@ static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_
static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
// input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
rpc_tensor rpc_tensor = serialize_tensor(tensor);
if (size > HASH_THRESHOLD) {
// input serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes)
size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + sizeof(uint64_t);
std::vector<uint8_t> input(input_size, 0);
uint64_t hash = fnv_hash((const uint8_t*)data, size);
memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &hash, sizeof(hash));
rpc_msg_set_tensor_hash_rsp response;
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, input.data(), input.size(), &response, sizeof(response));
GGML_ASSERT(status);
if (response.result) {
// the server has the same data, no need to send it
return;
}
}
// input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes)
size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size;
std::vector<uint8_t> input(input_size, 0);
rpc_tensor rpc_tensor = serialize_tensor(tensor);
memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
@@ -772,7 +812,9 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, si
class rpc_server {
public:
rpc_server(ggml_backend_t backend) : backend(backend) {}
rpc_server(ggml_backend_t backend, const char * cache_dir)
: backend(backend), cache_dir(cache_dir) {
}
~rpc_server();
void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
@@ -782,6 +824,7 @@ public:
bool free_buffer(const rpc_msg_free_buffer_req & request);
bool buffer_clear(const rpc_msg_buffer_clear_req & request);
bool set_tensor(const std::vector<uint8_t> & input);
bool set_tensor_hash(const std::vector<uint8_t> & input, rpc_msg_set_tensor_hash_rsp & response);
bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
@@ -789,6 +832,7 @@ public:
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
private:
bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
ggml_tensor * create_node(uint64_t id,
struct ggml_context * ctx,
@@ -797,6 +841,7 @@ private:
ggml_backend_t backend;
const char * cache_dir;
std::unordered_set<ggml_backend_buffer_t> buffers;
};
@@ -960,11 +1005,85 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
}
const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset);
if (cache_dir && size > HASH_THRESHOLD) {
uint64_t hash = fnv_hash((const uint8_t*)data, size);
char hash_str[17];
snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
// save to cache_dir/hash_str
fs::path cache_file = fs::path(cache_dir) / hash_str;
std::ofstream ofs(cache_file, std::ios::binary);
ofs.write((const char *)data, size);
printf("[%s] saved to '%s'\n", __func__, cache_file.c_str());
}
ggml_backend_tensor_set(tensor, data, offset, size);
ggml_free(ctx);
return true;
}
bool rpc_server::get_cached_file(uint64_t hash, std::vector<uint8_t> & data) {
if (!cache_dir) {
return false;
}
char hash_str[17];
snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
fs::path cache_file = fs::path(cache_dir) / hash_str;
if (!fs::exists(cache_file)) {
return false;
}
std::ifstream ifs(cache_file, std::ios::binary);
ifs.seekg(0, std::ios::end);
size_t size = ifs.tellg();
ifs.seekg(0, std::ios::beg);
data.resize(size);
ifs.read((char *)data.data(), size);
return true;
}
bool rpc_server::set_tensor_hash(const std::vector<uint8_t> & input, rpc_msg_set_tensor_hash_rsp & response)
{
// serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes) |
if (input.size() != sizeof(rpc_tensor) + 16) {
return false;
}
const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
uint64_t offset;
memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
const uint64_t * hash = (const uint64_t *)(input.data() + sizeof(rpc_tensor) + sizeof(offset));
std::vector<uint8_t> cached_file;
if (!get_cached_file(*hash, cached_file)) {
response.result = 0;
return true;
}
size_t size = cached_file.size();
struct ggml_init_params params {
/*.mem_size =*/ ggml_tensor_overhead(),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
struct ggml_context * ctx = ggml_init(params);
ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
if (tensor == nullptr) {
GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
ggml_free(ctx);
return false;
}
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size, *hash);
// sanitize tensor->data
{
const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
}
}
ggml_backend_tensor_set(tensor, cached_file.data(), offset, size);
response.result = 1;
ggml_free(ctx);
return true;
}
bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) {
struct ggml_init_params params {
/*.mem_size =*/ ggml_tensor_overhead(),
@@ -1148,8 +1267,9 @@ rpc_server::~rpc_server() {
}
}
static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
rpc_server server(backend);
static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
sockfd_t sockfd, size_t free_mem, size_t total_mem) {
rpc_server server(backend, cache_dir);
while (true) {
uint8_t cmd;
if (!recv_data(sockfd, &cmd, 1)) {
@@ -1260,6 +1380,20 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
}
break;
}
case RPC_CMD_SET_TENSOR_HASH: {
std::vector<uint8_t> input;
if (!recv_msg(sockfd, input)) {
return;
}
rpc_msg_set_tensor_hash_rsp response;
if (!server.set_tensor_hash(input, response)) {
return;
}
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
break;
}
case RPC_CMD_INIT_TENSOR: {
rpc_msg_init_tensor_req request;
if (!recv_msg(sockfd, &request,sizeof(request))) {
@@ -1335,7 +1469,9 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
}
}
void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint,
const char * cache_dir,
size_t free_mem, size_t total_mem) {
std::string host;
int port;
if (!parse_endpoint(endpoint, host, port)) {
@@ -1364,7 +1500,7 @@ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint
}
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
fflush(stdout);
rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
rpc_serve_client(backend, cache_dir, client_socket->fd, free_mem, total_mem);
printf("Client connection closed\n");
fflush(stdout);
}
-35
View File
@@ -66,41 +66,6 @@ int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block
return sycl_down_blk_size;
}
void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const ggml_sycl_op_flatten_t op) try {
const bool use_src1 = src1 != nullptr;
if(use_src1)
GGML_ASSERT(strcmp(src1->buffer->buft->iface.get_name(src1->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
// dd = data device
float * src0_ddf = (float *) src0->data;
float * src1_ddf = use_src1 ? (float *) src1->data : nullptr;
float * dst_ddf = (float *) dst->data;
ggml_sycl_pool_alloc<float> src0_f(ctx.pool());
ggml_sycl_pool_alloc<float> src1_f(ctx.pool());
ggml_sycl_pool_alloc<float> dst_f(ctx.pool());
ggml_sycl_set_device(ctx.device);
queue_ptr main_stream = ctx.stream();
// GGML_SYCL_DEBUG("ctx.device=%d, main_stream=%p src0_on_device=%d, src1_on_device=%d, dst_on_device=%d\n",
// ctx.device, main_stream, src0_on_device, src1_on_device, dst_on_device);
// do the computation
op(ctx, src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream);
// print_ggml_tensor("tensor", dst);
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}
void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector<queue_ptr> streams) {
for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) {
+8 -20
View File
@@ -494,12 +494,6 @@ static __dpct_inline__ Tp* get_pointer(sycl::local_accessor<Tp, dim> acc) {
int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size);
typedef void (*ggml_sycl_op_flatten_t)(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1,
ggml_tensor *dst, const float *src0_dd,
const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream);
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
int ne0, int ne1, int ne2, int ne3,
@@ -757,24 +751,22 @@ struct bin_bcast_sycl {
template <class op>
inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd,
const queue_ptr &main_stream) {
const ggml_tensor *src1, ggml_tensor *dst) {
dpct::queue_ptr main_stream = ctx.stream();
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
op()(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
op()(ctx, src0, src1, dst, (const float *)src0->data, (const float *)src1->data, (float *)dst->data, main_stream);
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd,
(sycl::half *)dst_dd, main_stream);
op()(ctx, src0, src1, dst, (const sycl::half *)src0->data, (const float *)src1->data,
(sycl::half *)dst->data, main_stream);
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, dst_dd,
op()(ctx, src0, src1, dst, (const sycl::half *)src0->data, (const float *)src1->data, (float *)dst->data,
main_stream);
} else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
op()(ctx, src0, src1, dst, (const int32_t *)src0_dd, (const int32_t *)src1_dd, (int32_t *)dst_dd,
op()(ctx, src0, src1, dst, (const int32_t *)src0->data, (const int32_t *)src1->data, (int32_t *)dst->data,
main_stream);
} else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
op()(ctx, src0, src1, dst, (const int16_t *)src0_dd, (const int16_t *)src1_dd, (int16_t *)dst_dd,
op()(ctx, src0, src1, dst, (const int16_t *)src0->data, (const int16_t *)src1->data, (int16_t *)dst->data,
main_stream);
} else {
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
@@ -784,8 +776,4 @@ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_t
}
bool gpu_has_xmx(sycl::device &dev);
void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const ggml_sycl_op_flatten_t op);
#endif // GGML_SYCL_COMMON_HPP
+185 -273
View File
@@ -509,497 +509,409 @@ static void pad_f32_sycl(const float *x, float *dst, const int ne00,
});
}
inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
ggml_tensor *dst, const float *src0_dd,
const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream) {
inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
silu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
silu_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
}
inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
ggml_tensor *dst, const float *src0_dd,
const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream) {
inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
gelu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
}
inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd,
const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
gelu_quick_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
gelu_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
}
inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
ggml_tensor *dst, const float *src0_dd,
const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream) {
inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
tanh_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
gelu_quick_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
}
inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
ggml_tensor *dst, const float *src0_dd,
const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream) {
inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
tanh_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
}
inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd,
const queue_ptr &main_stream) {
inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
hardsigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
}
inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd, const queue_ptr &main_stream) {
inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
hardswish_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
hardsigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
}
inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd, const queue_ptr &main_stream) {
inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
exp_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
hardswish_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
}
inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd, const queue_ptr &main_stream) {
inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
log_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
exp_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
}
inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd, const queue_ptr &main_stream) {
inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
sigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
log_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
}
inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd, const queue_ptr &main_stream) {
inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
sqrt_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
sigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
}
inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd, const queue_ptr &main_stream) {
inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
sin_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
sqrt_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
}
inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd, const queue_ptr &main_stream) {
inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
cos_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
sin_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
}
inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd, const queue_ptr &main_stream) {
inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
step_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
cos_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
}
inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd, const queue_ptr &main_stream) {
inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
neg_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
step_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
}
inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd,
const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
neg_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
}
inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
float negative_slope;
memcpy(&negative_slope, dst->op_params, sizeof(float));
leaky_relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), negative_slope, main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
leaky_relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), negative_slope, main_stream);
}
inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
ggml_tensor *dst, const float *src0_dd,
const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream) {
inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
sqr_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
sqr_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
}
inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd,
const queue_ptr &main_stream) {
inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
const float sf0 = (float)dst->ne[0]/src0->ne[0];
const float sf1 = (float)dst->ne[1]/src0->ne[1];
const float sf2 = (float)dst->ne[2]/src0->ne[2];
const float sf3 = (float)dst->ne[3]/src0->ne[3];
const float sf0 = (float)dst->ne[0]/dst->src[0]->ne[0];
const float sf1 = (float)dst->ne[1]/dst->src[0]->ne[1];
const float sf2 = (float)dst->ne[2]/dst->src[0]->ne[2];
const float sf3 = (float)dst->ne[3]/dst->src[0]->ne[3];
upscale_f32_sycl(src0_dd, dst_dd, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
upscale_f32_sycl(src0_dd, dst_dd, dst->src[0]->nb[0], dst->src[0]->nb[1], dst->src[0]->nb[2], dst->src[0]->nb[3],
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3,
main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
}
inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
ggml_tensor *dst, const float *src0_dd,
const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream) {
inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
GGML_ASSERT(dst->src[0]->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
pad_f32_sycl(src0_dd, dst_dd,
src0->ne[0], src0->ne[1], src0->ne[2],
dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2],
dst->ne[0], dst->ne[1], dst->ne[2], main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
}
inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
ggml_tensor *dst, const float *src0_dd,
const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream) {
inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
const float * src1_dd = static_cast<const float*>(dst->src[1]->data);
float * dst_dd = static_cast<float *>(dst->data);
int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
// int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
int offset = dst->op_params[3] / 4; // offset in bytes
acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, main_stream);
GGML_UNUSED(dst);
GGML_UNUSED(ctx);
acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), dst->src[1]->ne[0], dst->src[1]->ne[1], dst->src[1]->ne[2], nb1, nb2, offset, main_stream);
}
inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
ggml_tensor *dst, const float *src0_dd,
const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream) {
inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_add>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_add>>(ctx, dst->src[0], dst->src[1], dst);
}
inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
ggml_tensor *dst, const float *src0_dd,
const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream) {
inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(ctx, dst->src[0], dst->src[1], dst);
}
inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
ggml_tensor *dst, const float *src0_dd,
const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream) {
inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(ctx, dst->src[0], dst->src[1], dst);
}
inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
ggml_tensor *dst, const float *src0_dd,
const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream) {
inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_div>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_div>>(ctx, dst->src[0], dst->src[1], dst);
}
void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sqrt);
ggml_sycl_op_sqrt(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
void ggml_sycl_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sin);
ggml_sycl_op_sin(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
void ggml_sycl_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_cos);
ggml_sycl_op_cos(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
void ggml_sycl_acc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_acc);
ggml_sycl_op_acc(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_gelu);
ggml_sycl_op_gelu(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
void ggml_sycl_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_silu);
ggml_sycl_op_silu(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_gelu_quick);
ggml_sycl_op_gelu_quick(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_tanh);
ggml_sycl_op_tanh(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
void ggml_sycl_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_relu);
ggml_sycl_op_relu(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
void ggml_sycl_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sigmoid);
ggml_sycl_op_sigmoid(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_hardsigmoid);
ggml_sycl_op_hardsigmoid(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_hardswish);
ggml_sycl_op_hardswish(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
void ggml_sycl_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_exp);
ggml_sycl_op_exp(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
void ggml_sycl_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_log);
ggml_sycl_op_log(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
void ggml_sycl_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_neg);
ggml_sycl_op_neg(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
void ggml_sycl_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_step);
ggml_sycl_op_step(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_leaky_relu);
ggml_sycl_op_leaky_relu(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sqr);
ggml_sycl_op_sqr(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_upscale);
ggml_sycl_op_upscale(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_pad);
ggml_sycl_op_pad(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
@@ -1007,24 +919,24 @@ void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
void ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_add);
ggml_sycl_op_add(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sub);
ggml_sycl_op_sub(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_mul);
ggml_sycl_op_mul(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
void ggml_sycl_div(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_div);
ggml_sycl_op_div(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
+24 -20
View File
@@ -257,50 +257,54 @@ static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tens
GGML_UNUSED(ctx);
}
void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_d, const float *src1_d,
float *dst_d, const queue_ptr &stream) {
void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_I32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
GGML_ASSERT(dst->src[0]->nb[0] == ggml_type_size(dst->src[0]->type));
GGML_ASSERT(dst->src[1]->nb[0] == ggml_type_size(dst->src[1]->type));
GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
const int32_t * src1_i32 = (const int32_t *) src1_d;
switch (src0->type) {
const int32_t * src1_i32 = (const int32_t *) dst->src[1]->data;
/* TODO: Refactor and remove duplicates */
switch (dst->src[0]->type) {
case GGML_TYPE_F16:
get_rows_sycl_float(ctx, src0, src1, dst, (const sycl::half *)src0_d,
src1_i32, dst_d, stream);
get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const sycl::half *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
break;
case GGML_TYPE_F32:
get_rows_sycl_float(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
break;
case GGML_TYPE_Q4_0:
if (ctx.opt_feature.reorder && dst->op == GGML_OP_MUL_MAT) {
get_rows_sycl_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
get_rows_sycl_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
} else {
get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
}
break;
case GGML_TYPE_Q4_1:
get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
break;
case GGML_TYPE_Q5_0:
get_rows_sycl<QK5_0, QR5_0, dequantize_q5_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
get_rows_sycl<QK5_0, QR5_0, dequantize_q5_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
break;
case GGML_TYPE_Q5_1:
get_rows_sycl<QK5_1, QR5_1, dequantize_q5_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
get_rows_sycl<QK5_1, QR5_1, dequantize_q5_1>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
break;
case GGML_TYPE_Q8_0:
get_rows_sycl<QK8_0, QR8_0, dequantize_q8_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
get_rows_sycl<QK8_0, QR8_0, dequantize_q8_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
break;
default:
// TODO: k-quants
GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(dst->src[0]->type));
GGML_ABORT("fatal error");
}
}
+1 -4
View File
@@ -15,9 +15,6 @@
#include "common.hpp"
void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_d, const float *src1_d,
float *dst_d, const queue_ptr &stream);
void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
#endif // GGML_SYCL_GETROWS_HPP
+85 -127
View File
@@ -1988,16 +1988,8 @@ catch (sycl::exception const &exc) {
std::exit(1);
}
static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_d, const float *src1_d,
float *dst_d,
const queue_ptr &main_stream) {
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_repeat>>(ctx, dst, src0, dst, nullptr, src0_d, dst_d, main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(src1_d);
static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_repeat>>(ctx, dst, dst->src[0], dst);
}
@@ -2132,13 +2124,14 @@ catch (sycl::exception const &exc) {
std::exit(1);
}
static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd, const queue_ptr &main_stream) {
static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
const int32_t * opts = (const int32_t *)dst->op_params;
enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
@@ -2149,8 +2142,8 @@ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tens
const int p0 = opts[5];
const int p1 = opts[6];
const int64_t IH = src0->ne[1];
const int64_t IW = src0->ne[0];
const int64_t IH = dst->src[0]->ne[1];
const int64_t IW = dst->src[0]->ne[0];
const int64_t N = dst->ne[3];
const int64_t OC = dst->ne[2];
@@ -2169,163 +2162,125 @@ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tens
parallel_elements, src0_dd, dst_dd, op,
item_ct1);
});
GGML_UNUSED(src1);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
}
inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd,
const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
const int64_t ne = ggml_nelements(src0);
const int64_t ne = ggml_nelements(dst->src[0]);
sum_rows_f32_sycl(src0_dd, dst_dd, ne, 1, main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
}
inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd,
const queue_ptr &main_stream) {
inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
const int64_t ncols = dst->src[0]->ne[0];
const int64_t nrows = ggml_nrows(dst->src[0]);
sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
}
inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd,
const queue_ptr &main_stream) {
inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_I32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
int32_t * dst_dd = static_cast<int32_t *>(dst->data);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_I32);
const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
const int64_t ncols = dst->src[0]->ne[0];
const int64_t nrows = ggml_nrows(dst->src[0]);
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
argsort_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, order, main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
argsort_f32_i32_sycl(src0_dd, (int *) dst_dd, ncols, nrows, order, main_stream);
}
inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd,
const queue_ptr &main_stream) {
inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_I32);
const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
int32_t * dst_dd = static_cast<int32_t *>(dst->data);
argmax_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, main_stream);
const int64_t ncols = dst->src[0]->ne[0];
const int64_t nrows = ggml_nrows(dst->src[0]);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
argmax_f32_i32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
}
inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1,
ggml_tensor *dst, const float *src0_dd,
const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream) {
inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx,ggml_tensor *dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int nrows0 = ggml_nrows(src0);
const int64_t ne00 = dst->src[0]->ne[0];
const int64_t ne01 = dst->src[0]->ne[1];
const int nrows0 = ggml_nrows(dst->src[0]);
const int n_past = ((int32_t *) dst->op_params)[0];
diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
}
inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
ggml_tensor *dst, const float *src0_dd,
const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream) {
inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
float scale;
memcpy(&scale, dst->op_params, sizeof(float));
scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(src0), main_stream);
scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(dst->src[0]), main_stream);
/*
DPCT1010:87: SYCL uses exceptions to report errors and does not use the
error codes. The call was replaced with 0. You need to rewrite this code.
*/
SYCL_CHECK(0);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
}
inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
ggml_tensor *dst, const float *src0_dd,
const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream) {
inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
float min;
float max;
memcpy(&min, dst->op_params, sizeof(float));
memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
clamp_f32_sycl(src0_dd, dst_dd, min, max, ggml_nelements(src0), main_stream);
clamp_f32_sycl(src0_dd, dst_dd, min, max, ggml_nelements(dst->src[0]), ctx.stream());
/*
DPCT1010:88: SYCL uses exceptions to report errors and does not use the
error codes. The call was replaced with 0. You need to rewrite this code.
*/
SYCL_CHECK(0);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
}
static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
@@ -2695,37 +2650,37 @@ catch (sycl::exception const &exc) {
static void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_repeat);
ggml_sycl_op_repeat(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_get_rows);
ggml_sycl_op_get_rows(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_norm);
ggml_sycl_op_norm(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_rms_norm);
ggml_sycl_op_rms_norm(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_l2_norm);
ggml_sycl_op_l2_norm(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_group_norm);
ggml_sycl_op_group_norm(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
@@ -3269,48 +3224,48 @@ catch (sycl::exception const &exc) {
}
static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_scale);
ggml_sycl_op_scale(ctx, dst);
}
static void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_clamp);
ggml_sycl_op_clamp(ctx, dst);
}
static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_diag_mask_inf);
ggml_sycl_op_diag_mask_inf(ctx, dst);
}
static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(ggml_is_contiguous(dst->src[0])); // TODO: this restriction is temporary until non-cont support is implemented
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_rope);
ggml_sycl_op_rope(ctx, dst);
}
static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_pool2d);
ggml_sycl_op_pool2d(ctx, dst);
}
static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_im2col);
ggml_sycl_op_im2col(ctx, dst);
}
static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sum);
ggml_sycl_op_sum(ctx, dst);
}
static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sum_rows);
ggml_sycl_op_sum_rows(ctx, dst);
}
static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_argsort);
ggml_sycl_op_argsort(ctx, dst);
}
static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_argmax);
ggml_sycl_op_argmax(ctx, dst);
}
@@ -3335,7 +3290,7 @@ catch (sycl::exception const &exc) {
std::exit(1);
}
static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) {
static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) try {
if (!g_sycl_loaded) return false;
if (dst->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(dst->src[0]->buffer)) {
@@ -3528,6 +3483,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
}
return true;
} catch (sycl::exception & e) {
std::cerr << e.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
std::exit(1);
}
GGML_API void ggml_backend_sycl_get_device_description(int device, char *description,
+5 -10
View File
@@ -82,10 +82,9 @@ static void im2col_sycl(
}
}
void ggml_sycl_op_im2col(
ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream) {
void ggml_sycl_op_im2col(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
@@ -115,12 +114,8 @@ void ggml_sycl_op_im2col(
const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
if (dst->type == GGML_TYPE_F16) {
im2col_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
im2col_sycl((const float *) src1->data, (sycl::half *)dst->data, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, ctx.stream());
} else {
im2col_sycl(src1_dd, (float *)dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
im2col_sycl((const float *) src1->data, (float *)dst->data, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, ctx.stream());
}
GGML_UNUSED(src0);
GGML_UNUSED(src0_dd);
GGML_UNUSED(ctx);
}
+1 -3
View File
@@ -16,8 +16,6 @@
#include "common.hpp"
void ggml_sycl_op_im2col(
ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream);
ggml_backend_sycl_context & ctx, ggml_tensor *dst);
#endif // GGML_SYCL_IM2COL_HPP
+35 -47
View File
@@ -397,90 +397,78 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
}
}
void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1,
ggml_tensor* dst, const float* src0_dd,
const float* src1_dd, float* dst_dd,
const queue_ptr& main_stream) {
void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
const int64_t ne00 = dst->src[0]->ne[0];
const int64_t nrows = ggml_nrows(dst->src[0]);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
(void)src1;
(void)dst;
(void)src1_dd;
}
void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
const ggml_tensor* src1, ggml_tensor* dst,
const float* src0_dd, const float* src1_dd,
float* dst_dd,
const queue_ptr& main_stream) {
void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
int num_groups = dst->op_params[0];
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
float eps;
memcpy(&eps, dst->op_params + 1, sizeof(float));
int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
group_norm_f32_sycl(src0_dd, dst_dd, num_groups, eps, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream, ctx.device);
(void)src1;
(void)dst;
(void)src1_dd;
GGML_UNUSED(ctx);
int group_size = dst->src[0]->ne[0] * dst->src[0]->ne[1] * ((dst->src[0]->ne[2] + num_groups - 1) / num_groups);
group_norm_f32_sycl(src0_dd, dst_dd, num_groups, eps, group_size, dst->src[0]->ne[0] * dst->src[0]->ne[1] * dst->src[0]->ne[2], main_stream, ctx.device);
}
void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
const ggml_tensor* src1, ggml_tensor* dst,
const float* src0_dd, const float* src1_dd,
float* dst_dd,
const queue_ptr& main_stream) {
void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
const int64_t ne00 = dst->src[0]->ne[0];
const int64_t nrows = ggml_nrows(dst->src[0]);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
rms_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
(void)src1;
(void)dst;
(void)src1_dd;
}
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
const ggml_tensor* src1, ggml_tensor* dst,
const float* src0_dd, const float* src1_dd,
float* dst_dd,
const queue_ptr& main_stream) {
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const int64_t ne00 = dst->src[0]->ne[0];
const int64_t nrows = ggml_nrows(dst->src[0]);
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
l2_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
(void)src1;
(void)dst;
(void)src1_dd;
}
+4 -19
View File
@@ -15,27 +15,12 @@
#include "common.hpp"
void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1,
ggml_tensor* dst, const float* src0_dd,
const float* src1_dd, float* dst_dd,
const queue_ptr& main_stream);
void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
const ggml_tensor* src1, ggml_tensor* dst,
const float* src0_dd, const float* src1_dd,
float* dst_dd,
const queue_ptr& main_stream);
void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
const ggml_tensor* src1, ggml_tensor* dst,
const float* src0_dd, const float* src1_dd,
float* dst_dd,
const queue_ptr& main_stream);
void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
const ggml_tensor* src1, ggml_tensor* dst,
const float* src0_dd, const float* src1_dd,
float* dst_dd,
const queue_ptr& main_stream);
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
#endif // GGML_SYCL_NORM_HPP
+20 -25
View File
@@ -192,18 +192,15 @@ static void rope_neox_sycl(
}
}
void ggml_sycl_op_rope(
ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream) {
const ggml_tensor * src2 = dst->src[2];
void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type);
GGML_ASSERT(dst->src[0]->type == dst->type);
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t nr = ggml_nrows(src0);
const int64_t ne00 = dst->src[0]->ne[0];
const int64_t ne01 = dst->src[0]->ne[1];
const int64_t nr = ggml_nrows(dst->src[0]);
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
@@ -228,49 +225,47 @@ void ggml_sycl_op_rope(
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const int32_t * pos = (const int32_t *) src1_dd;
const int32_t * pos = (const int32_t *) dst->src[1]->data;
const float * freq_factors = nullptr;
if (src2 != nullptr) {
freq_factors = (const float *) src2->data;
if (dst->src[2] != nullptr) {
freq_factors = (const float *) dst->src[2]->data;
}
rope_corr_dims corr_dims;
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
// compute
if (is_neox) {
if (src0->type == GGML_TYPE_F32) {
if (dst->src[0]->type == GGML_TYPE_F32) {
rope_neox_sycl(
(const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
(const float *)dst->src[0]->data, (float *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, main_stream
);
} else if (src0->type == GGML_TYPE_F16) {
} else if (dst->src[0]->type == GGML_TYPE_F16) {
rope_neox_sycl(
(const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
(const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, main_stream
);
} else {
GGML_ABORT("fatal error");
}
} else {
if (src0->type == GGML_TYPE_F32) {
if (dst->src[0]->type == GGML_TYPE_F32) {
rope_norm_sycl(
(const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
(const float *)dst->src[0]->data, (float *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, main_stream
);
} else if (src0->type == GGML_TYPE_F16) {
} else if (dst->src[0]->type == GGML_TYPE_F16) {
rope_norm_sycl(
(const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
(const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, main_stream
);
} else {
GGML_ABORT("fatal error");
}
}
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_dd);
GGML_UNUSED(ctx);
}
+1 -3
View File
@@ -15,8 +15,6 @@
#include "common.hpp"
void ggml_sycl_op_rope(
ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream);
void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
#endif // GGML_SYCL_ROPE_HPP
+32 -22
View File
@@ -23,32 +23,40 @@ if (Vulkan_FOUND)
../../include/ggml-vulkan.h
)
# Compile a test shader to determine whether GL_KHR_cooperative_matrix is supported.
# If it's not, there will be an error to stderr.
# If it's supported, set a define to indicate that we should compile those shaders
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp"
OUTPUT_VARIABLE glslc_output
ERROR_VARIABLE glslc_error)
if(NOT DEFINED GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
# Compile a test shader to determine whether GL_KHR_cooperative_matrix is supported.
# If it's not, there will be an error to stderr.
# If it's supported, set a define to indicate that we should compile those shaders
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp"
OUTPUT_VARIABLE glslc_output
ERROR_VARIABLE glslc_error)
if (${glslc_error} MATCHES ".*extension not supported: GL_KHR_cooperative_matrix.*")
message(STATUS "GL_KHR_cooperative_matrix not supported by glslc")
else()
message(STATUS "GL_KHR_cooperative_matrix supported by glslc")
add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (${glslc_error} MATCHES ".*extension not supported: GL_KHR_cooperative_matrix.*")
message(STATUS "GL_KHR_cooperative_matrix not supported by glslc")
set(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT OFF CACHE INTERNAL "Whether coopmat is supported by glslc")
else()
message(STATUS "GL_KHR_cooperative_matrix supported by glslc")
add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
set(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT ON CACHE INTERNAL "Whether coopmat is supported by glslc")
endif()
endif()
# Compile a test shader to determine whether GL_NV_cooperative_matrix2 is supported.
# If it's not, there will be an error to stderr.
# If it's supported, set a define to indicate that we should compile those shaders
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp"
OUTPUT_VARIABLE glslc_output
ERROR_VARIABLE glslc_error)
if(NOT DEFINED GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
# Compile a test shader to determine whether GL_NV_cooperative_matrix2 is supported.
# If it's not, there will be an error to stderr.
# If it's supported, set a define to indicate that we should compile those shaders
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp"
OUTPUT_VARIABLE glslc_output
ERROR_VARIABLE glslc_error)
if (${glslc_error} MATCHES ".*extension not supported: GL_NV_cooperative_matrix2.*")
message(STATUS "GL_NV_cooperative_matrix2 not supported by glslc")
else()
message(STATUS "GL_NV_cooperative_matrix2 supported by glslc")
add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (${glslc_error} MATCHES ".*extension not supported: GL_NV_cooperative_matrix2.*")
message(STATUS "GL_NV_cooperative_matrix2 not supported by glslc")
set(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT OFF CACHE INTERNAL "Whether coopmat2 is supported by glslc")
else()
message(STATUS "GL_NV_cooperative_matrix2 supported by glslc")
add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
set(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT ON CACHE INTERNAL "Whether coopmat2 is supported by glslc")
endif()
endif()
target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan)
@@ -119,6 +127,8 @@ if (Vulkan_FOUND)
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders
CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE}
-DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}
-DGGML_VULKAN_COOPMAT_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT_GLSLC_SUPPORT}
-DGGML_VULKAN_COOPMAT2_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT}
BUILD_COMMAND ${CMAKE_COMMAND} --build .
INSTALL_COMMAND ${CMAKE_COMMAND} --install .
INSTALL_DIR ${CMAKE_BINARY_DIR}
+4
View File
@@ -8764,6 +8764,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
default:
return false;
}
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
// different head sizes of K and V are not supported yet
return false;
}
if (op->src[0]->type != GGML_TYPE_F32) {
return false;
}
@@ -1,5 +1,11 @@
find_package (Threads REQUIRED)
if (GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
endif()
if (GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
endif()
set(TARGET vulkan-shaders-gen)
add_executable(${TARGET} vulkan-shaders-gen.cpp)
install(TARGETS ${TARGET} RUNTIME)
+1 -1
View File
@@ -4369,7 +4369,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
}
// permute(0, 2, 1, 3)
int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
int64_t ne[4] = { v->ne[0], q->ne[2], q->ne[1], q->ne[3] };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
float params[] = { scale, max_bias, logit_softcap };
+40
View File
@@ -286,6 +286,8 @@ class MODEL_ARCH(IntEnum):
GRANITE_MOE = auto()
CHAMELEON = auto()
WAVTOKENIZER_DEC = auto()
PLM = auto()
BAILINGMOE = auto()
class MODEL_TENSOR(IntEnum):
@@ -488,6 +490,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.GRANITE_MOE: "granitemoe",
MODEL_ARCH.CHAMELEON: "chameleon",
MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec",
MODEL_ARCH.PLM: "plm",
MODEL_ARCH.BAILINGMOE: "bailingmoe",
}
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -1464,6 +1468,20 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_UP_SHEXP,
MODEL_TENSOR.FFN_EXP_PROBS_B,
],
MODEL_ARCH.PLM: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_KV_A_MQA,
MODEL_TENSOR.ATTN_KV_A_NORM,
MODEL_TENSOR.ATTN_KV_B,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_DOWN,
],
MODEL_ARCH.CHATGLM : [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.ROPE_FREQS,
@@ -1651,6 +1669,25 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.POSNET_ATTN_V,
MODEL_TENSOR.POSNET_ATTN_OUT,
],
MODEL_ARCH.BAILINGMOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
],
# TODO
}
@@ -1703,6 +1740,9 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
],
MODEL_ARCH.BAILINGMOE: [
MODEL_TENSOR.ROPE_FREQS,
],
}
#
+1
View File
@@ -29,6 +29,7 @@ class TensorNameMap:
"shared", # t5
"rwkv.embeddings", # rwkv6
"model.embeddings", # rwkv7
"model.word_embeddings", # bailingmoe
),
# Token type embeddings
+6
View File
@@ -108,6 +108,8 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30,
LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
};
enum llama_rope_type {
@@ -1265,6 +1267,10 @@ extern "C" {
float tau,
float eta);
/// @details Intializes a GBNF grammar, see grammars/README.md for details.
/// @param vocab The vocabulary that this grammar will be used with.
/// @param grammar_str The production rules for the grammar, encoded as a string. Returns an empty grammar if empty. Returns NULL if parsing of grammar_str fails.
/// @param grammar_root The name of the start symbol for the grammar.
LLAMA_API struct llama_sampler * llama_sampler_init_grammar(
const struct llama_vocab * vocab,
const char * grammar_str,
+34
View File
@@ -0,0 +1,34 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg id="Layer_1" xmlns="http://www.w3.org/2000/svg" version="1.1" viewBox="0 0 1500 500">
<!-- Generator: Adobe Illustrator 29.3.1, SVG Export Plug-In . SVG Version: 2.1.0 Build 151) -->
<defs>
<style>
.st0 {
fill: #ff8236;
}
.st1 {
fill: #fff;
}
.st2 {
fill: #1b1f20;
}
</style>
</defs>
<rect class="st2" width="1500" height="500" rx="16" ry="16"/>
<g>
<path class="st1" d="M749.4,353.8l5.4-204.1,20.4-.8,45.1,98.8,42.5-99h19l6.5,205h-38l-2-98-24.9,61.4c-1,1.3-8,1.3-9-1l-25.6-61.4-1.5,99h-38Z"/>
<path class="st1" d="M727.5,240.1c-10.8-27.1-53.1-24.5-75.3-14.7l3.1,28.4c9.2-1.9,30-8,37.5-1,.9.9,3.5,5.7,3.5,6.5v16.5c-31.8-17.2-54.5,6.1-54.4,38.5,0,36.5,28.4,57.3,56.4,27.5v12h32v-104.5c0-.5-2.4-8-2.8-9.2ZM696.4,327.8c-8.4,1.7-15.4,2.9-19.2-6.3-5.8-14,.6-37.9,19.2-27.2v33.5Z"/>
<path class="st1" d="M899.4,353.8l47.6-205.1h30.3c0,.1,47,205.1,47,205.1h-38l-7.9-33.6h-34.1l-7.9,33.6h-37ZM951.4,285.8h20l-10.5-56-9.5,56Z"/>
<polygon class="st1" points="490.4 148.8 490.4 317.3 491.9 318.8 534.4 318.8 534.4 353.8 451.4 353.8 451.4 150.3 452.9 148.8 490.4 148.8"/>
<polygon class="st1" points="589.4 148.8 589.4 318.8 633.4 318.8 633.4 353.8 550.4 353.8 550.4 148.8 589.4 148.8"/>
<g>
<path class="st0" d="M1163.3,226.8l-13.5,24c-17.8-13.7-44.2-15.7-62-1-28.7,23.7-26.7,78.5,18,78.8,12.5,0,23.1-5.9,34.5-9.8l6,23.9c-10.1,4.7-20.4,9.5-31.5,11-101.2,13.8-95.4-132.3-3.9-139.9,19.2-1.6,36.1,3.4,52.5,13Z"/>
<path class="st0" d="M1093.4,203.8c-15.4,4.6-29.7,13.1-40.5,25-2-24.2,3.4-73.1,30.3-82.7,4-1.4,17.7-4.9,17.3,2.2s-9.9,19.3-12.2,25.9c-4,11.6-.3,19.6,5.2,29.7Z"/>
<polygon class="st0" points="1131.4 258.8 1131.4 276.8 1147.4 276.8 1147.4 290.8 1131.4 290.8 1131.4 307.8 1116.4 307.8 1116.4 290.8 1099.4 290.8 1099.4 276.8 1114.9 276.8 1116.4 275.3 1116.4 258.8 1131.4 258.8"/>
<polygon class="st0" points="1186.4 258.8 1186.4 275.3 1187.9 276.8 1203.4 276.8 1203.4 290.8 1186.4 290.8 1186.4 307.8 1171.4 307.8 1171.4 290.8 1155.4 290.8 1155.4 276.8 1171.4 276.8 1171.4 258.8 1186.4 258.8"/>
<path class="st0" d="M1142.3,156.9c2,3-9.3,15.9-11.1,19.2-5.2,9.8-1.7,15.4,2.2,24.7-11.3-1.7-21.8-.3-33,1,2.5-21.5,14.6-52.8,41.9-44.9Z"/>
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 2.3 KiB

+16 -3
View File
@@ -69,7 +69,11 @@ while read c; do
git format-patch -U${ctx} -k $c~1..$c --stdout -- \
CMakeLists.txt \
src/CMakeLists.txt \
cmake/FindSIMD.cmake \
cmake/BuildTypes.cmake \
cmake/GitVars.cmake \
cmake/common.cmake \
cmake/ggml-config.cmake.in \
src/ggml-cpu/cmake/FindSIMD.cmake \
src/ggml*.h \
src/ggml*.c \
src/ggml*.cpp \
@@ -121,7 +125,12 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then
#
# CMakelists.txt -> ggml/CMakeLists.txt
# src/CMakeLists.txt -> ggml/src/CMakeLists.txt
# cmake/FindSIMD.cmake -> ggml/cmake/FindSIMD.cmake
# cmake/BuildTypes.cmake -> ggml/cmake/BuildTypes.cmake
# cmake/GitVars.cmake -> ggml/cmake/GitVars.cmake
# cmake/common.cmake -> ggml/cmake/common.cmake
# cmake/ggml-config.cmake.in -> ggml/cmake/ggml-config.cmake.in
# src/ggml-cpu/cmake/FindSIMD.cmake -> ggml/src/ggml-cpu/cmake/FindSIMD.cmake
#
# src/ggml*.c -> ggml/src/ggml*.c
# src/ggml*.cpp -> ggml/src/ggml*.cpp
@@ -151,7 +160,11 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then
cat ggml-src.patch | sed -E \
-e 's/(^[[:space:]]| [ab]\/)CMakeLists.txt/\1ggml\/CMakeLists.txt/g' \
-e 's/(^[[:space:]]| [ab]\/)src\/CMakeLists.txt/\1ggml\/src\/CMakeLists.txt/g' \
-e 's/(^[[:space:]]| [ab]\/)cmake\/FindSIMD.cmake/\1ggml\/cmake\/FindSIMD.cmake/g' \
-e 's/(^[[:space:]]| [ab]\/)cmake\/BuildTypes.cmake/\1ggml\/cmake\/BuildTypes.cmake/g' \
-e 's/(^[[:space:]]| [ab]\/)cmake\/GitVars.cmake/\1ggml\/cmake\/GitVars.cmake/g' \
-e 's/(^[[:space:]]| [ab]\/)cmake\/common.cmake/\1ggml\/cmake\/common.cmake/g' \
-e 's/(^[[:space:]]| [ab]\/)cmake\/ggml-config.cmake.in/\1ggml\/cmake\/ggml-config.cmake.in/g' \
-e 's/(^[[:space:]]| [ab]\/)src\/ggml-cpu\/cmake\/FindSIMD.cmake/\1ggml\/src\/ggml-cpu\/cmake\/FindSIMD.cmake/g' \
-e 's/([[:space:]]| [ab]\/)src\/ggml(.*)\.c/\1ggml\/src\/ggml\2.c/g' \
-e 's/([[:space:]]| [ab]\/)src\/ggml(.*)\.cpp/\1ggml\/src\/ggml\2.cpp/g' \
-e 's/([[:space:]]| [ab]\/)src\/ggml(.*)\.h/\1ggml\/src\/ggml\2.h/g' \
+1 -1
View File
@@ -1 +1 @@
c7dfe3d174f98b14801f9ed12f129179d3e7b638
d53795ee70aa545464569d71caa46f38c05c1982
+3 -1
View File
@@ -2,7 +2,9 @@
cp -rpv ../ggml/CMakeLists.txt ./ggml/CMakeLists.txt
cp -rpv ../ggml/src/CMakeLists.txt ./ggml/src/CMakeLists.txt
cp -rpv ../ggml/cmake/FindSIMD.cmake ./ggml/cmake/FindSIMD.cmake
cp -rpv ../ggml/cmake/* ./ggml/cmake/
cp -rpv ../ggml/src/ggml-cpu/cmake/* ./ggml/src/ggml-cpu/cmake/
cp -rpv ../ggml/src/ggml*.c ./ggml/src/
cp -rpv ../ggml/src/ggml*.cpp ./ggml/src/
+41
View File
@@ -65,6 +65,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
{ LLM_ARCH_CHAMELEON, "chameleon" },
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
{ LLM_ARCH_PLM, "plm" },
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@@ -1043,6 +1045,22 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
},
},
{
LLM_ARCH_PLM,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
{ LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" },
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_CHATGLM,
{
@@ -1392,6 +1410,29 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" },
},
},
{
LLM_ARCH_BAILINGMOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
},
},
{
LLM_ARCH_UNKNOWN,
{
+2
View File
@@ -69,6 +69,8 @@ enum llm_arch {
LLM_ARCH_GRANITE_MOE,
LLM_ARCH_CHAMELEON,
LLM_ARCH_WAVTOKENIZER_DEC,
LLM_ARCH_PLM,
LLM_ARCH_BAILINGMOE,
LLM_ARCH_UNKNOWN,
};
+41 -1
View File
@@ -59,6 +59,8 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "granite", LLM_CHAT_TEMPLATE_GRANITE },
{ "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
{ "megrez", LLM_CHAT_TEMPLATE_MEGREZ },
{ "yandex", LLM_CHAT_TEMPLATE_YANDEX },
{ "bailing", LLM_CHAT_TEMPLATE_BAILING },
};
llm_chat_template llm_chat_template_from_str(const std::string & name) {
@@ -168,6 +170,10 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_GIGACHAT;
} else if (tmpl_contains("<|role_start|>")) {
return LLM_CHAT_TEMPLATE_MEGREZ;
} else if (tmpl_contains(" Ассистент:")) {
return LLM_CHAT_TEMPLATE_YANDEX;
} else if (tmpl_contains("<role>ASSISTANT</role>") && tmpl_contains("'HUMAN'")) {
return LLM_CHAT_TEMPLATE_BAILING;
}
return LLM_CHAT_TEMPLATE_UNKNOWN;
}
@@ -567,6 +573,41 @@ int32_t llm_chat_apply_template(
if (add_ass) {
ss << "<|role_start|>assistant<|role_end|>";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_YANDEX) {
// Yandex template ("\n\n" is defined as EOT token)
ss << "<s>";
for (size_t i = 0; i < chat.size(); i++) {
std::string role(chat[i]->role);
if (role == "user") {
ss << " Пользователь: " << chat[i]->content << "\n\n";
} else if (role == "assistant") {
ss << " Ассистент: " << chat[i]->content << "\n\n";
}
}
// Add generation prompt if needed
if (add_ass) {
ss << " Ассистент:[SEP]";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_BAILING) {
// Bailing (Ling) template
for (auto message : chat) {
std::string role(message->role);
if (role == "user") {
role = "HUMAN";
} else {
std::transform(role.begin(), role.end(), role.begin(), ::toupper);
}
ss << "<role>" << role << "</role>" << message->content;
}
if (add_ass) {
ss << "<role>ASSISTANT</role>";
}
} else {
// template not supported
return -1;
@@ -585,4 +626,3 @@ int32_t llama_chat_builtin_templates(const char ** output, size_t len) {
}
return (int32_t) LLM_CHAT_TEMPLATES.size();
}
+2
View File
@@ -38,6 +38,8 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_GRANITE,
LLM_CHAT_TEMPLATE_GIGACHAT,
LLM_CHAT_TEMPLATE_MEGREZ,
LLM_CHAT_TEMPLATE_YANDEX,
LLM_CHAT_TEMPLATE_BAILING,
LLM_CHAT_TEMPLATE_UNKNOWN,
};
+2 -7
View File
@@ -1317,8 +1317,8 @@ int llama_context::decode(llama_batch & inp_batch) {
n_outputs = n_outputs_new;
}
// non-causal masks do not use the KV cache
if (hparams.causal_attn) {
// find KV slot
{
kv_self_update();
// if we have enough unused cells before the current head ->
@@ -2316,11 +2316,6 @@ llama_context * llama_init_from_model(
params.flash_attn = false;
}
if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
params.flash_attn = false;
}
if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
return nullptr;
+70 -104
View File
@@ -402,120 +402,86 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask || self_kq_mask_swa) {
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
if (cparams.causal_attn) {
const int64_t n_kv = kv_self->n;
const int64_t n_tokens = ubatch->n_tokens;
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
const int64_t n_seqs = ubatch->n_seqs;
const int64_t n_kv = kv_self->n;
const int64_t n_tokens = ubatch->n_tokens;
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
const int64_t n_seqs = ubatch->n_seqs;
float * data = nullptr;
float * data_swa = nullptr;
if (self_kq_mask) {
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
data = (float *) self_kq_mask->data;
}
if (self_kq_mask_swa) {
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
data_swa = (float *) self_kq_mask_swa->data;
}
// For causal attention, use only the previous KV cells
// of the correct sequence for each token of the ubatch.
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
for (int h = 0; h < 1; ++h) {
for (int s = 0; s < n_seqs; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[s][0];
for (int j = 0; j < n_seq_tokens; ++j) {
const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
for (int i = 0; i < n_kv; ++i) {
float f;
if (!kv_self->cells[i].has_seq_id(seq_id) || kv_self->cells[i].pos > pos) {
f = -INFINITY;
} else {
if (hparams.use_alibi) {
f = -std::abs(kv_self->cells[i].pos - pos);
} else {
f = 0.0f;
}
}
if (data) {
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
}
// may need to cut off old tokens for sliding window
if (data_swa) {
if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
f = -INFINITY;
}
data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
}
}
}
}
if (data) {
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_kv; ++j) {
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
}
}
}
if (data_swa) {
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_kv; ++j) {
data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
}
}
}
}
} else {
const int64_t n_tokens = ubatch->n_tokens;
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
const int64_t n_seqs = ubatch->n_seqs;
// when using kv cache, the mask needs to match the kv cache size
const int64_t n_stride = n_tokens;
float * data = nullptr;
float * data_swa = nullptr;
if (self_kq_mask) {
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
data = (float *) self_kq_mask->data;
}
float * data = (float *) self_kq_mask->data;
if (self_kq_mask_swa) {
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
data_swa = (float *) self_kq_mask_swa->data;
}
for (int h = 0; h < 1; ++h) {
for (int s1 = 0; s1 < n_seqs; ++s1) {
const llama_seq_id seq_id = ubatch->seq_id[s1][0];
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
// Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
// Causal mask:
// xxx-------
// xxxx------
// xxxxx-----
// Non-causal mask:
// xxxxx-----
// xxxxx-----
// xxxxx-----
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
for (int h = 0; h < 1; ++h) {
for (int s = 0; s < n_seqs; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[s][0];
for (int j = 0; j < n_seq_tokens; ++j) {
const int32_t tj = s1*n_seq_tokens + j;
for (int s0 = 0; s0 < n_seqs; ++s0) {
for (int i = 0; i < n_seq_tokens; ++i) {
const int32_t ti = s0*n_seq_tokens + i;
float f = -INFINITY;
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
if (ubatch->seq_id[s0][s] == seq_id) {
if (hparams.use_alibi) {
f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
} else {
f = 0.0f;
}
break;
}
}
data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
for (int j = 0; j < n_seq_tokens; ++j) {
const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
for (int i = 0; i < n_kv; ++i) {
float f;
// mask the token if:
if (!kv_self->cells[i].has_seq_id(seq_id) // not the correct sequence
|| (cparams.causal_attn && kv_self->cells[i].pos > pos) // for causal, mask future tokens
) {
f = -INFINITY;
} else {
if (hparams.use_alibi) {
f = -std::abs(kv_self->cells[i].pos - pos);
} else {
f = 0.0f;
}
}
for (int i = n_tokens; i < n_stride; ++i) {
data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
if (data) {
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
}
// may need to cut off old tokens for sliding window
if (data_swa) {
if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
f = -INFINITY;
}
data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
}
}
}
}
// mask padded tokens
if (data) {
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_kv; ++j) {
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
}
}
}
// mask padded tokens
if (data_swa) {
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_kv; ++j) {
data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
}
}
}
+450 -32
View File
@@ -47,6 +47,7 @@ const char * llm_type_name(llm_type type) {
case LLM_TYPE_1_4B: return "1.4B";
case LLM_TYPE_1_5B: return "1.5B";
case LLM_TYPE_1_6B: return "1.6B";
case LLM_TYPE_1_8B: return "1.8B";
case LLM_TYPE_2B: return "2B";
case LLM_TYPE_2_8B: return "2.8B";
case LLM_TYPE_2_9B: return "2.9B";
@@ -87,6 +88,7 @@ const char * llm_type_name(llm_type type) {
case LLM_TYPE_10B_128x3_66B: return "10B+128x3.66B";
case LLM_TYPE_57B_A14B: return "57B.A14B";
case LLM_TYPE_27B: return "27B";
case LLM_TYPE_290B: return "290B";
default: return "?B";
}
}
@@ -255,7 +257,7 @@ static ggml_backend_buffer_type_t select_weight_buft(const llama_hparams & hpara
return nullptr;
}
// CPU: ACCEL -> CPU extra -> GPU host -> CPU
// CPU: ACCEL -> GPU host -> CPU extra -> CPU
static buft_list_t make_cpu_buft_list(const std::vector<ggml_backend_dev_t> & devices) {
buft_list_t buft_list;
@@ -271,32 +273,6 @@ static buft_list_t make_cpu_buft_list(const std::vector<ggml_backend_dev_t> & de
}
}
bool has_gpu_device = false;
for (auto * dev : devices) {
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
has_gpu_device = true;
break;
}
}
// add extra buffer types, only if no GPU device is present
// ref: https://github.com/ggml-org/llama.cpp/issues/12481#issuecomment-2743136094
if (!has_gpu_device) {
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
if (ggml_backend_dev_get_extra_bufts_fn) {
ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev);
while (extra_bufts && *extra_bufts) {
buft_list.emplace_back(cpu_dev, *extra_bufts);
++extra_bufts;
}
}
} else {
LLAMA_LOG_WARN("%s: disabling extra buffer types (i.e. repacking) since a GPU device is available\n", __func__);
}
// add a host buffer type
// storing the tensors in a host buffer is useful when the processing of large batches
// is offloaded to a GPU device, since it reduces the time spent on data transfers
@@ -311,6 +287,20 @@ static buft_list_t make_cpu_buft_list(const std::vector<ggml_backend_dev_t> & de
}
}
// add extra buffer types, only if no GPU device is present
// ref: https://github.com/ggml-org/llama.cpp/issues/12481#issuecomment-2743136094
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
if (ggml_backend_dev_get_extra_bufts_fn) {
ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev);
while (extra_bufts && *extra_bufts) {
buft_list.emplace_back(cpu_dev, *extra_bufts);
++extra_bufts;
}
}
// add the CPU buffer type
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
@@ -1144,6 +1134,15 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_PLM:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv);
switch (hparams.n_layer) {
case 32: type = LLM_TYPE_1_8B; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_CHATGLM:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -1330,6 +1329,21 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups);
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
} break;
case LLM_ARCH_BAILINGMOE:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
switch (hparams.n_layer) {
case 28: type = LLM_TYPE_16B; break;
case 88: type = LLM_TYPE_290B; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
default: throw std::runtime_error("unsupported model architecture");
}
@@ -3068,6 +3082,35 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
}
}
} break;
case LLM_ARCH_PLM:
{
const int64_t n_embd_head_qk_rope = hparams.n_rot;
const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
const int64_t kv_lora_rank = hparams.n_lora_kv;
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
// output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0);
layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0);
layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}
} break;
case LLM_ARCH_BITNET:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -3712,6 +3755,46 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0);
output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_embd}, 0);
} break;
case LLM_ARCH_BAILINGMOE:
{
const int64_t n_ff_exp = hparams.n_ff_exp;
const int64_t n_expert_shared = hparams.n_expert_shared;
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_rot}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_head_kv * n_rot}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_head_kv * n_rot}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
if (n_expert == 0) {
throw std::runtime_error("n_expert must be > 0");
}
if (n_expert_used == 0) {
throw std::runtime_error("n_expert_used must be > 0");
}
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0);
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
}
} break;
default:
throw std::runtime_error("unknown architecture");
}
@@ -3999,6 +4082,14 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale);
}
if (arch == LLM_ARCH_BAILINGMOE) {
LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared);
LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
}
vocab.print_info();
}
@@ -6284,7 +6375,7 @@ struct llm_build_qwen2moe : public llm_graph_context {
false, 0.0,
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
il);
cb(cur, "ffn_moe_out", il);
cb(moe_out, "ffn_moe_out", il);
// FFN shared expert
{
@@ -11615,6 +11706,322 @@ struct llm_build_wavtokenizer_dec : public llm_graph_context {
}
};
struct llm_build_plm : public llm_graph_context {
llm_build_plm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
const float kq_scale = 1.0f/sqrtf(float(hparams.n_embd_head_k));
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
const uint32_t kv_lora_rank = hparams.n_lora_kv;
ggml_tensor * cur;
ggml_tensor * inpL;
// {n_embd, n_tokens}
inpL = build_inp_embd(model.tok_embd);
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
// norm
cur = build_norm(inpL,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
// self_attention
{
ggml_tensor * q = NULL;
q = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
cb(q, "q", il);
// split into {n_head * n_embd_head_qk_nope, n_tokens}
ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens,
ggml_row_size(q->type, hparams.n_embd_head_k),
ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
0);
cb(q_nope, "q_nope", il);
// and {n_head * n_embd_head_qk_rope, n_tokens}
ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
ggml_row_size(q->type, hparams.n_embd_head_k),
ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
ggml_row_size(q->type, n_embd_head_qk_nope));
cb(q_pe, "q_pe", il);
// {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens}
ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
cb(kv_pe_compresseed, "kv_pe_compresseed", il);
// split into {kv_lora_rank, n_tokens}
ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens,
kv_pe_compresseed->nb[1],
0);
cb(kv_compressed, "kv_compressed", il);
// and {n_embd_head_qk_rope, n_tokens}
ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
kv_pe_compresseed->nb[1],
kv_pe_compresseed->nb[1],
ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
cb(k_pe, "k_pe", il);
kv_compressed = build_norm(kv_compressed,
model.layers[il].attn_kv_a_norm, NULL,
LLM_NORM_RMS, il);
cb(kv_compressed, "kv_compressed", il);
// {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
cb(kv, "kv", il);
// split into {n_head * n_embd_head_qk_nope, n_tokens}
ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
0);
cb(k_nope, "k_nope", il);
// and {n_head * n_embd_head_v, n_tokens}
ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
ggml_row_size(kv->type, (n_embd_head_qk_nope)));
cb(v_states, "v_states", il);
v_states = ggml_cont(ctx0, v_states);
cb(v_states, "v_states", il);
v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
0);
cb(v_states, "v_states", il);
q_pe = ggml_rope_ext(
ctx0, q_pe, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(q_pe, "q_pe", il);
// shared RoPE key
k_pe = ggml_rope_ext(
ctx0, k_pe, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(k_pe, "k_pe", il);
ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
cb(q_states, "q_states", il);
ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
cb(k_states, "k_states", il);
cur = build_attn(inp_attn, gf,
model.layers[il].wo, NULL,
q_states, k_states, v_states, nullptr, kq_scale, il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
cur = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
cur = build_ffn(cur,
model.layers[il].ffn_up, NULL, NULL,
NULL, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_RELU_SQR, LLM_FFN_SEQ, il);
cb(cur, "ffn_out", il);
cur = ggml_add(ctx0, cur, ffn_inp);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = build_norm(cur,
model.output_norm, NULL,
LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}
};
struct llm_build_bailingmoe : public llm_graph_context {
llm_build_bailingmoe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
ggml_tensor * cur;
ggml_tensor * inpL;
inpL = build_inp_embd(model.tok_embd);
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
// norm
cur = build_norm(inpL,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
// self-attention
{
// rope freq factors for llama3; may return nullptr for llama2 and other models
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
cb(Qcur, "Qcur", il);
}
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
cb(Kcur, "Kcur", il);
}
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
cb(Vcur, "Vcur", il);
}
Qcur = ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_rot, n_head_kv, n_tokens);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
cur = build_attn(inp_attn, gf,
model.layers[il].wo, model.layers[il].bo,
Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_rot)), il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
cur = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
ggml_tensor * moe_out =
build_moe_ffn(cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
nullptr,
n_expert, n_expert_used,
LLM_FFN_SILU, hparams.expert_weights_norm,
false, hparams.expert_weights_scale,
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
il);
cb(moe_out, "ffn_moe_out", il);
// FFN shared expert
{
ggml_tensor * ffn_shexp = build_ffn(cur,
model.layers[il].ffn_up_shexp, NULL, NULL,
model.layers[il].ffn_gate_shexp, NULL, NULL,
model.layers[il].ffn_down_shexp, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(ffn_shexp, "ffn_shexp", il);
cur = ggml_add(ctx0, moe_out, ffn_shexp);
cb(cur, "ffn_out", il);
}
cur = ggml_add(ctx0, cur, ffn_inp);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = build_norm(cur,
model.output_norm, NULL,
LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}
};
llama_memory_i * llama_model::create_memory() const {
llama_memory_i * res;
@@ -11846,10 +12253,11 @@ llm_graph_result_ptr llama_model::build_graph(
GGML_ABORT("invalid graph type");
};
} break;
//case LLM_ARCH_T5ENCODER:
// {
// llm.build_t5_enc(gf);
// } break;
case LLM_ARCH_T5ENCODER:
{
llm = std::make_unique<llm_build_t5_enc>(*this, params, gf);
}
break;
case LLM_ARCH_JAIS:
{
llm = std::make_unique<llm_build_jais>(*this, params, gf);
@@ -11886,6 +12294,14 @@ llm_graph_result_ptr llama_model::build_graph(
{
llm = std::make_unique<llm_build_wavtokenizer_dec>(*this, params, gf);
} break;
case LLM_ARCH_PLM:
{
llm = std::make_unique<llm_build_plm>(*this, params, gf);
} break;
case LLM_ARCH_BAILINGMOE:
{
llm = std::make_unique<llm_build_bailingmoe>(*this, params, gf);
} break;
default:
GGML_ABORT("fatal error");
}
@@ -12012,10 +12428,12 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_ARCTIC:
case LLM_ARCH_DEEPSEEK:
case LLM_ARCH_DEEPSEEK2:
case LLM_ARCH_PLM:
case LLM_ARCH_CHATGLM:
case LLM_ARCH_GRANITE:
case LLM_ARCH_GRANITE_MOE:
case LLM_ARCH_CHAMELEON:
case LLM_ARCH_BAILINGMOE:
return LLAMA_ROPE_TYPE_NORM;
// the pairs of head values are offset by n_rot/2
+2
View File
@@ -44,6 +44,7 @@ enum llm_type {
LLM_TYPE_1_4B,
LLM_TYPE_1_5B,
LLM_TYPE_1_6B,
LLM_TYPE_1_8B,
LLM_TYPE_2B,
LLM_TYPE_2_8B,
LLM_TYPE_2_9B,
@@ -84,6 +85,7 @@ enum llm_type {
LLM_TYPE_10B_128x3_66B,
LLM_TYPE_57B_A14B,
LLM_TYPE_27B,
LLM_TYPE_290B,
};
struct llama_layer_posnet {
+5
View File
@@ -1477,6 +1477,7 @@ static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sam
const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
auto * result = llama_sampler_init_grammar_impl(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0, nullptr, 0);
GGML_ASSERT(result);
// copy the state
{
@@ -1548,6 +1549,10 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
/* .grammar_root = */ grammar_root,
/* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens),
};
if (!ctx->grammar) {
delete ctx;
return nullptr;
}
} else {
*ctx = {
/* .vocab = */ vocab,
+16
View File
@@ -342,6 +342,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
case LLAMA_VOCAB_PRE_TYPE_MPT:
case LLAMA_VOCAB_PRE_TYPE_OLMO:
case LLAMA_VOCAB_PRE_TYPE_JAIS:
case LLAMA_VOCAB_PRE_TYPE_TRILLION:
regex_exprs = {
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
};
@@ -406,6 +407,13 @@ struct llm_tokenizer_bpe : llm_tokenizer {
"(?=(\\d{3})+(?!\\d))",
};
break;
case LLAMA_VOCAB_PRE_TYPE_BAILINGMOE:
regex_exprs = {
// original regex from tokenizer.json
// "'(?i:[sdmt]|ll|ve|re)|[^\\r\\n\\p{L}\\p{N}]?+\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]++[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+"
"'(?:[sSdDmMtT]|[lL][lL]|[vV][eE]|[rR][eE])|[^\\r\\n\\p{L}\\p{N}]?+\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]++[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+",
};
break;
default:
// default regex for BPE tokenization pre-processing
regex_exprs = {
@@ -1614,6 +1622,14 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
tokenizer_pre == "superbpe") {
pre_type = LLAMA_VOCAB_PRE_TYPE_SUPERBPE;
clean_spaces = false;
} else if (
tokenizer_pre == "trillion") {
pre_type = LLAMA_VOCAB_PRE_TYPE_TRILLION;
clean_spaces = false;
} else if (
tokenizer_pre == "bailingmoe") {
pre_type = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE;
clean_spaces = false;
} else {
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
}
+38 -31
View File
@@ -3217,7 +3217,8 @@ struct test_leaky_relu : public test_case {
// GGML_OP_FLASH_ATTN_EXT
struct test_flash_attn_ext : public test_case {
const int64_t hs; // head size
const int64_t hsk; // K head size
const int64_t hsv; // V head size
const int64_t nh; // num heads
const int64_t nr; // repeat in Q, tests for grouped-query attention
const int64_t kv; // kv size
@@ -3233,7 +3234,7 @@ struct test_flash_attn_ext : public test_case {
std::array<int32_t, 4> permute;
std::string vars() override {
return VARS_TO_STR11(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute);
return VARS_TO_STR12(hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute);
}
double max_nmse_err() override {
@@ -3243,17 +3244,18 @@ struct test_flash_attn_ext : public test_case {
uint64_t op_flops(ggml_tensor * t) override {
GGML_UNUSED(t);
// Just counting matmul costs:
// Q*K^T is nb x hs x kv, P*V is nb x kv x hs, per head
return 2 * 2 * nh*nr * nb * hs * kv;
// Q*K^T is nb x hsk x kv, P*V is nb x kv x hsv, per head
return 2 * nh*nr * nb * (hsk + hsv) * kv;
}
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8,
test_flash_attn_ext(int64_t hsk = 128, int64_t hsv = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8,
bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32,
ggml_type type_KV = GGML_TYPE_F16, std::array<int32_t, 4> permute = {0, 1, 2, 3})
: hs(hs), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
: hsk(hsk), hsv(hsv), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
const int64_t hsk_padded = GGML_PAD(hsk, ggml_blck_size(type_KV));
const int64_t hsv_padded = GGML_PAD(hsv, ggml_blck_size(type_KV));
auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) -> ggml_tensor * {
int64_t ne[4] = {ne0, ne1, ne2, ne3};
@@ -3268,13 +3270,13 @@ struct test_flash_attn_ext : public test_case {
return t;
};
ggml_tensor * q = create_permuted(GGML_TYPE_F32, hs_padded, nb, nh*nr, 1);
ggml_tensor * q = create_permuted(GGML_TYPE_F32, hsk_padded, nb, nh*nr, 1);
ggml_set_name(q, "q");
ggml_tensor * k = create_permuted(type_KV, hs_padded, kv, nh, 1);
ggml_tensor * k = create_permuted(type_KV, hsk_padded, kv, nh, 1);
ggml_set_name(k, "k");
ggml_tensor * v = create_permuted(type_KV, hs_padded, kv, nh, 1);
ggml_tensor * v = create_permuted(type_KV, hsv_padded, kv, nh, 1);
ggml_set_name(v, "v");
ggml_tensor * m = nullptr;
@@ -3283,7 +3285,7 @@ struct test_flash_attn_ext : public test_case {
ggml_set_name(m, "m");
}
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias, logit_softcap);
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hsk), max_bias, logit_softcap);
ggml_flash_attn_ext_set_prec(out, prec);
ggml_set_name(out, "out");
@@ -4412,27 +4414,32 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_timestep_embedding());
test_cases.emplace_back(new test_leaky_relu());
for (int hs : { 64, 80, 128, 256, }) {
for (bool mask : { true, false } ) {
for (float max_bias : { 0.0f, 8.0f }) {
if (!mask && max_bias > 0.0f) continue;
for (float logit_softcap : {0.0f, 10.0f}) {
if (hs != 128 && logit_softcap != 0.0f) continue;
for (int nh : { 4, }) {
for (int nr : { 1, 4, 16 }) {
if (nr == 16 && hs != 128) continue;
for (int kv : { 512, 1024, }) {
if (nr != 1 && kv != 512) continue;
for (int nb : { 1, 3, 32, 35, }) {
for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
if (hs != 128 && prec == GGML_PREC_DEFAULT) continue;
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
test_cases.emplace_back(new test_flash_attn_ext(
hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV));
// run fewer test cases permuted
if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
for (int hsk : { 64, 80, 128, 192, 256, }) {
for (int hsv : { 64, 80, 128, 192, 256, }) {
if (hsk != 192 && hsk != hsv) continue;
if (hsk == 192 && (hsv != 128 && hsv != 192)) continue;
for (bool mask : { true, false } ) {
for (float max_bias : { 0.0f, 8.0f }) {
if (!mask && max_bias > 0.0f) continue;
for (float logit_softcap : {0.0f, 10.0f}) {
if (hsk != 128 && logit_softcap != 0.0f) continue;
for (int nh : { 4, }) {
for (int nr : { 1, 4, 16 }) {
if (nr == 16 && hsk != 128) continue;
for (int kv : { 512, 1024, }) {
if (nr != 1 && kv != 512) continue;
for (int nb : { 1, 3, 32, 35, }) {
for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
test_cases.emplace_back(new test_flash_attn_ext(
hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV));
// run fewer test cases permuted
if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
test_cases.emplace_back(new test_flash_attn_ext(
hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
}
}
}
}
+8
View File
@@ -270,6 +270,14 @@ int main(void) {
/* .bos_token= */ "",
/* .eos_token= */ "",
},
{
/* .name= */ "yandex/YandexGPT-5-Lite-8B-instruct",
/* .template_str= */ "<s>{%- set names = {'assistant': ' Ассистент:', 'user': ' Пользователь:'} %}\n{%- set tools_prefix = 'Тебе доступны следующие функции:' %}\n{%- macro __render_tool(tool) %}\n {%- set name = tool.function.name %}\n {%- set description = tool.function.description|default('') %}\n {%- set parameters = tool.function.parameters|tojson %}\n {{- '\\n' }}function {{ '{' }}'name':'{{ name }}',\n {%- if tool.function.description %}'description':'{{ description }}',{% endif %}\n'parameters':{{ parameters }}\n {{- '}' }}\n{%- endmacro %}\n{%- macro __render_tools(tools) %}\n {{- tools_prefix }}\n {%- for tool in tools %}\n {{- __render_tool(tool) }}\n {%- endfor %}\n {{- '\\n\\n' }}\n{%- endmacro %}\n{%- macro __render_tool_message(message) %}\n {{- '\\n\\nРезультат вызова' }} {{ message.name }}: {{ message.content }} {{ '\\n\\n' }}\n{%- endmacro %}\n{%- if tools -%}\n {{- __render_tools(tools) }}\n{%- endif -%}\n{%- macro __render_user_message(message) %}\n{{ names.user }} {{ message.content + '\\n\\n' }}\n{%- endmacro %}\n{%- macro __render_assistant_message(message) %}\n {{- names.assistant }}\n {%- set call = message['function_call'] %}\n {%- if call %}\n {{- '\\n[TOOL_CALL_START]' }}{{ call.name }}{{ '\\n' }}{{ call.arguments|tojson }}\n {%- else %}\n {{- ' ' + message.content + '\\n\\n' }}\n {%- endif %}\n{%- endmacro %}\n{%- if not add_generation_prompt is defined %}\n{%- set add_generation_prompt = false %}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'user' %}\n {{- __render_user_message(message) }}\n {%- endif %}\n {%- if message.role == 'assistant' and not loop.last %}\n {{- __render_assistant_message(message) }}\n {%- endif %}\n {%- if message.role == 'tool' %}\n {{- __render_tool_message(message) }}\n {%- endif %}\n {%- if loop.last %}\n {{- ' Ассистент:[SEP]' }}\n {%- endif %}\n{%- endfor %}\n",
/* .expected_output= */ "<s> Пользователь: Hello\n\n Ассистент: Hi there\n\n Пользователь: Who are you\n\n Ассистент: I am an assistant \n\n Пользователь: Another question\n\n Ассистент:[SEP]",
/* .expected_output_jinja= */ "<s> Пользователь: You are a helpful assistant\nHello\n\n Ассистент: Hi there\n\n Пользователь: Who are you\n\n Ассистент: I am an assistant \n\n Пользователь: Another question\n\n Ассистент:[SEP]",
/* .bos_token= */ "",
/* .eos_token= */ "",
},
};
std::vector<char> formatted_chat(1024);
int32_t res;