Compare commits

..

36 Commits

Author SHA1 Message Date
Jeff Bolz 44e18ef939 vulkan: fix coopmat2 flash attention for non-contiguous inputs (#11281)
Add code similar to mul_mm_cm2 to force alignment of strides, to avoid
a performance regression.

Add noncontiguous FA tests in test-backend-ops.

Fixes #11268.
2025-01-18 09:26:50 +01:00
codezjx 3edfa7d375 llama.android: add field formatChat to control whether to parse special tokens when send message (#11270) 2025-01-17 14:57:56 +02:00
Radoslav Gerganov 667d72846c rpc : early register backend devices (#11262)
Early register RPC devices and do not propagate RPC specifics in the
llama model structures.

ref: #10609
2025-01-17 10:57:09 +02:00
Georgi Gerganov a133566d34 vocab : fix double-eos check (#11273)
ggml-ci
2025-01-17 09:28:00 +02:00
David Renshaw 960ec65273 llama : fix deprecation message: vocabable -> vocab (#11269) 2025-01-17 08:12:01 +01:00
musoles 7a689c415e README : added kalavai to infrastructure list (#11216) 2025-01-17 01:10:49 +01:00
Jeff Bolz bd38ddea01 vulkan: support copy from f32 to q4_0/q4_1/q5_0/q5_1/q8_0/iq4_nl (#11166)
* vulkan: support copy from f32 to q4_0/q4_1/q5_0/q5_1/q8_0/iq4_nl

Shaders are based on cpy.cu.

* vulkan: support copy from q4_0/q4_1/q5_0/q5_1/q8_0/iq4_nl to f32

* ggml: copy q->f32 assumes some contiguity in the destination
2025-01-16 22:47:10 +01:00
Jeff Bolz 466300fe14 vulkan: optimize coopmat2 q4_k/q5_k dequant functions. (#11206)
Do masking on whole dwords, fetch all scales at once.
2025-01-16 22:23:49 +01:00
Jeff Bolz 206bc53422 vulkan: optimize coopmat2 q2_k dequant function (#11130) 2025-01-16 22:16:39 +01:00
RunningLeon 4dbc8b9cb7 llama : add internlm3 support (#11233)
* support internlm3

* fix lint
2025-01-16 20:10:38 +02:00
Johannes Gäßler 9c8dcefe17 CUDA: backwards pass for misc. ops, add tests (#11257)
* CUDA: backwards pass for misc. ops, add tests

* remove restrict from pointers
2025-01-16 16:43:38 +01:00
Xuan Son Nguyen 681149ced2 llama : add llama_model_load_from_splits (#11255)
* llama : add `llama_model_load_from_splits`

* update
2025-01-16 13:54:08 +01:00
fj-y-saito c67cc9837d ggml: aarch64: implement SVE kernels for q4_K_q8_K vector dot (#11227)
* Add SVE support for q4_K_q8_K

* Update ggml/src/ggml-cpu/ggml-cpu-quants.c

change to use K_SCALE_SIZE

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-01-16 11:11:49 +02:00
Eve adc5dd92e8 vulkan: scale caching for k quants + misc fixes (#11081)
* q6_k scale caching

* 16 bit unpack

* q4_k test (slow)

* revert it

* q3_k

* q2_k

* little stuff

* try precalculating products of a and q2_k scales

* Revert "try precalculating products of a and q2_k scales"

This reverts commit 65110b81f23f66331a50c6e889a7c1ab9470a86b.

* unpack should be u16, add vim swap to gitignore (about time)

* better q4_k scales

* q5_k

* better q6_k with separate paths for all threads and partial threads in use, plus some more optimizations

* q2_k better dequant

* q3_k optimizations

* q3_k use hmask simd from cpu avx version

* make the caches happy

* q3_k separate out calculation

* q2_k separate out

* little stuff

* use calc_superblock everywhere

* q2_k optimize scale calculation

* more barriers
2025-01-15 19:50:13 +00:00
Georgi Gerganov f11cfdfd7f ci : use -no-cnv in gguf-split tests (#11254)
* ci : use -no-cnv in gguf-split tests

ggml-ci

* ci : use -no-cnv in requantize tests

ggml-ci

* scripts : fix [no ci]
2025-01-15 18:28:35 +02:00
Junil Kim 1d8504338e fix: ggml: fix vulkan-shaders-gen build (#10448)
* fix: ggml: fix vulkan-shaders-gen build

The vulkan-shaders-gen target was not being built correctly
in case of cross-compilation.
Other outputs need to be built for the cross compile target,
but vulkan-shaders-gen needs to be built for the host.

* refactor: ggml: Improve vulkan-shaders-gen toolchain setup

- Add GGML_SHADERS_GEN_TOOLCHAIN CMake option.
- Auto-detect host toolchain if not set.

* refactor: ggml: Improve vulkan-shaders-gen toolchain setup

Use configure_file to generate host_toolchain.cmake from template

* fix: ggml: Fix compile error

Fix compile error not finding vulkan-shaders-gen

* fix: vulkan-shaders-gen build and path handling

Fix build issues with vulkan-shaders-gen:
- Add target dependency for correct build order
- Use CMAKE_HOST_SYSTEM_NAME for executable suffix
- Fix MSVC output directory in host toolchain
- Normalize path handling for cross-compilation

* fix: improve host compiler detection in vulkan shader build

Improve host compiler detection for vulkan shader generation:
- Add NO_CMAKE_FIND_ROOT_PATH to all compiler searches
- Consolidate compiler detection logic
- Fix Windows-specific MSVC detection
- Ensure correct compiler search in cross-compilation

* refactor: Simplify CMake function for detecting host compiler

Simplified the CMake function to improve the process of detecting the host compiler.

* fix: Remove unnecessary Vulkan library linkage in CMakeLists.txt

Since `vulkan-shader-gen.cpp` only requires the `glslc` executable
and not the Vulkan headers or libraries, CMakeLists.txt needs to
be corrected.
(See: ecc93d0558)

* refactor: Rename host_toolchain.cmake.in

- Rename host_toolchain.cmake.in to cmake/host-toolchain.cmake.in

* refactor: GGML_VULKAN_SHADERS_GEN_TOOLCHAIN

Rename the macro GGML_SHADERS_GEN_TOOLCHAIN to GGML_VULKAN_SHADERS_GEN_TOOLCHAIN
2025-01-15 14:17:42 +01:00
Johannes Gäßler 432df2d5f9 RoPE: fix back, CUDA support for back + noncont. (#11240)
* RoPE: fix back, CUDA support for back + noncont.

* fix comments reg. non-cont. RoPE support [no-ci]
2025-01-15 12:51:37 +01:00
Daniel Bevenius 0ccd7f3eb2 examples : add embd_to_audio to tts-outetts.py [no ci] (#11235)
This commit contains a suggestion for adding the missing embd_to_audio
function from tts.cpp to tts-outetts.py. This introduces a depencency
numpy which I was not sure if that is acceptable or not (only PyTorch
was mentioned in referened PR).

Also the README has been updated with instructions to run the example
with llama-server and the python script.

Refs: https://github.com/ggerganov/llama.cpp/pull/10784#issuecomment-2548377734
2025-01-15 05:44:38 +01:00
Akarshan Biswas f446c2cf6a SYCL: Add gated linear attention kernel (#11175)
* SYCL: Add Gated Linear attention kernel

* glahpp: add a space at the end of file

* gla: Put the barrier inside the main logic loop
2025-01-15 11:20:17 +08:00
Xuan Son Nguyen b4d92a59a2 ci : add -no-cnv for tests (#11238) 2025-01-14 16:42:23 +02:00
Georgi Gerganov bbf3e55e35 vocab : add dummy tokens for "no_vocab" type (#11231)
* vocab : add dummy tokens for "no_vocab" type

ggml-ci

* vocab : minor [no ci]
2025-01-14 11:54:58 +01:00
ebraminio c5bf0d1bd7 server : Improve code snippets direction between RTL text (#11221) 2025-01-14 11:39:33 +01:00
Olivier Chafik 091592d758 Refactor test-chat-template.cpp (#11224)
* Refactor test-chat-template

* Update test-chat-template.cpp
2025-01-14 10:16:41 +00:00
Georgi Gerganov 44d1e796d0 sync : ggml 2025-01-14 10:39:42 +02:00
Georgi Gerganov a4f3f5d8e6 scripts : sync gguf (cont) 2025-01-14 09:40:52 +02:00
Georgi Gerganov 48e1ae0e61 scripts : sync gguf 2025-01-14 09:36:58 +02:00
Georgi Gerganov d00a80e89d scripts : sync opencl 2025-01-14 09:19:58 +02:00
ebraminio 504af20ee4 server : (UI) Improve messages bubble shape in RTL (#11220)
I simply have overlooked message bubble's tail placement for RTL
text as I use the dark mode and that isn't visible there and this
fixes it.
2025-01-13 20:23:31 +01:00
Xuan Son Nguyen 84a44815f7 cli : auto activate conversation mode if chat template is available (#11214)
* cli : auto activate conversation mode if chat template is detected

* add warn on bad template

* update readme (writing with the help of chatgpt)

* update readme (2)

* do not activate -cnv for non-instruct models
2025-01-13 20:18:12 +01:00
Andreas Kieslinger 39509fb082 cuda : CUDA Graph Compute Function Refactor (precursor for performance improvements) (#11042)
* Refactor: Moves cuda graph executable update step to separate function.

* Refactor: Moves cuda graph update check to separate function.

* Refactor: Moves cuda graph maintenance (update or adjusting copy parameters) to separate function for improved readability.

* Fix: Adds missing reference to maintain_cuda_graph() definition.

* Refactor: Improves structure and abstractions by moving CUDA graph evaluation and capture to its own function.

* Refactor: Moves node graph checks and copy ops into individual function for improved readability.

* Refactor: Removes code permanently excluded from compilation to increase readability.

* Style: Adds missing newline

* Style: Consolidates several neighboring '#ifdef USE_CUDA_GRAPH' into a single one

* Refactor: Makes 'cuda_graph_update_required' a local variable

* remove double lines between functions

---------

Co-authored-by: slaren <slarengh@gmail.com>
2025-01-13 16:45:53 +01:00
Georgi Gerganov a29f0870d4 contrib : add naming guidelines (cont) (#11177) 2025-01-13 15:59:26 +02:00
ebraminio 437e05f714 server : (UI) Support for RTL text as models input or output (#11208) 2025-01-13 14:46:39 +01:00
Georgi Gerganov ca001f6656 contrib : add naming guidelines (cont) (#11177) 2025-01-13 15:08:44 +02:00
Xuan Son Nguyen 00b4c3da62 common : support tag-based --hf-repo like on ollama (#11195)
* common : support tag-based hf_repo like on ollama

* fix build

* various fixes

* small fixes

* fix style

* fix windows build?

* move common_get_hf_file to common.cpp

* fix complain with noreturn
2025-01-13 13:56:23 +01:00
Georgi Gerganov 7426a26b24 contrib : add naming guidelines (#11177)
* contrib : add naming guidelines

* contrib : expand naming guidelines [no ci]

* contrib : cont [no ci]

* contrib : add `_t` suffix guideline [no ci]

* contrib : cont [no ci]

* minor [no ci]

* contrib : move coding guidelines to correct section [no ci]

* contrib : minor reword coding guidelines [no ci]

* contrib : add TODO for preprocessor directives [no ci]

* contrib : expand [no ci]

* minor [no ci]

* contrib : clarify `_context` suffix usage [no ci]

* contrib : filename guidelines [no ci]

* contrib : fix notes [no ci]
2025-01-13 14:46:36 +02:00
Daniel Bevenius 8f70fc3d1b llama : remove 'd' from bad special token log (#11212)
This commit removes the 'd' from the log message in llama-vocab.cpp
when logging a bad special token.

The motivation for this is that currently the output can look something
like the following:
```console
load: bad special token:
    'tokenizer.ggml.image_token_id' = 128256d, using default id -1
```
2025-01-13 13:38:20 +01:00
76 changed files with 3672 additions and 1660 deletions
+1
View File
@@ -18,6 +18,7 @@
*.metallib
*.o
*.so
*.swp
*.tmp
# IDE / OS
+96 -6
View File
@@ -1,10 +1,10 @@
# Pull requests (for contributors)
- Test your changes:
- Execute [the full CI locally on your machine](ci/README.md) before publishing
- Verify that the perplexity and the performance are not affected negatively by your changes (use `llama-perplexity` and `llama-bench`)
- If you modified the `ggml` source, run the `test-backend-ops` tool to check whether different backend implementations of the `ggml` operators produce consistent results (this requires access to at least two different `ggml` backends)
- If you modified a `ggml` operator or added a new one, add the corresponding test cases to `test-backend-ops`
- Execute [the full CI locally on your machine](ci/README.md) before publishing
- Verify that the perplexity and the performance are not affected negatively by your changes (use `llama-perplexity` and `llama-bench`)
- If you modified the `ggml` source, run the `test-backend-ops` tool to check whether different backend implementations of the `ggml` operators produce consistent results (this requires access to at least two different `ggml` backends)
- If you modified a `ggml` operator or added a new one, add the corresponding test cases to `test-backend-ops`
- Consider allowing write access to your branch for faster reviews, as reviewers can push commits directly
- If your PR becomes stale, don't hesitate to ping the maintainers in the comments
@@ -20,14 +20,104 @@
- Avoid adding third-party dependencies, extra files, extra headers, etc.
- Always consider cross-compatibility with other operating systems and architectures
- Avoid fancy-looking modern STL constructs, use basic `for` loops, avoid templates, keep it simple
- There are no strict rules for the code style, but try to follow the patterns in the code (indentation, spaces, etc.). Vertical alignment makes things more readable and easier to batch edit
- Vertical alignment makes things more readable and easier to batch edit
- Clean-up any trailing whitespaces, use 4 spaces for indentation, brackets on the same line, `void * ptr`, `int & a`
- Naming usually optimizes for common prefix (see https://github.com/ggerganov/ggml/pull/302#discussion_r1243240963)
- Use sized integer types such as `int32_t` in the public API, e.g. `size_t` may also be appropriate for allocation sizes or byte offsets
- Declare structs with `struct foo {}` instead of `typedef struct foo {} foo`
- In C++ code omit optional `struct` and `enum` keyword whenever they are not necessary
```cpp
// OK
llama_context * ctx;
const llama_rope_type rope_type;
// not OK
struct llama_context * ctx;
const enum llama_rope_type rope_type;
```
_(NOTE: this guideline is yet to be applied to the `llama.cpp` codebase. New code should follow this guideline.)_
- Try to follow the existing patterns in the code (indentation, spaces, etc.). In case of doubt use `clang-format` to format the added code
- For anything not covered in the current guidelines, refer to the [C++ Core Guidelines](https://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines)
- Tensors store data in row-major order. We refer to dimension 0 as columns, 1 as rows, 2 as matrices
- Matrix multiplication is unconventional: [`C = ggml_mul_mat(ctx, A, B)`](https://github.com/ggerganov/llama.cpp/blob/880e352277fc017df4d5794f0c21c44e1eae2b84/ggml.h#L1058-L1064) means $C^T = A B^T \Leftrightarrow C = B A^T.$
![matmul](media/matmul.png)
# Naming guidelines
- Use `snake_case` for function, variable and type names
- Naming usually optimizes for longest common prefix (see https://github.com/ggerganov/ggml/pull/302#discussion_r1243240963)
```cpp
// not OK
int small_number;
int big_number;
// OK
int number_small;
int number_big;
```
- Enum values are always in upper case and prefixed with the enum name
```cpp
enum llama_vocab_type {
LLAMA_VOCAB_TYPE_NONE = 0,
LLAMA_VOCAB_TYPE_SPM = 1,
LLAMA_VOCAB_TYPE_BPE = 2,
LLAMA_VOCAB_TYPE_WPM = 3,
LLAMA_VOCAB_TYPE_UGM = 4,
LLAMA_VOCAB_TYPE_RWKV = 5,
};
```
- The general naming pattern is `<class>_<method>`, with `<method>` being `<action>_<noun>`
```cpp
llama_model_init(); // class: "llama_model", method: "init"
llama_sampler_chain_remove(); // class: "llama_sampler_chain", method: "remove"
llama_sampler_get_seed(); // class: "llama_sampler", method: "get_seed"
llama_set_embeddings(); // class: "llama_context", method: "set_embeddings"
llama_n_threads(); // class: "llama_context", method: "n_threads"
llama_adapter_lora_free(); // class: "llama_adapter_lora", method: "free"
```
- The `get` `<action>` can be omitted
- The `<noun>` can be omitted if not necessary
- The `_context` suffix of the `<class>` is optional. Use it to disambiguate symbols when needed
- Use `init`/`free` for constructor/destructor `<action>`
- Use the `_t` suffix when a type is supposed to be opaque to the user - it's not relevant to them if it is a struct or anything else
```cpp
typedef struct llama_context * llama_context_t;
enum llama_pooling_type llama_pooling_type(const llama_context_t ctx);
```
_(NOTE: this guideline is yet to be applied to the `llama.cpp` codebase. New code should follow this guideline)_
- C/C++ filenames are all lowercase with dashes. Headers use the `.h` extension. Source files use the `.c` or `.cpp` extension
- Python filenames are all lowercase with underscores
- _(TODO: abbreviations usage)_
# Preprocessor directives
- _(TODO: add guidelines with examples and apply them to the codebase)_
```cpp
#ifdef FOO
#endif // FOO
```
# Documentation
- Documentation is a community effort
- When you need to look into the source code to figure out how to use an API consider adding a short summary to the header file for future reference
- When you notice incorrect or outdated documentation, please update it
# Resources
The Github issues, PRs and discussions contain a lot of information that can be useful to get familiar with the codebase. For convenience, some of the more important information is referenced from Github projects:
+22 -17
View File
@@ -204,6 +204,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
- [GPUStack](https://github.com/gpustack/gpustack) - Manage GPU clusters for running LLMs
- [llama_cpp_canister](https://github.com/onicai/llama_cpp_canister) - llama.cpp as a smart contract on the Internet Computer, using WebAssembly
- [llama-swap](https://github.com/mostlygeek/llama-swap) - transparent proxy that adds automatic model switching with llama-server
- [Kalavai](https://github.com/kalavai-net/kalavai-client) - Crowdsource end to end LLM deployment at any scale
</details>
@@ -245,6 +246,8 @@ The [Hugging Face](https://huggingface.co) platform hosts a [number of LLMs](htt
- [Trending](https://huggingface.co/models?library=gguf&sort=trending)
- [LLaMA](https://huggingface.co/models?sort=trending&search=llama+gguf)
You can either manually download the GGUF file or directly use any `llama.cpp`-compatible models from Hugging Face by using this CLI argument: `-hf <user>/<model>[:quant]`
After downloading a model, use the CLI tools to run it locally - see below.
`llama.cpp` requires the model to be stored in the [GGUF](https://github.com/ggerganov/ggml/blob/master/docs/gguf.md) file format. Models in other data formats can be converted to GGUF using the `convert_*.py` Python scripts in this repo.
@@ -263,21 +266,12 @@ To learn more about model quantization, [read this documentation](examples/quant
#### A CLI tool for accessing and experimenting with most of `llama.cpp`'s functionality.
- <details open>
<summary>Run simple text completion</summary>
```bash
llama-cli -m model.gguf -p "I believe the meaning of life is" -n 128
# I believe the meaning of life is to find your own truth and to live in accordance with it. For me, this means being true to myself and following my passions, even if they don't align with societal expectations. I think that's what I love about yoga it's not just a physical practice, but a spiritual one too. It's about connecting with yourself, listening to your inner voice, and honoring your own unique journey.
```
</details>
- <details>
<summary>Run in conversation mode</summary>
Models with a built-in chat template will automatically activate conversation mode. If this doesn't occur, you can manually enable it by adding `-cnv` and specifying a suitable chat template with `--chat-template NAME`
```bash
llama-cli -m model.gguf -p "You are a helpful assistant" -cnv
llama-cli -m model.gguf
# > hi, who are you?
# Hi there! I'm your helpful assistant! I'm an AI-powered chatbot designed to assist and provide information to users like you. I'm here to help answer your questions, provide guidance, and offer support on a wide range of topics. I'm a friendly and knowledgeable AI, and I'm always happy to help with anything you need. What's on your mind, and how can I assist you today?
@@ -289,17 +283,28 @@ To learn more about model quantization, [read this documentation](examples/quant
</details>
- <details>
<summary>Run with custom chat template</summary>
<summary>Run in conversation mode with custom chat template</summary>
```bash
# use the "chatml" template
llama-cli -m model.gguf -p "You are a helpful assistant" -cnv --chat-template chatml
# use the "chatml" template (use -h to see the list of supported templates)
llama-cli -m model.gguf -cnv --chat-template chatml
# use a custom template
llama-cli -m model.gguf -p "You are a helpful assistant" -cnv --in-prefix 'User: ' --reverse-prompt 'User:'
llama-cli -m model.gguf -cnv --in-prefix 'User: ' --reverse-prompt 'User:'
```
[Supported templates](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template)
</details>
- <details>
<summary>Run simple text completion</summary>
To disable conversation mode explicitly, use `-no-cnv`
```bash
llama-cli -m model.gguf -p "I believe the meaning of life is" -n 128 -no-cnv
# I believe the meaning of life is to find your own truth and to live in accordance with it. For me, this means being true to myself and following my passions, even if they don't align with societal expectations. I think that's what I love about yoga it's not just a physical practice, but a spiritual one too. It's about connecting with yourself, listening to your inner voice, and honoring your own unique journey.
```
</details>
+33 -33
View File
@@ -326,17 +326,17 @@ function gg_run_open_llama_7b_v2 {
./bin/llama-quantize ${model_f16} ${model_q5_k} q5_k
./bin/llama-quantize ${model_f16} ${model_q6_k} q6_k
(time ./bin/llama-cli --model ${model_f16} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log
(time ./bin/llama-cli --model ${model_q8_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log
(time ./bin/llama-cli --model ${model_q4_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_0.log
(time ./bin/llama-cli --model ${model_q4_1} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_1.log
(time ./bin/llama-cli --model ${model_q5_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_0.log
(time ./bin/llama-cli --model ${model_q5_1} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_1.log
(time ./bin/llama-cli --model ${model_q2_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q2_k.log
(time ./bin/llama-cli --model ${model_q3_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q3_k.log
(time ./bin/llama-cli --model ${model_q4_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_k.log
(time ./bin/llama-cli --model ${model_q5_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log
(time ./bin/llama-cli --model ${model_q6_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log
(time ./bin/llama-cli -no-cnv --model ${model_f16} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log
(time ./bin/llama-cli -no-cnv --model ${model_q8_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log
(time ./bin/llama-cli -no-cnv --model ${model_q4_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_0.log
(time ./bin/llama-cli -no-cnv --model ${model_q4_1} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_1.log
(time ./bin/llama-cli -no-cnv --model ${model_q5_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_0.log
(time ./bin/llama-cli -no-cnv --model ${model_q5_1} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_1.log
(time ./bin/llama-cli -no-cnv --model ${model_q2_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q2_k.log
(time ./bin/llama-cli -no-cnv --model ${model_q3_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q3_k.log
(time ./bin/llama-cli -no-cnv --model ${model_q4_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_k.log
(time ./bin/llama-cli -no-cnv --model ${model_q5_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log
(time ./bin/llama-cli -no-cnv --model ${model_q6_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log
(time ./bin/llama-perplexity --model ${model_f16} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log
(time ./bin/llama-perplexity --model ${model_q8_0} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log
@@ -460,17 +460,17 @@ function gg_run_pythia_1_4b {
./bin/llama-quantize ${model_f16} ${model_q5_k} q5_k
./bin/llama-quantize ${model_f16} ${model_q6_k} q6_k
(time ./bin/llama-cli --model ${model_f16} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log
(time ./bin/llama-cli --model ${model_q8_0} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log
(time ./bin/llama-cli --model ${model_q4_0} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_0.log
(time ./bin/llama-cli --model ${model_q4_1} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_1.log
(time ./bin/llama-cli --model ${model_q5_0} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_0.log
(time ./bin/llama-cli --model ${model_q5_1} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_1.log
(time ./bin/llama-cli --model ${model_q2_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q2_k.log
(time ./bin/llama-cli --model ${model_q3_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q3_k.log
(time ./bin/llama-cli --model ${model_q4_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_k.log
(time ./bin/llama-cli --model ${model_q5_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log
(time ./bin/llama-cli --model ${model_q6_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log
(time ./bin/llama-cli -no-cnv --model ${model_f16} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log
(time ./bin/llama-cli -no-cnv --model ${model_q8_0} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log
(time ./bin/llama-cli -no-cnv --model ${model_q4_0} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_0.log
(time ./bin/llama-cli -no-cnv --model ${model_q4_1} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_1.log
(time ./bin/llama-cli -no-cnv --model ${model_q5_0} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_0.log
(time ./bin/llama-cli -no-cnv --model ${model_q5_1} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_1.log
(time ./bin/llama-cli -no-cnv --model ${model_q2_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q2_k.log
(time ./bin/llama-cli -no-cnv --model ${model_q3_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q3_k.log
(time ./bin/llama-cli -no-cnv --model ${model_q4_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_k.log
(time ./bin/llama-cli -no-cnv --model ${model_q5_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log
(time ./bin/llama-cli -no-cnv --model ${model_q6_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log
(time ./bin/llama-perplexity --model ${model_f16} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log
(time ./bin/llama-perplexity --model ${model_q8_0} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log
@@ -591,17 +591,17 @@ function gg_run_pythia_2_8b {
./bin/llama-quantize ${model_f16} ${model_q5_k} q5_k
./bin/llama-quantize ${model_f16} ${model_q6_k} q6_k
(time ./bin/llama-cli --model ${model_f16} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log
(time ./bin/llama-cli --model ${model_q8_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log
(time ./bin/llama-cli --model ${model_q4_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_0.log
(time ./bin/llama-cli --model ${model_q4_1} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_1.log
(time ./bin/llama-cli --model ${model_q5_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_0.log
(time ./bin/llama-cli --model ${model_q5_1} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_1.log
(time ./bin/llama-cli --model ${model_q2_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q2_k.log
(time ./bin/llama-cli --model ${model_q3_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q3_k.log
(time ./bin/llama-cli --model ${model_q4_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_k.log
(time ./bin/llama-cli --model ${model_q5_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log
(time ./bin/llama-cli --model ${model_q6_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log
(time ./bin/llama-cli -no-cnv --model ${model_f16} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log
(time ./bin/llama-cli -no-cnv --model ${model_q8_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log
(time ./bin/llama-cli -no-cnv --model ${model_q4_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_0.log
(time ./bin/llama-cli -no-cnv --model ${model_q4_1} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_1.log
(time ./bin/llama-cli -no-cnv --model ${model_q5_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_0.log
(time ./bin/llama-cli -no-cnv --model ${model_q5_1} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_1.log
(time ./bin/llama-cli -no-cnv --model ${model_q2_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q2_k.log
(time ./bin/llama-cli -no-cnv --model ${model_q3_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q3_k.log
(time ./bin/llama-cli -no-cnv --model ${model_q4_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_k.log
(time ./bin/llama-cli -no-cnv --model ${model_q5_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log
(time ./bin/llama-cli -no-cnv --model ${model_q6_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log
(time ./bin/llama-perplexity --model ${model_f16} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log
(time ./bin/llama-perplexity --model ${model_q8_0} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log
+60 -20
View File
@@ -130,17 +130,26 @@ std::string common_arg::to_string() {
static void common_params_handle_model_default(
std::string & model,
std::string & model_url,
const std::string & model_url,
std::string & hf_repo,
std::string & hf_file) {
std::string & hf_file,
const std::string & hf_token) {
if (!hf_repo.empty()) {
// short-hand to avoid specifying --hf-file -> default it to --model
if (hf_file.empty()) {
if (model.empty()) {
throw std::invalid_argument("error: --hf-repo requires either --hf-file or --model\n");
auto auto_detected = common_get_hf_file(hf_repo, hf_token);
if (auto_detected.first.empty() || auto_detected.second.empty()) {
exit(1); // built without CURL, error message already printed
}
hf_repo = auto_detected.first;
hf_file = auto_detected.second;
} else {
hf_file = model;
}
hf_file = model;
} else if (model.empty()) {
}
// make sure model path is present (for caching purposes)
if (model.empty()) {
// this is to avoid different repo having same file name, or same file name in different subdirs
std::string filename = hf_repo + "_" + hf_file;
// to make sure we don't have any slashes in the filename
@@ -290,8 +299,8 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
}
// TODO: refactor model params in a common struct
common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file);
common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file);
common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file, params.hf_token);
common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file, params.hf_token);
if (params.escape) {
string_process_escapes(params.prompt);
@@ -367,6 +376,30 @@ static std::vector<ggml_backend_dev_t> parse_device_list(const std::string & val
return devices;
}
static void add_rpc_devices(std::string servers) {
auto rpc_servers = string_split<std::string>(servers, ',');
if (rpc_servers.empty()) {
throw std::invalid_argument("no RPC servers specified");
}
ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC");
if (!rpc_reg) {
throw std::invalid_argument("failed to find RPC backend");
}
typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint);
ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
if (!ggml_backend_rpc_add_device_fn) {
throw std::invalid_argument("failed to find RPC device add function");
}
for (const auto & server : rpc_servers) {
ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str());
if (dev) {
ggml_backend_device_register(dev);
} else {
throw std::invalid_argument("failed to register RPC device");
}
}
}
bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **)) {
auto ctx_arg = common_params_parser_init(params, ex, print_usage);
const common_params params_org = ctx_arg.params; // the example can modify the default params
@@ -768,15 +801,19 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-cnv", "--conversation"},
string_format(
"run in conversation mode:\n"
"- does not print special tokens and suffix/prefix\n"
"- interactive mode is also enabled\n"
"(default: %s)",
params.conversation ? "true" : "false"
),
"run in conversation mode:\n"
"- does not print special tokens and suffix/prefix\n"
"- interactive mode is also enabled\n"
"(default: auto enabled if chat template is available)",
[](common_params & params) {
params.conversation = true;
params.conversation_mode = COMMON_CONVERSATION_MODE_ENABLED;
}
).set_examples({LLAMA_EXAMPLE_MAIN}));
add_opt(common_arg(
{"-no-cnv", "--no-conversation"},
"force disable conversation mode (default: false)",
[](common_params & params) {
params.conversation_mode = COMMON_CONVERSATION_MODE_DISABLED;
}
).set_examples({LLAMA_EXAMPLE_MAIN}));
add_opt(common_arg(
@@ -1372,7 +1409,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--rpc"}, "SERVERS",
"comma separated list of RPC servers",
[](common_params & params, const std::string & value) {
params.rpc_servers = value;
add_rpc_devices(value);
GGML_UNUSED(params);
}
).set_env("LLAMA_ARG_RPC"));
}
@@ -1583,21 +1621,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_env("LLAMA_ARG_MODEL_URL"));
add_opt(common_arg(
{"-hfr", "--hf-repo"}, "REPO",
"Hugging Face model repository (default: unused)",
{"-hf", "-hfr", "--hf-repo"}, "<user>/<model>[:quant]",
"Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.\n"
"example: unsloth/phi-4-GGUF:q4_k_m\n"
"(default: unused)",
[](common_params & params, const std::string & value) {
params.hf_repo = value;
}
).set_env("LLAMA_ARG_HF_REPO"));
add_opt(common_arg(
{"-hff", "--hf-file"}, "FILE",
"Hugging Face model file (default: unused)",
"Hugging Face model file. If specified, it will override the quant in --hf-repo (default: unused)",
[](common_params & params, const std::string & value) {
params.hf_file = value;
}
).set_env("LLAMA_ARG_HF_FILE"));
add_opt(common_arg(
{"-hfrv", "--hf-repo-v"}, "REPO",
{"-hfv", "-hfrv", "--hf-repo-v"}, "<user>/<model>[:quant]",
"Hugging Face model repository for the vocoder model (default: unused)",
[](common_params & params, const std::string & value) {
params.vocoder.hf_repo = value;
+100 -7
View File
@@ -73,6 +73,22 @@
#include <sys/syslimits.h>
#endif
#define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
//
// CURL utils
//
using curl_ptr = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
// cannot use unique_ptr for curl_slist, because we cannot update without destroying the old one
struct curl_slist_ptr {
struct curl_slist * ptr = nullptr;
~curl_slist_ptr() {
if (ptr) {
curl_slist_free_all(ptr);
}
}
};
#endif // LLAMA_USE_CURL
using json = nlohmann::ordered_json;
@@ -1027,7 +1043,6 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
if (params.n_gpu_layers != -1) {
mparams.n_gpu_layers = params.n_gpu_layers;
}
mparams.rpc_servers = params.rpc_servers.c_str();
mparams.main_gpu = params.main_gpu;
mparams.split_mode = params.split_mode;
mparams.tensor_split = params.tensor_split;
@@ -1130,7 +1145,8 @@ static bool curl_perform_with_retry(const std::string & url, CURL * curl, int ma
static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) {
// Initialize libcurl
std::unique_ptr<CURL, decltype(&curl_easy_cleanup)> curl(curl_easy_init(), &curl_easy_cleanup);
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
curl_slist_ptr http_headers;
if (!curl) {
LOG_ERR("%s: error initializing libcurl\n", __func__);
return false;
@@ -1144,11 +1160,9 @@ static bool common_download_file(const std::string & url, const std::string & pa
// Check if hf-token or bearer-token was specified
if (!hf_token.empty()) {
std::string auth_header = "Authorization: Bearer ";
auth_header += hf_token.c_str();
struct curl_slist *http_headers = NULL;
http_headers = curl_slist_append(http_headers, auth_header.c_str());
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers);
std::string auth_header = "Authorization: Bearer " + hf_token;
http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
}
#if defined(_WIN32)
@@ -1444,6 +1458,80 @@ struct llama_model * common_load_model_from_hf(
return common_load_model_from_url(model_url, local_path, hf_token, params);
}
/**
* Allow getting the HF file from the HF repo with tag (like ollama), for example:
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q4
* - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s
* Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo)
*
* Return pair of <repo, file> (with "repo" already having tag removed)
*
* Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files.
*/
std::pair<std::string, std::string> common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & hf_token) {
auto parts = string_split<std::string>(hf_repo_with_tag, ':');
std::string tag = parts.size() > 1 ? parts.back() : "latest";
std::string hf_repo = parts[0];
if (string_split<std::string>(hf_repo, '/').size() != 2) {
throw std::invalid_argument("error: invalid HF repo format, expected <user>/<model>[:quant]\n");
}
// fetch model info from Hugging Face Hub API
json model_info;
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
curl_slist_ptr http_headers;
std::string res_str;
std::string url = "https://huggingface.co/v2/" + hf_repo + "/manifests/" + tag;
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L);
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data);
auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t {
static_cast<std::string *>(data)->append((char * ) ptr, size * nmemb);
return size * nmemb;
};
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_str);
#if defined(_WIN32)
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
#endif
if (!hf_token.empty()) {
std::string auth_header = "Authorization: Bearer " + hf_token;
http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
}
// Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
http_headers.ptr = curl_slist_append(http_headers.ptr, "Accept: application/json");
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
CURLcode res = curl_easy_perform(curl.get());
if (res != CURLE_OK) {
throw std::runtime_error("error: cannot make GET request to HF API");
}
long res_code;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code);
if (res_code == 200) {
model_info = json::parse(res_str);
} else if (res_code == 401) {
throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token");
} else {
throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str()));
}
// check response
if (!model_info.contains("ggufFile")) {
throw std::runtime_error("error: model does not have ggufFile");
}
json & gguf_file = model_info.at("ggufFile");
if (!gguf_file.contains("rfilename")) {
throw std::runtime_error("error: ggufFile does not have rfilename");
}
return std::make_pair(hf_repo, gguf_file.at("rfilename"));
}
#else
struct llama_model * common_load_model_from_url(
@@ -1465,6 +1553,11 @@ struct llama_model * common_load_model_from_hf(
return nullptr;
}
std::pair<std::string, std::string> common_get_hf_file(const std::string &, const std::string &) {
LOG_WRN("%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
return std::make_pair("", "");
}
#endif // LLAMA_USE_CURL
//
+16 -2
View File
@@ -103,6 +103,12 @@ enum dimre_method {
DIMRE_METHOD_MEAN,
};
enum common_conversation_mode {
COMMON_CONVERSATION_MODE_DISABLED = 0,
COMMON_CONVERSATION_MODE_ENABLED = 1,
COMMON_CONVERSATION_MODE_AUTO = 2,
};
// sampling parameters
struct common_params_sampling {
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
@@ -240,7 +246,6 @@ struct common_params {
std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT
std::string logits_file = ""; // file for saving *all* logits // NOLINT
std::string rpc_servers = ""; // comma separated list of RPC servers // NOLINT
std::vector<std::string> in_files; // all input files
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
@@ -275,7 +280,6 @@ struct common_params {
bool special = false; // enable special token output
bool interactive = false; // interactive mode
bool interactive_first = false; // wait for user input immediately
bool conversation = false; // conversation mode (does not print special tokens and suffix/prefix)
bool prompt_cache_all = false; // save user input and generations to prompt cache
bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it
@@ -301,6 +305,8 @@ struct common_params {
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO;
// multimodal models (see examples/llava)
std::string mmproj = ""; // path to multimodal projector // NOLINT
std::vector<std::string> image; // path to image file(s)
@@ -454,6 +460,11 @@ static bool string_starts_with(const std::string & str,
return str.rfind(prefix, 0) == 0;
}
static bool string_ends_with(const std::string & str,
const std::string & suffix) { // While we wait for C++20's std::string::ends_with...
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
}
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
void string_process_escapes(std::string & input);
@@ -501,6 +512,9 @@ struct llama_model * common_load_model_from_hf(
const std::string & local_path,
const std::string & hf_token,
const struct llama_model_params & params);
std::pair<std::string, std::string> common_get_hf_file(
const std::string & hf_repo_with_tag,
const std::string & hf_token);
// clear LoRA adapters from context, then apply new list of adapters
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora);
+60
View File
@@ -2882,6 +2882,66 @@ class InternLM2Model(Model):
return [(self.map_tensor_name(name), data_torch)]
@Model.register("InternLM3ForCausalLM")
class InternLM3Model(Model):
model_arch = gguf.MODEL_ARCH.LLAMA
def set_vocab(self):
tokens, scores, toktypes = self._create_vocab_sentencepiece()
self.gguf_writer.add_tokenizer_model("llama")
self.gguf_writer.add_tokenizer_pre("default")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
if tokenizer_config_file.is_file():
with open(tokenizer_config_file, "r", encoding="utf-8") as f:
tokenizer_config_json = json.load(f)
if "add_prefix_space" in tokenizer_config_json:
self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"])
if "added_tokens_decoder" in tokenizer_config_json:
for token_id, token_data in tokenizer_config_json["added_tokens_decoder"].items():
if token_data.get("special"):
token_id = int(token_id)
token = token_data["content"]
special_vocab._set_special_token(token, token_id)
# update eos token
if token == '<|im_end|>' and "eos" in special_vocab.special_token_ids:
special_vocab.special_token_ids["eos"] = token_id
special_vocab.add_to_gguf(self.gguf_writer)
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
if "head_dim" in hparams:
rope_dim = hparams["head_dim"]
else:
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
self.gguf_writer.add_rope_dimension_count(rope_dim)
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
if self.hparams["rope_scaling"].get("type") == "linear" or self.hparams["rope_scaling"].get("rope_type") == "linear":
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams.get("num_key_value_heads")
if name.endswith(("q_proj.weight", "q_proj.bias")):
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
if name.endswith(("k_proj.weight", "k_proj.bias")):
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
return [(self.map_tensor_name(name), data_torch)]
@Model.register("BertModel", "BertForMaskedLM", "CamembertModel")
class BertModel(Model):
model_arch = gguf.MODEL_ARCH.BERT
+5 -5
View File
@@ -41,7 +41,7 @@ echo PASS
echo
# 2b. Test the sharded model is loading properly
$MAIN --model $WORK_PATH/ggml-model-split-00001-of-00006.gguf --n-predict 32
$MAIN -no-cnv --model $WORK_PATH/ggml-model-split-00001-of-00006.gguf --n-predict 32
echo PASS
echo
@@ -51,7 +51,7 @@ echo PASS
echo
# 3b. Test the merged model is loading properly
$MAIN --model $WORK_PATH/ggml-model-merge.gguf --n-predict 32
$MAIN -no-cnv --model $WORK_PATH/ggml-model-merge.gguf --n-predict 32
echo PASS
echo
@@ -61,7 +61,7 @@ echo PASS
echo
# 4b. Test the sharded model is loading properly
$MAIN --model $WORK_PATH/ggml-model-split-32-tensors-00001-of-00007.gguf --n-predict 32
$MAIN -no-cnv --model $WORK_PATH/ggml-model-split-32-tensors-00001-of-00007.gguf --n-predict 32
echo PASS
echo
@@ -71,7 +71,7 @@ echo
#echo
# 5b. Test the merged model is loading properly
#$MAIN --model $WORK_PATH/ggml-model-merge-2.gguf --n-predict 32
#$MAIN -no-cnv --model $WORK_PATH/ggml-model-merge-2.gguf --n-predict 32
#echo PASS
#echo
@@ -81,7 +81,7 @@ echo PASS
echo
# 6b. Test the sharded model is loading properly
$MAIN --model $WORK_PATH/ggml-model-split-2G-00001-of-00002.gguf --n-predict 32
$MAIN -no-cnv --model $WORK_PATH/ggml-model-split-2G-00001-of-00002.gguf --n-predict 32
echo PASS
echo
+33 -4
View File
@@ -683,7 +683,7 @@ struct cmd_params_instance {
bool cpu_strict;
int poll;
int n_gpu_layers;
std::string rpc_servers;
std::string rpc_servers_str;
llama_split_mode split_mode;
int main_gpu;
bool no_kv_offload;
@@ -696,8 +696,37 @@ struct cmd_params_instance {
llama_model_params mparams = llama_model_default_params();
mparams.n_gpu_layers = n_gpu_layers;
if (!rpc_servers.empty()) {
mparams.rpc_servers = rpc_servers.c_str();
if (!rpc_servers_str.empty()) {
auto rpc_servers = string_split<std::string>(rpc_servers_str, ',');
// add RPC devices
if (!rpc_servers.empty()) {
ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC");
if (!rpc_reg) {
fprintf(stderr, "%s: failed to find RPC backend\n", __func__);
exit(1);
}
typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint);
ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
if (!ggml_backend_rpc_add_device_fn) {
fprintf(stderr, "%s: failed to find RPC device add function\n", __func__);
exit(1);
}
static std::vector<ggml_backend_dev_t> devices;
devices.clear();
for (const std::string & server : rpc_servers) {
ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str());
if (dev) {
devices.push_back(dev);
} else {
fprintf(stderr, "%s: failed to add RPC device for server '%s'\n", __func__, server.c_str());
exit(1);
}
}
devices.push_back(nullptr);
mparams.devices = devices.data();
}
}
mparams.split_mode = split_mode;
mparams.main_gpu = main_gpu;
@@ -708,7 +737,7 @@ struct cmd_params_instance {
}
bool equal_mparams(const cmd_params_instance & other) const {
return model == other.model && n_gpu_layers == other.n_gpu_layers && rpc_servers == other.rpc_servers &&
return model == other.model && n_gpu_layers == other.n_gpu_layers && rpc_servers_str == other.rpc_servers_str &&
split_mode == other.split_mode && main_gpu == other.main_gpu && use_mmap == other.use_mmap &&
tensor_split == other.tensor_split;
}
@@ -347,6 +347,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
jlong context_pointer,
jlong batch_pointer,
jstring jtext,
jboolean format_chat,
jint n_len
) {
@@ -356,7 +357,8 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
const auto context = reinterpret_cast<llama_context *>(context_pointer);
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
const auto tokens_list = common_tokenize(context, text, 1);
bool parse_special = (format_chat == JNI_TRUE);
const auto tokens_list = common_tokenize(context, text, true, parse_special);
auto n_ctx = llama_n_ctx(context);
auto n_kv_req = tokens_list.size() + (n_len - tokens_list.size());
@@ -368,7 +370,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
}
for (auto id : tokens_list) {
LOGi("%s", common_token_to_piece(context, id).c_str());
LOGi("token: `%s`-> %d ", common_token_to_piece(context, id).c_str(), id);
}
common_batch_clear(*batch);
@@ -65,6 +65,7 @@ class LLamaAndroid {
context: Long,
batch: Long,
text: String,
formatChat: Boolean,
nLen: Int
): Int
@@ -115,10 +116,10 @@ class LLamaAndroid {
}
}
fun send(message: String): Flow<String> = flow {
fun send(message: String, formatChat: Boolean = false): Flow<String> = flow {
when (val state = threadLocalState.get()) {
is State.Loaded -> {
val ncur = IntVar(completion_init(state.context, state.batch, message, nlen))
val ncur = IntVar(completion_init(state.context, state.batch, message, formatChat, nlen))
while (ncur.value <= nlen) {
val str = completion_loop(state.context, state.batch, state.sampler, nlen, ncur)
if (str == null) {
+34 -10
View File
@@ -30,6 +30,8 @@
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
static const char * DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant";
static llama_context ** g_ctx;
static llama_model ** g_model;
static common_sampler ** g_smpl;
@@ -204,8 +206,24 @@ int main(int argc, char ** argv) {
LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, n_ctx);
}
// auto enable conversation mode if chat template is available
const bool has_chat_template = !common_get_builtin_chat_template(model).empty() || !params.chat_template.empty();
if (params.conversation_mode == COMMON_CONVERSATION_MODE_AUTO) {
if (has_chat_template) {
LOG_INF("%s: chat template is available, enabling conversation mode (disable it with -no-cnv)\n", __func__);
params.conversation_mode = COMMON_CONVERSATION_MODE_ENABLED;
} else {
params.conversation_mode = COMMON_CONVERSATION_MODE_DISABLED;
}
}
// in case user force-activate conversation mode (via -cnv) without proper chat template, we show a warning
if (params.conversation_mode && !has_chat_template) {
LOG_WRN("%s: chat template is not available or is not supported. This may cause the model to output suboptimal responses\n", __func__);
}
// print chat template example in conversation mode
if (params.conversation) {
if (params.conversation_mode) {
if (params.enable_chat_template) {
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(model, params.chat_template).c_str());
} else {
@@ -252,8 +270,10 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd_inp;
{
auto prompt = (params.conversation && params.enable_chat_template && !params.prompt.empty())
? chat_add_and_format(model, chat_msgs, "system", params.prompt) // format the system prompt in conversation mode
auto prompt = (params.conversation_mode && params.enable_chat_template)
// format the system prompt in conversation mode (fallback to default if empty)
? chat_add_and_format(model, chat_msgs, "system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt)
// otherwise use the prompt as is
: params.prompt;
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
LOG_DBG("tokenize the prompt\n");
@@ -327,7 +347,7 @@ int main(int argc, char ** argv) {
params.n_keep += add_bos; // always keep the BOS token
}
if (params.conversation) {
if (params.conversation_mode) {
params.interactive_first = true;
}
@@ -451,7 +471,11 @@ int main(int argc, char ** argv) {
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
LOG_INF( " - Press Ctrl+C to interject at any time.\n");
#endif
LOG_INF( "%s\n", control_message);
LOG_INF( "%s", control_message);
if (params.conversation_mode && params.enable_chat_template && params.prompt.empty()) {
LOG_INF( " - Using default system message. To change it, set a different value via -p PROMPT or -f FILE argument.\n");
}
LOG_INF("\n");
is_interacting = params.interactive_first;
}
@@ -763,7 +787,7 @@ int main(int argc, char ** argv) {
}
// if current token is not EOG, we add it to current assistant message
if (params.conversation) {
if (params.conversation_mode) {
const auto id = common_sampler_last(smpl);
assistant_ss << common_token_to_piece(ctx, id, false);
}
@@ -771,7 +795,7 @@ int main(int argc, char ** argv) {
if (n_past > 0 && is_interacting) {
LOG_DBG("waiting for user input\n");
if (params.conversation) {
if (params.conversation_mode) {
LOG("\n> ");
}
@@ -781,7 +805,7 @@ int main(int argc, char ** argv) {
}
std::string buffer;
if (!params.input_prefix.empty() && !params.conversation) {
if (!params.input_prefix.empty() && !params.conversation_mode) {
LOG_DBG("appending input prefix: '%s'\n", params.input_prefix.c_str());
LOG("%s", params.input_prefix.c_str());
}
@@ -805,7 +829,7 @@ int main(int argc, char ** argv) {
// Entering a empty line lets the user pass control back
if (buffer.length() > 1) {
// append input suffix if any
if (!params.input_suffix.empty() && !params.conversation) {
if (!params.input_suffix.empty() && !params.conversation_mode) {
LOG_DBG("appending input suffix: '%s'\n", params.input_suffix.c_str());
LOG("%s", params.input_suffix.c_str());
}
@@ -818,7 +842,7 @@ int main(int argc, char ** argv) {
string_process_escapes(buffer);
}
bool format_chat = params.conversation && params.enable_chat_template;
bool format_chat = params.conversation_mode && params.enable_chat_template;
std::string user_inp = format_chat
? chat_add_and_format(model, chat_msgs, "user", std::move(buffer))
: std::move(buffer);
+2 -2
View File
@@ -47,7 +47,7 @@ echo PASS
echo
# 3a. Test the requanted model is loading properly
$MAIN --model $WORK_PATH/ggml-model-requant-00001-of-00006.gguf --n-predict 32
$MAIN -no-cnv --model $WORK_PATH/ggml-model-requant-00001-of-00006.gguf --n-predict 32
echo PASS
echo
@@ -57,7 +57,7 @@ echo PASS
echo
# 4b. Test the requanted model is loading properly
$MAIN --model $WORK_PATH/ggml-model-requant-merge.gguf --n-predict 32
$MAIN -no-cnv --model $WORK_PATH/ggml-model-requant-merge.gguf --n-predict 32
echo PASS
echo
Binary file not shown.
+6 -2
View File
@@ -37,7 +37,7 @@
<div v-for="conv in conversations" :class="{
'btn btn-ghost justify-start font-normal': true,
'btn-active': conv.id === viewingConvId,
}" @click="setViewingConv(conv.id)">
}" @click="setViewingConv(conv.id)" dir="auto">
<span class="truncate">{{ conv.messages[0].content }}</span>
</div>
<div class="text-center text-xs opacity-40 mt-auto mx-4">
@@ -156,6 +156,7 @@
@keydown.enter.shift.exact.prevent="inputMsg += '\n'"
:disabled="isGenerating"
id="msg-input"
dir="auto"
></textarea>
<button v-if="!isGenerating" class="btn btn-primary ml-2" @click="sendMessage" :disabled="inputMsg.length === 0">Send</button>
<button v-else class="btn btn-neutral ml-2" @click="stopGeneration">Stop</button>
@@ -248,6 +249,7 @@
<!-- textarea for editing message -->
<template v-if="editingContent !== null">
<textarea
dir="auto"
class="textarea textarea-bordered bg-base-100 text-base-content w-[calc(90vw-8em)] lg:w-96"
v-model="editingContent"></textarea>
<br/>
@@ -258,7 +260,9 @@
<!-- show loading dots for pending message -->
<span v-if="msg.content === null" class="loading loading-dots loading-md"></span>
<!-- render message as markdown -->
<vue-markdown v-else :source="msg.content"></vue-markdown>
<div v-else dir="auto">
<vue-markdown :source="msg.content"></vue-markdown>
</div>
<!-- render timings if enabled -->
<div class="dropdown dropdown-hover dropdown-top mt-2" v-if="timings && config.showTokensPerSecond">
<div tabindex="0" role="button" class="cursor-pointer font-semibold text-sm opacity-60">Speed: {{ timings.predicted_per_second.toFixed(1) }} t/s</div>
+2 -2
View File
@@ -111,12 +111,12 @@ const VueMarkdown = defineComponent(
highlight: function (str, lang) { // Add highlight.js
if (lang && hljs.getLanguage(lang)) {
try {
return '<pre><code class="hljs">' +
return '<pre dir="auto"><code class="hljs">' +
hljs.highlight(str, { language: lang, ignoreIllegals: true }).value +
'</code></pre>';
} catch (__) {}
}
return '<pre><code class="hljs">' + md.value.utils.escapeHtml(str) + '</code></pre>';
return '<pre dir="auto"><code class="hljs">' + md.value.utils.escapeHtml(str) + '</code></pre>';
}
}));
// support latex with double dollar sign and square brackets
+37
View File
@@ -78,3 +78,40 @@ play the audio:
$ aplay output.wav
```
### Running the example with llama-server
Running this example with `llama-server` is also possible and requires two
server instances to be started. One will serve the LLM model and the other
will serve the voice decoder model.
The LLM model server can be started with the following command:
```console
$ ./build/bin/llama-server -m ./models/outetts-0.2-0.5B-q8_0.gguf --port 8020
```
And the voice decoder model server can be started using:
```console
./build/bin/llama-server -m ./models/wavtokenizer-large-75-f16.gguf --port 8021 --embeddings --pooling none
```
Then we can run [tts-outetts.py](tts-outetts.py) to generate the audio.
First create a virtual environment for python and install the required
dependencies (this in only required to be done once):
```console
$ python3 -m venv venv
$ source venv/bin/activate
(venv) pip install requests numpy
```
And then run the python script using:
```conole
(venv) python ./examples/tts/tts-outetts.py http://localhost:8020 http://localhost:8021 "Hello world"
spectrogram generated: n_codes: 90, n_embd: 1282
converting to audio ...
audio generated: 28800 samples
audio written to file "output.wav"
```
And to play the audio we can again use aplay or any other media player:
```console
$ aplay output.wav
```
+126 -2
View File
@@ -3,6 +3,121 @@ import sys
#import struct
import requests
import re
import struct
import numpy as np
from concurrent.futures import ThreadPoolExecutor
def fill_hann_window(size, periodic=True):
if periodic:
return np.hanning(size + 1)[:-1]
return np.hanning(size)
def irfft(n_fft, complex_input):
return np.fft.irfft(complex_input, n=n_fft)
def fold(buffer, n_out, n_win, n_hop, n_pad):
result = np.zeros(n_out)
n_frames = len(buffer) // n_win
for i in range(n_frames):
start = i * n_hop
end = start + n_win
result[start:end] += buffer[i * n_win:(i + 1) * n_win]
return result[n_pad:-n_pad] if n_pad > 0 else result
def process_frame(args):
l, n_fft, ST, hann = args
frame = irfft(n_fft, ST[l])
frame = frame * hann
hann2 = hann * hann
return frame, hann2
def embd_to_audio(embd, n_codes, n_embd, n_thread=4):
embd = np.asarray(embd, dtype=np.float32).reshape(n_codes, n_embd)
n_fft = 1280
n_hop = 320
n_win = 1280
n_pad = (n_win - n_hop) // 2
n_out = (n_codes - 1) * n_hop + n_win
hann = fill_hann_window(n_fft, True)
E = np.zeros((n_embd, n_codes), dtype=np.float32)
for l in range(n_codes):
for k in range(n_embd):
E[k, l] = embd[l, k]
half_embd = n_embd // 2
S = np.zeros((n_codes, half_embd + 1), dtype=np.complex64)
for k in range(half_embd):
for l in range(n_codes):
mag = E[k, l]
phi = E[k + half_embd, l]
mag = np.clip(np.exp(mag), 0, 1e2)
S[l, k] = mag * np.exp(1j * phi)
res = np.zeros(n_codes * n_fft)
hann2_buffer = np.zeros(n_codes * n_fft)
with ThreadPoolExecutor(max_workers=n_thread) as executor:
args = [(l, n_fft, S, hann) for l in range(n_codes)]
results = list(executor.map(process_frame, args))
for l, (frame, hann2) in enumerate(results):
res[l*n_fft:(l+1)*n_fft] = frame
hann2_buffer[l*n_fft:(l+1)*n_fft] = hann2
audio = fold(res, n_out, n_win, n_hop, n_pad)
env = fold(hann2_buffer, n_out, n_win, n_hop, n_pad)
mask = env > 1e-10
audio[mask] /= env[mask]
return audio
def save_wav(filename, audio_data, sample_rate):
num_channels = 1
bits_per_sample = 16
bytes_per_sample = bits_per_sample // 8
data_size = len(audio_data) * bytes_per_sample
byte_rate = sample_rate * num_channels * bytes_per_sample
block_align = num_channels * bytes_per_sample
chunk_size = 36 + data_size # 36 = size of header minus first 8 bytes
header = struct.pack(
'<4sI4s4sIHHIIHH4sI',
b'RIFF',
chunk_size,
b'WAVE',
b'fmt ',
16, # fmt chunk size
1, # audio format (PCM)
num_channels,
sample_rate,
byte_rate,
block_align,
bits_per_sample,
b'data',
data_size
)
audio_data = np.clip(audio_data * 32767, -32768, 32767)
pcm_data = audio_data.astype(np.int16)
with open(filename, 'wb') as f:
f.write(header)
f.write(pcm_data.tobytes())
def process_text(text: str):
text = re.sub(r'\d+(\.\d+)?', lambda x: x.group(), text.lower()) # TODO this needs to be fixed
@@ -170,6 +285,15 @@ n_embd = len(embd[0])
print('spectrogram generated: n_codes: %d, n_embd: %d' % (n_codes, n_embd))
# post-process the spectrogram to convert to audio
# TODO: see the tts.cpp:embd_to_audio() and implement it in Python
print('converting to audio ...')
print('TODO: see the tts.cpp:embd_to_audio() and implement it in Python')
audio = embd_to_audio(embd, n_codes, n_embd)
print('audio generated: %d samples' % len(audio))
filename = "output.wav"
sample_rate = 24000 # sampling rate
# zero out first 0.25 seconds
audio[:24000 // 4] = 0.0
save_wav(filename, audio, sample_rate)
print('audio written to file "%s"' % filename)
+3
View File
@@ -185,6 +185,9 @@ option(GGML_OPENCL_PROFILING "ggml: use OpenCL profiling (increas
option(GGML_OPENCL_EMBED_KERNELS "ggml: embed kernels" ON)
option(GGML_OPENCL_USE_ADRENO_KERNELS "ggml: use optimized kernels for Adreno" ON)
# toolchain for vulkan-shaders-gen
set (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN "" CACHE FILEPATH "ggml: toolchain file for vulkan-shaders-gen")
# extra artifacts
option(GGML_BUILD_TESTS "ggml: build tests" ${GGML_STANDALONE})
option(GGML_BUILD_EXAMPLES "ggml: build examples" ${GGML_STANDALONE})
+2
View File
@@ -203,6 +203,8 @@ extern "C" {
// Backend registry
//
GGML_API void ggml_backend_device_register(ggml_backend_dev_t device);
// Backend (reg) enumeration
GGML_API size_t ggml_backend_reg_count(void);
GGML_API ggml_backend_reg_t ggml_backend_reg_get(size_t index);
+26 -5
View File
@@ -1384,16 +1384,20 @@ extern "C" {
float scale,
float max_bias);
GGML_API struct ggml_tensor * ggml_soft_max_back(
GGML_API struct ggml_tensor * ggml_soft_max_ext_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
struct ggml_tensor * b,
float scale,
float max_bias);
// in-place, returns view(a)
GGML_API struct ggml_tensor * ggml_soft_max_back_inplace(
GGML_API struct ggml_tensor * ggml_soft_max_ext_back_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
struct ggml_tensor * b,
float scale,
float max_bias);
// rotary position embedding
// if (mode & 1) - skip n_past elements (NOT SUPPORTED)
@@ -1500,7 +1504,7 @@ extern "C" {
// rotary position embedding backward, i.e compute dx from dy
// a - dy
GGML_API struct ggml_tensor * ggml_rope_back(
GGML_API struct ggml_tensor * ggml_rope_ext_back(
struct ggml_context * ctx,
struct ggml_tensor * a, // gradients of ggml_rope result
struct ggml_tensor * b, // positions
@@ -1515,6 +1519,23 @@ extern "C" {
float beta_fast,
float beta_slow);
GGML_API struct ggml_tensor * ggml_rope_multi_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
int n_dims,
int sections[4],
int mode,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow);
// clamp
// in-place, returns view(a)
GGML_API struct ggml_tensor * ggml_clamp(
+5
View File
@@ -37,6 +37,7 @@ static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml
return true;
}
// ops that return true for this function must not use restrict pointers for their backend implementations
static bool ggml_op_can_inplace(enum ggml_op op) {
switch (op) {
case GGML_OP_SCALE:
@@ -52,8 +53,12 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
case GGML_OP_LOG:
case GGML_OP_UNARY:
case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK:
case GGML_OP_SILU_BACK:
case GGML_OP_RMS_NORM:
case GGML_OP_RMS_NORM_BACK:
case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX_BACK:
return true;
default:
-1
View File
@@ -208,7 +208,6 @@ extern "C" {
// Internal backend registry API
GGML_API void ggml_backend_register(ggml_backend_reg_t reg);
GGML_API void ggml_backend_device_register(ggml_backend_dev_t device);
// Add backend dynamic loading support to the backend
+82 -1
View File
@@ -5573,7 +5573,88 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
uint32_t utmp[4];
#ifdef __ARM_NEON
#ifdef __ARM_FEATURE_SVE
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
memcpy(utmp, x[i].scales, K_SCALE_SIZE);
uint32x2_t mins8 = { 0 };
mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0);
mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1);
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
utmp[0] &= kmask1;
const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
sumf -= dmin * vaddvq_s32(prod);
const uint8_t * scales = (const uint8_t *)utmp;
const uint8_t * restrict q4 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
const int vector_length = ggml_cpu_get_sve_cnt()*8;
const svuint8_t m4b = svdup_n_u8(0xf);
const svint32_t mzero = svdup_n_s32(0);
svint32_t sumi1 = svdup_n_s32(0);
svint32_t sumi1_1 = svdup_n_s32(0);
svint32_t sumi1_2 = svdup_n_s32(0);
svint32_t sumi2 = svdup_n_s32(0);
svint32_t sumi2_1 = svdup_n_s32(0);
svint32_t sumi2_2 = svdup_n_s32(0);
switch (vector_length) {
case 128:
{
for (int j = 0; j < QK_K/64; ++j) {
svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), m4b));
svint8_t q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
sumi1_1 = svmla_n_s32_x(svptrue_b32(), sumi1_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), m4b));
q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
sumi1_2 = svmla_n_s32_x(svptrue_b32(), sumi1_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), 4));
q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
sumi2_1 = svmla_n_s32_x(svptrue_b32(), sumi2_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), 4));
q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
sumi2_2 = svmla_n_s32_x(svptrue_b32(), sumi2_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
q4 += 32;
}
sumi1 = svadd_s32_x(svptrue_b32(), sumi1_1, sumi1_2);
sumi2 = svadd_s32_x(svptrue_b32(), sumi2_1, sumi2_2);
sumf += d * (svaddv_s32(svptrue_b32(), svadd_s32_x(svptrue_b32(), sumi1, sumi2)));
} break;
case 256:
case 512:
{
for (int j = 0; j < QK_K/64; ++j) {
const svuint8_t q4bits = svld1_u8(svptrue_pat_b8(SV_VL32), q4); q4 += 32;
svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_pat_b8(SV_VL32), q4bits, m4b));
svint8_t q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
sumi1 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q4bits, 4));
q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
sumi2 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
}
sumf += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), sumi1, sumi2)));
} break;
default:
assert(false && "Unsupported vector length");
break;
}
}
*s = sumf;
#elif __ARM_NEON
const uint8x16_t m4b = vdupq_n_u8(0xf);
const int32x4_t mzero = vdupq_n_s32(0);
+131 -58
View File
@@ -3967,6 +3967,57 @@ static void ggml_compute_forward_dup_bytes(
}
}
static void ggml_compute_forward_dup_q(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
GGML_TENSOR_BINARY_OP_LOCALS
const enum ggml_type type = src0->type;
ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
size_t qk = ggml_blck_size(type);
const int64_t nr = ggml_nelements(src1) / qk;
// destination must be contiguous in the first dimension
GGML_ASSERT(nb10 == ggml_type_size(dst->type));
// must either have first dimension large enough to hold a row, or fully contiguous
GGML_ASSERT((ne10 % qk) == 0 || ggml_is_contiguous(dst));
const int ith = params->ith;
const int nth = params->nth;
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int64_t ir = ir0; ir < ir1; ++ir) {
uint32_t i = ir * qk;
const int64_t i03 = i/(ne00 * ne01 * ne02);
const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
const int64_t i13 = i/(ne10 * ne11 * ne12);
const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
dequantize_row_q(
(const void *) ((char *) src0->data + x_offset),
(float *) ((char *) dst->data + dst_offset), qk);
}
}
static void ggml_compute_forward_dup(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
@@ -3993,6 +4044,10 @@ static void ggml_compute_forward_dup(
} break;
default:
{
if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) {
ggml_compute_forward_dup_q(params, dst);
break;
}
GGML_ABORT("fatal error");
}
}
@@ -6691,20 +6746,20 @@ static void ggml_compute_forward_silu_back_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * grad = dst->src[1];
const struct ggml_tensor * grad = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
assert(ggml_is_contiguous_1(grad));
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(src1));
assert(ggml_is_contiguous_1(dst));
assert(ggml_are_same_shape(src0, dst));
assert(ggml_are_same_shape(src0, grad));
assert(ggml_are_same_shape(src1, dst));
assert(ggml_are_same_shape(src1, grad));
const int ith = params->ith;
const int nth = params->nth;
const int nc = src0->ne[0];
const int nr = ggml_nrows(src0);
const int nc = src1->ne[0];
const int nr = ggml_nrows(src1);
// rows per thread
const int dr = (nr + nth - 1)/nth;
@@ -6716,7 +6771,7 @@ static void ggml_compute_forward_silu_back_f32(
for (int i1 = ir0; i1 < ir1; i1++) {
ggml_vec_silu_backward_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])),
(float *) ((char *) src1->data + i1*(src1->nb[1])),
(float *) ((char *) grad->data + i1*(grad->nb[1])));
#ifndef NDEBUG
@@ -6895,7 +6950,7 @@ static void ggml_compute_forward_norm_f32(
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps > 0.0f);
GGML_ASSERT(eps >= 0.0f);
// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
@@ -6966,7 +7021,7 @@ static void ggml_compute_forward_rms_norm_f32(
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps > 0.0f);
GGML_ASSERT(eps >= 0.0f);
// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
@@ -7018,12 +7073,13 @@ static void ggml_compute_forward_rms_norm_back_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
const struct ggml_tensor * src0 = dst->src[0]; // gradients from forward pass output
const struct ggml_tensor * src1 = dst->src[1]; // src1 from forward pass
GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1));
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
@@ -7042,8 +7098,8 @@ static void ggml_compute_forward_rms_norm_back_f32(
const int64_t i12 = i02;
const int64_t i13 = i03;
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
const float * dz = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
const float * dz = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
const float * x = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
ggml_float sum_xx = 0.0;
ggml_float sum_xdz = 0.0;
@@ -7066,9 +7122,9 @@ static void ggml_compute_forward_rms_norm_back_f32(
{
// z = rms_norm(x)
//
// rms_norm(src0) =
// rms_norm(src1) =
// scale(
// src0,
// src1,
// div(
// 1,
// sqrt(
@@ -7076,13 +7132,13 @@ static void ggml_compute_forward_rms_norm_back_f32(
// scale(
// sum(
// sqr(
// src0)),
// src1)),
// (1.0/N)),
// eps))));
// postorder:
// ## op args grad
// 00 param src0 grad[#00]
// 00 param src1 grad[#00]
// 01 const 1
// 02 sqr (#00) grad[#02]
// 03 sum (#02) grad[#03]
@@ -7159,6 +7215,7 @@ static void ggml_compute_forward_rms_norm_back_f32(
// dx := scale(dx, rrms)
float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
// dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps)
ggml_vec_cpy_f32 (ne00, dx, x);
// ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
@@ -7750,12 +7807,13 @@ static void ggml_compute_forward_out_prod_f32(
const int ith = params->ith;
const int nth = params->nth;
GGML_ASSERT(ne0 == ne00);
GGML_ASSERT(ne1 == ne10);
GGML_ASSERT(ne2 == ne02);
GGML_ASSERT(ne02 == ne12);
GGML_ASSERT(ne3 == ne13);
GGML_ASSERT(ne03 == ne13);
GGML_ASSERT(ne0 == ne00);
GGML_ASSERT(ne1 == ne10);
GGML_ASSERT(ne2 == ne12);
GGML_ASSERT(ne3 == ne13);
GGML_ASSERT(ne2 % ne02 == 0);
GGML_ASSERT(ne3 % ne03 == 0);
// we don't support permuted src0 or src1
GGML_ASSERT(nb00 == sizeof(float));
@@ -7797,6 +7855,10 @@ static void ggml_compute_forward_out_prod_f32(
const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
const int64_t blck_1 = 16;
// dps == dst per src0, used for group query attention
const int64_t dps2 = ne2 / ne02;
const int64_t dps3 = ne3 / ne03;
for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
const int64_t bir1 = MIN(bir + blck_1, ir1);
for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
@@ -7807,8 +7869,8 @@ static void ggml_compute_forward_out_prod_f32(
const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
const int64_t i02 = i2;
const int64_t i03 = i3;
const int64_t i02 = i2 / dps2;
const int64_t i03 = i3 / dps3;
//const int64_t i10 = i1;
const int64_t i12 = i2;
@@ -8906,9 +8968,9 @@ static void ggml_compute_forward_soft_max(
}
// ggml_compute_forward_soft_max_back
// ggml_compute_forward_soft_max_ext_back
static void ggml_compute_forward_soft_max_back_f32(
static void ggml_compute_forward_soft_max_ext_back_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
@@ -8921,6 +8983,14 @@ static void ggml_compute_forward_soft_max_back_f32(
GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(ggml_are_same_shape(src1, dst));
float scale = 1.0f;
float max_bias = 0.0f;
memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
GGML_ASSERT(max_bias == 0.0f);
// TODO: handle transposed/permuted matrices
const int ith = params->ith;
@@ -8969,10 +9039,11 @@ static void ggml_compute_forward_soft_max_back_f32(
// linear runtime, no additional memory
float dot_y_dy = 0;
ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
ggml_vec_cpy_f32 (nc, dx, dy);
ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
ggml_vec_mul_f32 (nc, dx, dx, y);
ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
ggml_vec_cpy_f32 (nc, dx, dy);
ggml_vec_acc1_f32 (nc, dx, -dot_y_dy);
ggml_vec_mul_f32 (nc, dx, dx, y);
ggml_vec_scale_f32(nc, dx, scale);
#ifndef NDEBUG
for (int i = 0; i < nc; ++i) {
@@ -8983,7 +9054,7 @@ static void ggml_compute_forward_soft_max_back_f32(
}
}
static void ggml_compute_forward_soft_max_back(
static void ggml_compute_forward_soft_max_ext_back(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
@@ -8992,7 +9063,7 @@ static void ggml_compute_forward_soft_max_back(
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_soft_max_back_f32(params, dst);
ggml_compute_forward_soft_max_ext_back_f32(params, dst);
} break;
default:
{
@@ -9985,9 +10056,10 @@ static void ggml_compute_forward_im2col_back_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
const struct ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
const struct ggml_tensor * src1 = dst->src[1]; // convolution kernel
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
@@ -10009,11 +10081,11 @@ static void ggml_compute_forward_im2col_back_f32(
const int64_t IH = is_2D ? ne1 : 1;
const int64_t IW = ne0;
const int64_t KH = is_2D ? ne01 : 1;
const int64_t KW = ne00;
const int64_t KH = is_2D ? ne11 : 1;
const int64_t KW = ne10;
const int64_t OH = is_2D ? ne12 : 1;
const int64_t OW = ne11;
const int64_t OH = is_2D ? ne02 : 1;
const int64_t OW = ne01;
int ofs0 = is_2D ? nb3 : nb2;
int ofs1 = is_2D ? nb2 : nb1;
@@ -10059,9 +10131,9 @@ static void ggml_compute_forward_im2col_back_f32(
continue;
}
const float * const src_data = (const float *) src1->data
const float * const grad_in = (const float *) src0->data
+ (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
grad += src_data[iic*(KH*KW) + ikh*KW + ikw];
grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
}
}
float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
@@ -12484,22 +12556,22 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
const struct ggml_tensor * opt0 = dst->src[2];
const struct ggml_tensor * grad = dst->src[0]; // gradient of forward pass output
const struct ggml_tensor * src0f = dst->src[1]; // src0 of forward pass
const struct ggml_tensor * src1f = dst->src[2]; // src1 of forward pass
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(ggml_is_contiguous(opt0));
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
GGML_ASSERT(ggml_is_contiguous(src0f));
GGML_ASSERT(ggml_is_contiguous(src1f));
GGML_ASSERT(ggml_is_contiguous(grad));
GGML_ASSERT(ggml_are_same_shape(src0f, src1f) && ggml_are_same_shape(src0f, dst));
const int64_t ith = params->ith;
const int64_t nth = params->nth;
// TODO: handle transposed/permuted matrices
const int64_t nc = src0->ne[0];
const int64_t nr = ggml_nrows(src0);
const int64_t nc = src0f->ne[0];
const int64_t nr = ggml_nrows(src0f);
// rows per thread
const int64_t dr = (nr + nth - 1)/nth;
@@ -12508,12 +12580,12 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
const int64_t ir0 = dr*ith;
const int64_t ir1 = MIN(ir0 + dr, nr);
const float d_by_nr = ((const float *) opt0->data)[0] / (float) nr;
const float d_by_nr = ((const float *) grad->data)[0] / (float) nr;
for (int64_t i1 = ir0; i1 < ir1; i1++) {
float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
const float * s0 = (const float *)((const char *) src0f->data + i1*src0f->nb[1]);
const float * s1 = (const float *)((const char *) src1f->data + i1*src1f->nb[1]);
#ifndef NDEBUG
for (int64_t i = 0; i < nc; ++i) {
@@ -12526,11 +12598,11 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
// soft_max
float max = -INFINITY;
ggml_vec_max_f32(nc, &max, s0);
ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
const ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
assert(sum > 0.0);
ggml_vec_scale_f32(nc, ds0, 1.0/sum);
// grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
// grad(src0f) = (softmax(src0f) - src1f) * grad(cross_entropy_loss(src0f, src1f)) / nr
ggml_vec_sub_f32(nc, ds0, ds0, s1);
ggml_vec_scale_f32(nc, ds0, d_by_nr);
@@ -12827,7 +12899,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break;
case GGML_OP_SOFT_MAX_BACK:
{
ggml_compute_forward_soft_max_back(params, tensor);
ggml_compute_forward_soft_max_ext_back(params, tensor);
} break;
case GGML_OP_ROPE:
{
@@ -13668,6 +13740,7 @@ struct ggml_cplan ggml_graph_plan(
} break;
case GGML_OP_SOFT_MAX:
case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK:
{
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
} break;
+10 -2
View File
@@ -403,8 +403,16 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
case GGML_OP_MUL_MAT:
return src1->type == GGML_TYPE_F32 || src1->type == ggml_get_type_traits_cpu(src0->type)->vec_dot_type;
case GGML_OP_ROPE_BACK:
return op->src[2] == NULL && (op->op_params[2] & 4) == 0;
case GGML_OP_SOFT_MAX_BACK: {
if (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type != GGML_TYPE_F32) {
return false;
}
float max_bias = 0.0f;
memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
return max_bias == 0.0f;
}
case GGML_OP_IM2COL_BACK:
return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
case GGML_OP_OUT_PROD:
+99 -76
View File
@@ -5,95 +5,89 @@
#include <cmath>
#include <cstdint>
static __global__ void cross_entropy_loss_f32(const float * logits, const float * labels, float * dst, const int nclasses, const int k) {
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
const int i0 = blockDim.x*blockIdx.x + warp_id*WARP_SIZE;
template <bool use_shared>
static __global__ void cross_entropy_loss_f32(
const float * __restrict__ logits, const float * __restrict__ labels, float * __restrict__ dst, const int nclasses, const int k) {
extern __shared__ float tmp[];
const int ne_tmp = WARP_SIZE*nclasses;
extern __shared__ float tmp_all[];
float * tmp_logits = tmp_all + (2*warp_id + 0)*ne_tmp;
float * tmp_labels = tmp_all + (2*warp_id + 1)*ne_tmp;
// Each warp first loads ne_tmp logits/labels into shared memory:
for (int i = lane_id; i < ne_tmp; i += WARP_SIZE) {
const int ig = i0*nclasses + i; // ig == i global
tmp_logits[i] = ig < k*nclasses ? logits[ig] : 0.0f;
tmp_labels[i] = ig < k*nclasses ? labels[ig] : 0.0f;
}
// Each thread in the warp then calculates the cross entropy loss for a single row.
// TODO: pad in order to avoid shared memory bank conflicts.
logits += int64_t(blockIdx.x)*nclasses;
labels += int64_t(blockIdx.x)*nclasses;
// Find maximum for softmax:
float max = -INFINITY;
for (int i = 0; i < nclasses; ++i) {
max = fmaxf(max, tmp_logits[lane_id*nclasses + i]);
float max_logit = -INFINITY;
for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
const float val = logits[i];
max_logit = fmaxf(max_logit, val);
if (use_shared) {
tmp[i] = val;
}
}
max_logit = warp_reduce_max(max_logit);
// Calculate log(softmax(logits)) which is just logits - max:
float sum = 0.0f;
for (int i = 0; i < nclasses; ++i) {
float val = tmp_logits[lane_id*nclasses + i] - max;
sum += expf(val);
tmp_logits[lane_id*nclasses + i] = val;
for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
const float logit_i = use_shared ? tmp[i] : logits[i];
sum += expf(logit_i - max_logit);
}
sum = warp_reduce_sum(sum);
sum = logf(sum);
// log(exp(logits - max) / sum) = (logits - max) - log(sum)
float loss = 0.0f;
for (int i = 0; i < nclasses; ++i) {
loss += (tmp_logits[lane_id*nclasses + i] - sum) * tmp_labels[lane_id*nclasses + i];
for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
const float logit_i = use_shared ? tmp[i] : logits[i];
loss += (logit_i - max_logit - sum) * labels[i];
}
loss = -warp_reduce_sum(loss) / (float)k;
__syncthreads();
if (lane_id == 0) {
tmp_all[warp_id] = loss;
}
__syncthreads();
if (warp_id != 0) {
return;
}
loss = lane_id < CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE/WARP_SIZE ? tmp_all[lane_id] : 0.0f;
loss = warp_reduce_sum(loss);
if (lane_id != 0) {
if (threadIdx.x != 0) {
return;
}
dst[blockIdx.x] = loss;
}
static __global__ void cross_entropy_loss_back_f32(const float * logits, const float * labels, const float * loss, float * dst, const int nclasses) {
template <bool use_shared>
static __global__ void cross_entropy_loss_back_f32(
const float * __restrict__ grad, const float * __restrict__ logits, const float * __restrict__ labels,
float * __restrict__ dst, const int nclasses) {
extern __shared__ float tmp[];
logits += int64_t(blockIdx.x)*nclasses;
labels += int64_t(blockIdx.x)*nclasses;
dst += int64_t(blockIdx.x)*nclasses;
float maxval = -INFINITY;
for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
const float val = logits[blockIdx.x*nclasses + i];
const float val = logits[i];
maxval = fmaxf(maxval, val);
tmp[i] = val;
if (use_shared) {
tmp[i] = val;
}
}
maxval = warp_reduce_max(maxval);
float sum = 0.0f;
for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
const float val = expf(tmp[i] - maxval);
const float val = expf((use_shared ? tmp[i] : logits[i]) - maxval);
sum += val;
tmp[i] = val;
if (use_shared) {
tmp[i] = val;
} else {
dst[i] = val;
}
}
sum = warp_reduce_sum(sum);
const float sm_scale = 1.0f/sum;
const float d_by_nrows = *loss/gridDim.x;
const float d_by_nrows = *grad/gridDim.x;
for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
dst[blockIdx.x*nclasses + i] = (tmp[i]*sm_scale - labels[blockIdx.x*nclasses + i])*d_by_nrows;
const float val = use_shared ? tmp[i] : dst[i];
dst[i] = (val*sm_scale - labels[i])*d_by_nrows;
}
}
@@ -119,48 +113,77 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
ggml_cuda_pool & pool = ctx.pool();
cudaStream_t stream = ctx.stream();
const dim3 blocks_dim(CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1);
const dim3 blocks_num((nrows + CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE - 1) / CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1);
const int shmem = 2*CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE*ne00*sizeof(float);
const dim3 blocks_dim(WARP_SIZE, 1, 1);
const dim3 blocks_num(nrows, 1, 1);
const size_t nbytes_shared = ne00*sizeof(float);
const int id = ggml_cuda_get_device();
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
cross_entropy_loss_f32<<<blocks_num, blocks_dim, shmem, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
if (nbytes_shared <= smpbo) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
if (!shared_memory_limit_raised[id]) {
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
shared_memory_limit_raised[id] = true;
}
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
cross_entropy_loss_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
} else {
cross_entropy_loss_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
}
CUDA_CHECK(cudaGetLastError());
// Combine results from individual blocks:
sum_f32_cuda(pool, dst_tmp.ptr, dst_d, blocks_num.x, stream);
}
void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * opt0 = dst->src[2];
const ggml_tensor * grad = dst->src[0];
const ggml_tensor * src0f = dst->src[1];
const ggml_tensor * src1f = dst->src[2];
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(opt0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0f->type == GGML_TYPE_F32);
GGML_ASSERT(src1f->type == GGML_TYPE_F32);
GGML_ASSERT( grad->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(ggml_is_contiguous(opt0));
GGML_ASSERT(ggml_is_scalar(grad));
GGML_ASSERT(ggml_is_contiguous(src0f));
GGML_ASSERT(ggml_is_contiguous(src1f));
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_are_same_shape(src0, src1));
GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(ggml_are_same_shape(src0f, src1f));
GGML_ASSERT(ggml_are_same_shape(src0f, dst));
const int64_t ne00 = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
const int64_t ne00 = src0f->ne[0];
const int64_t nrows = ggml_nrows(src0f);
const float * src0_d = (const float *) src0->data;
const float * src1_d = (const float *) src1->data;
const float * opt0_d = (const float *) opt0->data;
float * dst_d = (float *) dst->data;
const float * grad_d = (const float *) grad->data;
const float * src0f_d = (const float *) src0f->data;
const float * src1f_d = (const float *) src1f->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
const dim3 blocks_dim(WARP_SIZE, 1, 1);
const dim3 blocks_num(nrows, 1, 1);
const int shmem = ne00*sizeof(float);
const size_t nbytes_shared = ne00*sizeof(float);
cross_entropy_loss_back_f32<<<blocks_num, blocks_dim, shmem, stream>>>(src0_d, src1_d, opt0_d, dst_d, ne00);
const int id = ggml_cuda_get_device();
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
if (nbytes_shared <= smpbo) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
if (!shared_memory_limit_raised[id]) {
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
shared_memory_limit_raised[id] = true;
}
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
cross_entropy_loss_back_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
} else {
cross_entropy_loss_back_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
}
}
+107 -50
View File
@@ -3,15 +3,15 @@
template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static __global__ void k_get_rows(
const void * src0, const int32_t * src1, dst_t * dst,
int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
/*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
/*size_t s0,*/ size_t s1, size_t s2, size_t s3,
/*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
/*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
/*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2;
const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
@@ -22,10 +22,10 @@ static __global__ void k_get_rows(
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03;
const int ib = i00/qk; // block index
const int iqs = (i00%qk)/qr; // quant index
const int ib = i00/qk; // block index
const int iqs = (i00%qk)/qr; // quant index
const int iybs = i00 - i00%qk; // dst block start index
const int y_offset = qr == 1 ? 1 : qk/2;
@@ -39,15 +39,15 @@ static __global__ void k_get_rows(
template<typename src0_t, typename dst_t>
static __global__ void k_get_rows_float(
const src0_t * src0, const int32_t * src1, dst_t * dst,
int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
/*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
/*size_t s0,*/ size_t s1, size_t s2, size_t s3,
/*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
/*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
/*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
const int i00 = blockIdx.x*blockDim.x + threadIdx.x;
const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
const int i00 = blockIdx.x*blockDim.x + threadIdx.x;
const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
@@ -58,14 +58,38 @@ static __global__ void k_get_rows_float(
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
dst_row[i00] = src0_row[i00];
}
template<typename grad_t, typename dst_t>
static __global__ void k_get_rows_back_float(
const grad_t * __restrict__ grad, const int32_t * __restrict__ rows, dst_t * __restrict__ dst, const int64_t ncols, const int64_t nrows_grad) {
const int col = blockIdx.x*blockDim.x + threadIdx.x;
if (col >= ncols) {
return;
}
const int dst_row = blockIdx.y*blockDim.y + threadIdx.y;
float sum = 0.0f;
for (int64_t i = 0; i < nrows_grad; ++i) {
if (rows[i] != dst_row) {
continue;
}
sum += grad[i*ncols + col];
}
dst[dst_row*ncols + col] = sum;
}
template<int qk, int qr, dequantize_kernel_t dq>
static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
static void get_rows_cuda(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
GGML_TENSOR_BINARY_OP_LOCALS
@@ -87,22 +111,25 @@ static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, gg
GGML_ASSERT(ne00 % 2 == 0);
k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
src0_dd, src1_dd, dst_dd,
ne00, /*ne01, ne02, ne03,*/
/*ne10, ne11,*/ ne12, /*ne13,*/
/* s0,*/ s1, s2, s3,
/* nb00,*/ nb01, nb02, nb03,
s10, s11, s12/*, s13*/);
src0_dd, src1_dd, dst_dd,
ne00, /*ne01, ne02, ne03,*/
/*ne10, ne11,*/ ne12, /*ne13,*/
/* s0,*/ s1, s2, s3,
/* nb00,*/ nb01, nb02, nb03,
s10, s11, s12/*, s13*/);
GGML_UNUSED(dst);
}
template<typename src0_t>
static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
static void get_rows_cuda_float(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
GGML_TENSOR_BINARY_OP_LOCALS
GGML_ASSERT(ne13 == 1);
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
const dim3 block_nums(block_num_x, ne10, ne11*ne12);
@@ -119,12 +146,12 @@ static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * sr
//const size_t s13 = nb13 / ggml_element_size(src1);
k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
src0_dd, src1_dd, dst_dd,
ne00, /*ne01, ne02, ne03,*/
/*ne10, ne11,*/ ne12, /*ne13,*/
/* s0,*/ s1, s2, s3,
/* nb00,*/ nb01, nb02, nb03,
s10, s11, s12/*, s13*/);
src0_dd, src1_dd, dst_dd,
ne00, /*ne01, ne02, ne03,*/
/*ne10, ne11,*/ ne12, /*ne13,*/
/* s0,*/ s1, s2, s3,
/* nb00,*/ nb01, nb02, nb03,
s10, s11, s12/*, s13*/);
GGML_UNUSED(dst);
}
@@ -132,42 +159,41 @@ static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * sr
void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
float * dst_d = (float *)dst->data;
const void * src0_d = (const void *) src0->data;
const int32_t * src1_d = (const int32_t *) src1->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
const int32_t * src1_i32 = (const int32_t *) src1_d;
GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
switch (src0->type) {
case GGML_TYPE_F16:
get_rows_cuda_float(src0, src1, dst, (const half *)src0_d, src1_i32, dst_d, stream);
get_rows_cuda_float(src0, src1, dst, (const half *) src0_d, src1_d, dst_d, stream);
break;
case GGML_TYPE_F32:
get_rows_cuda_float(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
get_rows_cuda_float(src0, src1, dst, (const float *) src0_d, src1_d, dst_d, stream);
break;
case GGML_TYPE_Q4_0:
get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
break;
case GGML_TYPE_Q4_1:
get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
break;
case GGML_TYPE_Q5_0:
get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
break;
case GGML_TYPE_Q5_1:
get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
break;
case GGML_TYPE_Q8_0:
get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
break;
default:
// TODO: k-quants
@@ -175,3 +201,34 @@ void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
break;
}
}
void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
const ggml_tensor * src1 = dst->src[1]; // src1 in forward pass
GGML_TENSOR_BINARY_OP_LOCALS
const float * src0_d = (const float *) src0->data;
const int32_t * src1_d = (const int32_t *) src1->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ne02*ne03 == 1);
GGML_ASSERT(ne12*ne13 == 1);
GGML_ASSERT(ne2*ne3 == 1);
const dim3 block_dims(CUDA_GET_ROWS_BACK_BLOCK_SIZE, 1, 1);
const int block_num_x = (ne00 + CUDA_GET_ROWS_BACK_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BACK_BLOCK_SIZE;
const dim3 block_nums(block_num_x, ne1, 1);
k_get_rows_back_float<<<block_nums, block_dims, 0, stream>>>(src0_d, src1_d, dst_d, ne00, ne10);
}
+3
View File
@@ -1,5 +1,8 @@
#include "common.cuh"
#define CUDA_GET_ROWS_BLOCK_SIZE 256
#define CUDA_GET_ROWS_BACK_BLOCK_SIZE 256
void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+260 -203
View File
@@ -2003,6 +2003,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_GET_ROWS:
ggml_cuda_op_get_rows(ctx, dst);
break;
case GGML_OP_GET_ROWS_BACK:
ggml_cuda_op_get_rows_back(ctx, dst);
break;
case GGML_OP_DUP:
ggml_cuda_dup(ctx, dst);
break;
@@ -2091,9 +2094,15 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_LEAKY_RELU:
ggml_cuda_op_leaky_relu(ctx, dst);
break;
case GGML_OP_SILU_BACK:
ggml_cuda_op_silu_back(ctx, dst);
break;
case GGML_OP_RMS_NORM:
ggml_cuda_op_rms_norm(ctx, dst);
break;
case GGML_OP_RMS_NORM_BACK:
ggml_cuda_op_rms_norm_back(ctx, dst);
break;
case GGML_OP_MUL_MAT:
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
GGML_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]);
@@ -2138,9 +2147,15 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_SOFT_MAX:
ggml_cuda_op_soft_max(ctx, dst);
break;
case GGML_OP_SOFT_MAX_BACK:
ggml_cuda_op_soft_max_back(ctx, dst);
break;
case GGML_OP_ROPE:
ggml_cuda_op_rope(ctx, dst);
break;
case GGML_OP_ROPE_BACK:
ggml_cuda_op_rope_back(ctx, dst);
break;
case GGML_OP_IM2COL:
ggml_cuda_op_im2col(ctx, dst);
break;
@@ -2289,6 +2304,66 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
}
#ifdef USE_CUDA_GRAPH
static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
std::vector<void *> & ggml_cuda_cpy_fn_ptrs, bool use_cuda_graph) {
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
cuda_ctx->cuda_graph->updated_kernel_arg.clear();
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
continue;
}
if (node->src[0] && node->src[0]->buffer && ggml_backend_buft_is_cuda_split(node->src[0]->buffer->buft)) {
use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer\n", __func__);
#endif
}
if (node->op == GGML_OP_MUL_MAT_ID) {
use_cuda_graph = false; // This node type is not supported by CUDA graph capture
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to mul_mat_id\n", __func__);
#endif
}
if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
// disable CUDA graphs for batch size > 1 for now.
// Changes in batch size or context size can cause changes to the grid size of some kernels.
use_cuda_graph = false;
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
#endif
}
if (node->op == GGML_OP_CPY) {
// store the copy op parameter which changes with each token.
cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
// store a pointer to each copy op CUDA kernel to identify it later
void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
if (!ptr) {
use_cuda_graph = false;
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
#endif
} else {
if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
ggml_cuda_cpy_fn_ptrs.push_back(ptr);
}
}
}
if (!use_cuda_graph) {
break;
}
}
return use_cuda_graph;
}
static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
graph_node_properties->node_address = node->data;
graph_node_properties->node_op = node->op;
@@ -2339,149 +2414,105 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
return true;
}
#endif
static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vector<void *> & ggml_cuda_cpy_fn_ptrs, bool cuda_graph_update_required) {
ggml_cuda_set_device(cuda_ctx->device);
if (cuda_graph_update_required) {
// Extract nodes from graph
// First call with null argument gets number of nodes in graph
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
// Subsequent call with non-null argument gets nodes
cuda_ctx->cuda_graph->nodes.clear();
cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
cuda_ctx->cuda_graph->params.clear();
cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
if (cuda_ctx->cuda_graph->num_nodes > 0) {
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
#ifdef USE_CUDA_GRAPH
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
// Objects required for CUDA Graph
if (cuda_ctx->cuda_graph == nullptr) {
cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
}
bool use_cuda_graph = true;
bool cuda_graph_update_required = false;
// vector of pointers to CUDA cpy kernels, which are required to identify
// kernel parameters which need updated in the graph for each token
std::vector<void *> ggml_cuda_cpy_fn_ptrs;
if (cuda_ctx->cuda_graph->graph == nullptr) {
if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
#endif
}
}
// Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
// or previous graph capture failure.
// Also disable for multi-gpu for now. TO DO investigate
if (disable_cuda_graphs_due_to_env
|| cuda_ctx->cuda_graph->disable_due_to_gpu_arch
|| cuda_ctx->cuda_graph->disable_due_to_too_many_updates
|| cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) {
use_cuda_graph = false;
}
if (use_cuda_graph) {
if (cuda_ctx->cuda_graph->instance == nullptr) {
cuda_graph_update_required = true;
}
// Check if the graph size has changed
if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
cuda_graph_update_required = true;
cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
}
// Loop over nodes in GGML graph to determine if CUDA graph update is required
// and store properties to allow this comparison for the next token
for (int i = 0; i < cgraph->n_nodes; i++) {
bool has_matching_properties = true;
if (!cuda_graph_update_required) {
has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
}
if (!has_matching_properties) {
cuda_graph_update_required = true;
}
set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
}
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
cuda_ctx->cuda_graph->updated_kernel_arg.clear();
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
continue;
}
if (node->src[0] && node->src[0]->buffer && ggml_backend_buft_is_cuda_split(node->src[0]->buffer->buft)) {
use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer\n", __func__);
#endif
}
if (node->op == GGML_OP_MUL_MAT_ID) {
use_cuda_graph = false; // This node type is not supported by CUDA graph capture
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to mul_mat_id\n", __func__);
#endif
}
if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
// disable CUDA graphs for batch size > 1 for now.
// Changes in batch size or context size can cause changes to the grid size of some kernels.
use_cuda_graph = false;
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
#endif
}
if (node->op == GGML_OP_CPY) {
// store the copy op parameter which changes with each token.
cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
// store a pointer to each copy op CUDA kernel to identify it later
void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
if (!ptr) {
use_cuda_graph = false;
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
#endif
} else {
if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
ggml_cuda_cpy_fn_ptrs.push_back(ptr);
// Loop over nodes, and extract kernel parameters from each node
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
cudaGraphNodeType node_type;
CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
if (node_type == cudaGraphNodeTypeKernel) {
cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
if (stat == cudaErrorInvalidDeviceFunction) {
// Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
// We don't need to update blas nodes, so clear error and move on.
cudaGetLastError();
} else {
GGML_ASSERT(stat == cudaSuccess);
}
}
}
if (!use_cuda_graph) {
break;
}
} else {
// One of the arguments to the copy kernel is updated for each token, hence we need to
// replace that argument with the updated value in the CUDA graph
// on update steps, the live parameters will already be captured
int k = 0;
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
}
}
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
if (use_cuda_graph && cuda_graph_update_required) {
cuda_ctx->cuda_graph->number_consecutive_updates++;
} else {
cuda_ctx->cuda_graph->number_consecutive_updates = 0;
}
if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
#endif
}
}
}
if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
}
static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
#else
bool use_cuda_graph = false;
bool cuda_graph_update_required = false;
#endif // USE_CUDA_GRAPH
bool graph_evaluated_or_captured = false;
if (cuda_ctx->cuda_graph->instance == nullptr) {
cuda_graph_update_required = true;
}
// Check if the graph size has changed
if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
cuda_graph_update_required = true;
cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
}
// Loop over nodes in GGML graph to determine if CUDA graph update is required
// and store properties to allow this comparison for the next token
for (int i = 0; i < cgraph->n_nodes; i++) {
bool has_matching_properties = true;
if (!cuda_graph_update_required) {
has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
}
if (!has_matching_properties) {
cuda_graph_update_required = true;
}
set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
}
return cuda_graph_update_required;
}
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
cudaGraphExecUpdateResultInfo result_info;
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
if (stat == cudaErrorGraphExecUpdateFailure) {
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
#endif
// The pre-existing graph exec cannot be updated due to violated constraints
// so instead clear error and re-instantiate
cudaGetLastError();
CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
cuda_ctx->cuda_graph->instance = nullptr;
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
} else {
GGML_ASSERT(stat == cudaSuccess);
}
}
#endif
static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
[[maybe_unused]] std::vector<void *> & ggml_cuda_cpy_fn_ptrs, bool & graph_evaluated_or_captured, bool & use_cuda_graph,
bool & cuda_graph_update_required) {
while (!graph_evaluated_or_captured) {
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
@@ -2519,19 +2550,8 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
cuda_ctx->cuda_graph->graph = nullptr;
}
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
#if 0
if (disable_cuda_graphs_due_to_failed_capture) {
use_cuda_graph = false;
cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true;
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to failed graph capture\n", __func__);
#endif
} else {
graph_evaluated_or_captured = true; // CUDA graph has been captured
}
#endif
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
graph_evaluated_or_captured = true; // CUDA graph has been captured
} else {
graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
@@ -2544,72 +2564,91 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
}
// Perform update to graph (if required for this token), and change copy parameter (required for every token)
if (cuda_graph_update_required) {
// Extract nodes from graph
// First call with null argument gets number of nodes in graph
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
// Subsequent call with non-null argument gets nodes
cuda_ctx->cuda_graph->nodes.clear();
cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
cuda_ctx->cuda_graph->params.clear();
cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
if (cuda_ctx->cuda_graph->num_nodes > 0) {
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
// Loop over nodes, and extract kernel parameters from each node
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
cudaGraphNodeType node_type;
CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
if (node_type == cudaGraphNodeTypeKernel) {
cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
if (stat == cudaErrorInvalidDeviceFunction) {
// Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
// We don't need to update blas nodes, so clear error and move on.
cudaGetLastError();
} else {
GGML_ASSERT(stat == cudaSuccess);
}
}
}
}
}
// One of the arguments to the copy kernel is updated for each token, hence we need to
// replace that argument with the updated value in the CUDA graph
if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
int k = 0;
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
}
}
}
maintain_cuda_graph(cuda_ctx, ggml_cuda_cpy_fn_ptrs, cuda_graph_update_required);
// Update graph executable
cudaGraphExecUpdateResultInfo result_info;
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
if (stat == cudaErrorGraphExecUpdateFailure) {
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
#endif
// The pre-existing graph exec cannot be updated due to violated constraints
// so instead clear error and re-instantiate
cudaGetLastError();
CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
cuda_ctx->cuda_graph->instance = nullptr;
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
} else {
GGML_ASSERT(stat == cudaSuccess);
}
update_cuda_graph_executable(cuda_ctx);
// Launch graph
CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
#else
graph_evaluated_or_captured = true;
#endif // USE_CUDA_GRAPH
#endif // USE_CUDA_GRAPH
}
}
static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
ggml_cuda_set_device(cuda_ctx->device);
// vector of pointers to CUDA cpy kernels, which are required to identify
// kernel parameters which need updated in the graph for each token
std::vector<void *> ggml_cuda_cpy_fn_ptrs;
#ifdef USE_CUDA_GRAPH
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
// Objects required for CUDA Graph
if (cuda_ctx->cuda_graph == nullptr) {
cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
}
bool use_cuda_graph = true;
bool cuda_graph_update_required = false;
if (cuda_ctx->cuda_graph->graph == nullptr) {
if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
#endif
}
}
// Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
// or previous graph capture failure.
// Also disable for multi-gpu for now. TO DO investigate
if (disable_cuda_graphs_due_to_env
|| cuda_ctx->cuda_graph->disable_due_to_gpu_arch
|| cuda_ctx->cuda_graph->disable_due_to_too_many_updates
|| cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) {
use_cuda_graph = false;
}
if (use_cuda_graph) {
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph,
ggml_cuda_cpy_fn_ptrs, use_cuda_graph);
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
if (use_cuda_graph && cuda_graph_update_required) {
cuda_ctx->cuda_graph->number_consecutive_updates++;
} else {
cuda_ctx->cuda_graph->number_consecutive_updates = 0;
}
if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
#endif
}
}
if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
}
#else
bool use_cuda_graph = false;
bool cuda_graph_update_required = false;
#endif // USE_CUDA_GRAPH
bool graph_evaluated_or_captured = false;
evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, ggml_cuda_cpy_fn_ptrs, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
return GGML_STATUS_SUCCESS;
}
@@ -2885,7 +2924,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
}
} break;
case GGML_OP_OUT_PROD:
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
case GGML_OP_GET_ROWS:
{
switch (op->src[0]->type) {
@@ -2901,6 +2940,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return false;
}
} break;
case GGML_OP_GET_ROWS_BACK:
{
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
} break;
case GGML_OP_CPY:
{
ggml_type src0_type = op->src[0]->type;
@@ -2974,8 +3017,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
}
return false;
} break;
case GGML_OP_SILU_BACK:
return ggml_is_contiguous(op->src[0]);
break;
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_RMS_NORM_BACK:
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
break;
case GGML_OP_NONE:
@@ -3000,8 +3047,17 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
return true;
case GGML_OP_SOFT_MAX_BACK: {
float max_bias = 0.0f;
memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
return max_bias == 0.0f;
}
case GGML_OP_ROPE:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_ROPE_BACK: {
const size_t ts = ggml_type_size(op->src[0]->type);
const int64_t ne0_012 = op->src[0]->ne[0] * op->src[0]->ne[1] * op->src[0]->ne[2];
return op->src[0]->nb[0] == ts && op->src[0]->nb[3] == ne0_012*ts;
}
case GGML_OP_IM2COL:
case GGML_OP_POOL_2D:
case GGML_OP_SUM:
@@ -3057,6 +3113,7 @@ static int64_t get_op_batch_size(const ggml_tensor * op) {
return op->ne[1];
case GGML_OP_MUL_MAT_ID:
case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK:
return op->ne[2];
default:
return ggml_nrows(op);
+124 -31
View File
@@ -5,20 +5,24 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols, c
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
float2 mean_var = make_float2(0.f, 0.f);
x += int64_t(row)*ncols;
dst += int64_t(row)*ncols;
float2 mean_var = make_float2(0.0f, 0.0f);
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[row*ncols + col];
const float xi = x[col];
mean_var.x += xi;
mean_var.y += xi * xi;
}
// sum up partial sums
mean_var = warp_reduce_sum(mean_var);
if (block_size > WARP_SIZE) {
if constexpr (block_size > WARP_SIZE) {
static_assert(block_size == 1024, "unexpected block_size");
__shared__ float2 s_sum[32];
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = mean_var;
}
@@ -32,7 +36,7 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols, c
const float inv_std = rsqrtf(var + eps);
for (int col = tid; col < ncols; col += block_size) {
dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std;
dst[col] = (x[col] - mean) * inv_std;
}
}
@@ -40,14 +44,8 @@ template <int block_size>
static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
// blockIdx.x: num_groups idx
// threadIdx.x: block_size idx
int start = blockIdx.x * group_size;
int end = start + group_size;
start += threadIdx.x;
if (end >= ne_elements) {
end = ne_elements;
}
const int start = blockIdx.x*group_size + threadIdx.x;
const int end = min(blockIdx.x*group_size + group_size, ne_elements);
float tmp = 0.0f; // partial sum for thread in warp
@@ -56,10 +54,11 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
}
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
if constexpr (block_size > WARP_SIZE) {
static_assert(block_size == 1024, "unexpected block_size");
__shared__ float s_sum[32];
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
@@ -68,11 +67,11 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
tmp = warp_reduce_sum(tmp);
}
float mean = tmp / group_size;
const float mean = tmp / group_size;
tmp = 0.0f;
for (int j = start; j < end; j += block_size) {
float xi = x[j] - mean;
const float xi = x[j] - mean;
dst[j] = xi;
tmp += xi * xi;
}
@@ -80,8 +79,8 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
__shared__ float s_sum[32];
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
@@ -90,8 +89,8 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
tmp = warp_reduce_sum(tmp);
}
float variance = tmp / group_size;
float scale = rsqrtf(variance + eps);
const float variance = tmp / group_size;
const float scale = rsqrtf(variance + eps);
for (int j = start; j < end; j += block_size) {
dst[j] *= scale;
}
@@ -102,19 +101,23 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
x += int64_t(row)*ncols;
dst += int64_t(row)*ncols;
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[row*ncols + col];
const float xi = x[col];
tmp += xi * xi;
}
// sum up partial sums
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
if constexpr (block_size > WARP_SIZE) {
static_assert(block_size == 1024, "unexpected block_size");
__shared__ float s_sum[32];
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
@@ -127,12 +130,63 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
const float scale = rsqrtf(mean + eps);
for (int col = tid; col < ncols; col += block_size) {
dst[row*ncols + col] = scale * x[row*ncols + col];
dst[col] = scale * x[col];
}
}
template <int block_size>
static __global__ void rms_norm_back_f32(
const float * grad, const float * xf, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
grad += int64_t(row)*ncols;
xf += int64_t(row)*ncols;
dst += int64_t(row)*ncols;
float sum_xx = 0.0f; // sum for squares of x, equivalent to forward pass
float sum_xg = 0.0f; // sum for x * gradient, needed because RMS norm mixes inputs
for (int col = tid; col < ncols; col += block_size) {
const float xfi = xf[col];
sum_xx += xfi * xfi;
sum_xg += xfi * grad[col];
}
// sum up partial sums
sum_xx = warp_reduce_sum(sum_xx);
sum_xg = warp_reduce_sum(sum_xg);
if constexpr (block_size > WARP_SIZE) {
static_assert(block_size == 1024, "unexpected block_size");
__shared__ float s_sum_xx[32];
__shared__ float s_sum_xg[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum_xx[warp_id] = sum_xx;
s_sum_xg[warp_id] = sum_xg;
}
__syncthreads();
sum_xx = s_sum_xx[lane_id];
sum_xx = warp_reduce_sum(sum_xx);
sum_xg = s_sum_xg[lane_id];
sum_xg = warp_reduce_sum(sum_xg);
}
const float mean_eps = sum_xx / ncols + eps;
const float sum_eps = sum_xx + ncols*eps;
const float scale_grad = rsqrtf(mean_eps);
const float scale_x = -scale_grad * sum_xg/sum_eps;
for (int col = tid; col < ncols; col += block_size) {
dst[col] = scale_grad*grad[col] + scale_x*xf[col];
}
}
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
@@ -142,7 +196,8 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i
}
}
static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) {
static void group_norm_f32_cuda(
const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) {
if (group_size < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
@@ -153,7 +208,6 @@ static void group_norm_f32_cuda(const float * x, float * dst, const int num_grou
}
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
@@ -163,6 +217,16 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con
}
}
static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_back_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(grad, xf, dst, ncols, eps);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_back_f32<1024><<<nrows, block_dims, 0, stream>>>(grad, xf, dst, ncols, eps);
}
}
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
@@ -179,6 +243,7 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps >= 0.0f);
norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
}
@@ -198,6 +263,7 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
float eps;
memcpy(&eps, dst->op_params + 1, sizeof(float));
GGML_ASSERT(eps >= 0.0f);
int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], eps, group_size, ggml_nelements(src0), stream);
@@ -219,6 +285,33 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps >= 0.0f);
rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
}
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * grad = dst->src[0]; // gradients
const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass
const float * grad_d = (const float *) grad->data;
const float * src0f_d = (const float *) src0f->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(grad));
GGML_ASSERT( grad->type == GGML_TYPE_F32);
GGML_ASSERT(src0f->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0f->ne[0];
const int64_t nrows = ggml_nrows(src0f);
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps >= 0.0f);
rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream);
}
+2
View File
@@ -5,3 +5,5 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+26 -12
View File
@@ -11,16 +11,15 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ne01 == ne11);
GGML_ASSERT(ne0 == ne00);
GGML_ASSERT(ne1 == ne10);
GGML_ASSERT(ne2 == src0->ne[2]);
GGML_ASSERT(ne2 % src0->ne[2] == 0);
GGML_ASSERT(ne3 % src0->ne[3] == 0);
GGML_ASSERT(ne2 == src1->ne[2]);
GGML_ASSERT(ne3 == src0->ne[3]);
GGML_ASSERT(ne3 == src1->ne[3]);
const float * src0_d = (const float *) src0->data;
@@ -33,8 +32,6 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const float alpha = 1.0f;
const float beta = 0.0f;
GGML_ASSERT(ne2 == 1);
GGML_ASSERT(ne3 == 1);
CUBLAS_CHECK(cublasSetStream(handle, stream));
const bool src1_T = ggml_is_transposed(src1);
@@ -42,10 +39,27 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
GGML_ASSERT( (src1_T ? nb11 : nb10) == sizeof(float));
CUBLAS_CHECK(
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
ne0, ne1, ne01,
&alpha, src0_d, ne00,
src1_d, ldb,
&beta, dst_d, ne0));
// data strides in dimensions 2/3
const size_t s02 = nb02 / sizeof(float);
const size_t s03 = nb03 / sizeof(float);
const size_t s12 = nb12 / sizeof(float);
const size_t s13 = nb13 / sizeof(float);
const size_t s2 = nb2 / sizeof(float);
const size_t s3 = nb3 / sizeof(float);
// dps == dst per src0, used for group query attention
const int64_t dps2 = ne2 / ne02;
const int64_t dps3 = ne3 / ne03;
// TODO batched matrix multiplication
for (int64_t i3 = 0; i3 < ne3; ++i3) {
for (int64_t i2 = 0; i2 < ne2; ++i2) {
CUBLAS_CHECK(
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
ne0, ne1, ne01,
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, ne00,
src1_d + i3 *s13 + i2 *s12, ldb,
&beta, dst_d + i3 *s3 + i2 *s2, ne0));
}
}
}
+162 -204
View File
@@ -16,9 +16,10 @@ static __device__ float rope_yarn_ramp(const float low, const float high, const
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
template<bool forward>
static __device__ void rope_yarn(
float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
float * cos_theta, float * sin_theta) {
const float theta_extrap, const float freq_scale, const rope_corr_dims corr_dims, const int64_t i0, const float ext_factor,
float mscale, float & cos_theta, float & sin_theta) {
// Get n-d rotational scaling corrected for extrapolation
float theta_interp = freq_scale * theta_extrap;
float theta = theta_interp;
@@ -29,24 +30,28 @@ static __device__ void rope_yarn(
// Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
}
*cos_theta = cosf(theta) * mscale;
*sin_theta = sinf(theta) * mscale;
cos_theta = cosf(theta) * mscale;
sin_theta = sinf(theta) * mscale;
if (!forward) {
sin_theta *= -1.0f;
}
}
template<typename T, bool has_ff>
template<bool forward, bool has_ff, typename T>
static __global__ void rope_norm(
const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne0) {
return;
}
const int row = blockDim.x*blockIdx.x + threadIdx.x;
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
if (i0 >= n_dims) {
const int i = row*ne0 + i0;
const int i = row_dst*ne0 + i0;
dst[i + 0] = x[i + 0];
dst[i + 1] = x[i + 1];
@@ -54,39 +59,43 @@ static __global__ void rope_norm(
return;
}
const int i = row*ne0 + i0;
const int i2 = row/p_delta_rows;
const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1;
const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
const int idst = row_dst*ne0 + i0;
const int ix = channel_x*s2 + row_x*s1 + i0;
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
float cos_theta;
float sin_theta;
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
const float x0 = x[i + 0];
const float x1 = x[i + 1];
const float x0 = x[ix + 0];
const float x1 = x[ix + 1];
dst[i + 0] = x0*cos_theta - x1*sin_theta;
dst[i + 1] = x0*sin_theta + x1*cos_theta;
dst[idst + 0] = x0*cos_theta - x1*sin_theta;
dst[idst + 1] = x0*sin_theta + x1*cos_theta;
}
template<typename T, bool has_ff>
template<bool forward, bool has_ff, typename T>
static __global__ void rope_neox(
const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne0) {
return;
}
const int row = blockDim.x*blockIdx.x + threadIdx.x;
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
if (i0 >= n_dims) {
const int i = row*ne0 + i0;
const int i = row_dst*ne0 + i0;
dst[i + 0] = x[i + 0];
dst[i + 1] = x[i + 1];
@@ -94,39 +103,43 @@ static __global__ void rope_neox(
return;
}
const int i = row*ne0 + i0/2;
const int i2 = row/p_delta_rows;
const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1;
const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
const int idst = row_dst*ne0 + i0/2;
const int ix = channel_x*s2 + row_x*s1 + i0/2;
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
float cos_theta;
float sin_theta;
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
const float x0 = x[i + 0];
const float x1 = x[i + n_dims/2];
const float x0 = x[ix + 0];
const float x1 = x[ix + n_dims/2];
dst[i + 0] = x0*cos_theta - x1*sin_theta;
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
dst[idst + 0] = x0*cos_theta - x1*sin_theta;
dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
}
template<typename T, bool has_ff>
template<bool forward, bool has_ff, typename T>
static __global__ void rope_multi(
const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) {
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne0) {
return;
}
const int row = blockDim.x*blockIdx.x + threadIdx.x;
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
if (i0 >= n_dims) {
const int i = row*ne0 + i0;
const int i = row_dst*ne0 + i0;
dst[i + 0] = x[i + 0];
dst[i + 1] = x[i + 1];
@@ -134,25 +147,28 @@ static __global__ void rope_multi(
return;
}
const int i = row*ne0 + i0/2;
const int i2 = row/p_delta_rows;
const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1;
int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
int sec_w = sections.v[1] + sections.v[0];
int sector = (i0 / 2) % sect_dims;
const int idst = row_dst*ne0 + i0/2;
const int ix = channel_x*s2 + row_x*s1 + i0/2;
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
const int sec_w = sections.v[1] + sections.v[0];
const int sector = (i0 / 2) % sect_dims;
float theta_base = 0.0;
if (sector < sections.v[0]) {
theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sections.v[0] && sector < sec_w) {
theta_base = pos[i2 + ne2 * 1]*powf(theta_scale, i0/2.0f);
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
theta_base = pos[i2 + ne2 * 2]*powf(theta_scale, i0/2.0f);
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sec_w + sections.v[2]) {
theta_base = pos[i2 + ne2 * 3]*powf(theta_scale, i0/2.0f);
theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
}
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@@ -160,42 +176,46 @@ static __global__ void rope_multi(
float cos_theta;
float sin_theta;
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
const float x0 = x[i + 0];
const float x1 = x[i + n_dims/2];
const float x0 = x[ix + 0];
const float x1 = x[ix + n_dims/2];
dst[i + 0] = x0*cos_theta - x1*sin_theta;
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
dst[idst + 0] = x0*cos_theta - x1*sin_theta;
dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
}
template<typename T, bool has_ff>
template<bool forward, bool has_ff, typename T>
static __global__ void rope_vision(
const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) {
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims,
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
const float theta_scale, const float * freq_factors, const mrope_sections sections) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne0) {
return;
}
const int row = blockDim.x*blockIdx.x + threadIdx.x;
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
const int i = row*ne0 + i0/2;
const int i2 = row/p_delta_rows; // i2-th tokens
const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1;
int sect_dims = sections.v[0] + sections.v[1];
int sec_w = sections.v[1] + sections.v[0];
int sector = (i0 / 2) % sect_dims;
const int idst = row_dst*ne0 + i0/2;
const int ix = channel_x*s2 + row_x*s1 + i0/2;
const int sect_dims = sections.v[0] + sections.v[1];
const int sec_w = sections.v[1] + sections.v[0];
const int sector = (i0 / 2) % sect_dims;
float theta_base = 0.0;
if (sector < sections.v[0]) {
const int p = sector;
theta_base = pos[i2]*powf(theta_scale, p);
theta_base = pos[channel_x]*powf(theta_scale, p);
}
else if (sector >= sections.v[0] && sector < sec_w) {
const int p = sector - sections.v[0];
theta_base = pos[i2 + ne2]*powf(theta_scale, p);
theta_base = pos[channel_x + ne2]*powf(theta_scale, p);
}
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@@ -203,19 +223,20 @@ static __global__ void rope_vision(
float cos_theta;
float sin_theta;
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
const float x0 = x[i + 0];
const float x1 = x[i + n_dims];
const float x0 = x[ix + 0];
const float x1 = x[ix + n_dims];
dst[i + 0] = x0*cos_theta - x1*sin_theta;
dst[i + n_dims] = x0*sin_theta + x1*cos_theta;
dst[idst + 0] = x0*cos_theta - x1*sin_theta;
dst[idst + n_dims] = x0*sin_theta + x1*cos_theta;
}
template<typename T>
template<bool forward, typename T>
static void rope_norm_cuda(
const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@@ -224,22 +245,21 @@ static void rope_norm_cuda(
const float theta_scale = powf(freq_base, -2.0f/n_dims);
if (freq_factors == nullptr) {
rope_norm<T, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors
);
rope_norm<forward, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors);
} else {
rope_norm<T, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors
);
rope_norm<forward, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors);
}
}
template<typename T>
template<bool forward, typename T>
static void rope_neox_cuda(
const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@@ -248,22 +268,21 @@ static void rope_neox_cuda(
const float theta_scale = powf(freq_base, -2.0f/n_dims);
if (freq_factors == nullptr) {
rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors
);
rope_neox<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors);
} else {
rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors
);
rope_neox<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors);
}
}
template<typename T>
template<bool forward, typename T>
static void rope_multi_cuda(
const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) {
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@@ -272,22 +291,21 @@ static void rope_multi_cuda(
const float theta_scale = powf(freq_base, -2.0f/n_dims);
if (freq_factors == nullptr) {
rope_multi<T, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, sections
);
rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections);
} else {
rope_multi<T, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, sections
);
rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections);
}
}
template<typename T>
template<bool forward, typename T>
static void rope_vision_cuda(
const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) {
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@@ -298,80 +316,18 @@ static void rope_vision_cuda(
const float theta_scale = powf(freq_base, -2.0f/n_dims);
if (freq_factors == nullptr) {
rope_vision<T, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, sections
);
rope_vision<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections);
} else {
rope_vision<T, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, sections
);
rope_vision<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections);
}
}
static void rope_norm_cuda_f16(
const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
rope_norm_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}
static void rope_norm_cuda_f32(
const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
rope_norm_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}
static void rope_neox_cuda_f16(
const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
rope_neox_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}
static void rope_neox_cuda_f32(
const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
) {
rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}
static void rope_multi_cuda_f16(
const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
) {
rope_multi_cuda<half>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
}
static void rope_multi_cuda_f32(
const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
) {
rope_multi_cuda<float>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
}
static void rope_vision_cuda_f16(
const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
) {
rope_vision_cuda<half>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
}
static void rope_vision_cuda_f32(
const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
) {
rope_vision_cuda<float>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
}
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
template <bool forward>
void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
@@ -382,7 +338,6 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type);
@@ -392,6 +347,9 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int64_t ne02 = src0->ne[2]; // num heads
const int64_t nr = ggml_nrows(src0);
const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);
const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
@@ -440,59 +398,59 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
// compute
if (is_neox) {
if (src0->type == GGML_TYPE_F32) {
rope_neox_cuda_f32(
(const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, stream
);
rope_neox_cuda<forward>(
(const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
} else if (src0->type == GGML_TYPE_F16) {
rope_neox_cuda_f16(
(const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, stream
);
rope_neox_cuda<forward>(
(const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
} else {
GGML_ABORT("fatal error");
}
} else if (is_mrope && !is_vision) {
if (src0->type == GGML_TYPE_F32) {
rope_multi_cuda_f32(
(const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, sections, stream
);
rope_multi_cuda<forward>(
(const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
} else if (src0->type == GGML_TYPE_F16) {
rope_multi_cuda_f16(
(const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, sections, stream
);
rope_multi_cuda<forward>(
(const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
} else {
GGML_ABORT("fatal error");
}
} else if (is_vision) {
if (src0->type == GGML_TYPE_F32) {
rope_vision_cuda_f32(
(const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, sections, stream
);
rope_vision_cuda<forward>(
(const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
} else if (src0->type == GGML_TYPE_F16) {
rope_vision_cuda_f16(
(const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, sections, stream
);
rope_vision_cuda<forward>(
(const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
} else {
GGML_ABORT("fatal error");
}
} else {
if (src0->type == GGML_TYPE_F32) {
rope_norm_cuda_f32(
(const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, stream
);
rope_norm_cuda<forward>(
(const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
} else if (src0->type == GGML_TYPE_F16) {
rope_norm_cuda_f16(
(const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, stream
);
rope_norm_cuda<forward>(
(const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
} else {
GGML_ABORT("fatal error");
}
}
}
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_rope_impl<true>(ctx, dst);
}
void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_rope_impl<false>(ctx, dst);
}
+2
View File
@@ -3,3 +3,5 @@
#define CUDA_ROPE_BLOCK_SIZE 256
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+101 -33
View File
@@ -1,5 +1,7 @@
#include "common.cuh"
#include "ggml.h"
#include "softmax.cuh"
#include <cstdint>
template <typename T>
static __device__ __forceinline__ float t2f32(T val) {
@@ -11,14 +13,20 @@ __device__ float __forceinline__ t2f32<half>(half val) {
return __half2float(val);
}
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
template <bool use_shared, int ncols_template, int block_size_template, typename T>
static __global__ void soft_max_f32(
const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
const int tid = threadIdx.x;
const int rowx = blockIdx.x;
const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
x += int64_t(rowx)*ncols;
mask += int64_t(rowy)*ncols * (mask != nullptr);
dst += int64_t(rowx)*ncols;
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
const int warp_id = threadIdx.x / WARP_SIZE;
@@ -29,7 +37,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
extern __shared__ float data_soft_max_f32[];
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
// shared memory buffer to cache values between iterations:
float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + (int64_t)rowx*ncols;
float * vals = use_shared ? buf_iw + WARP_SIZE : dst;
float max_val = -INFINITY;
@@ -41,10 +49,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
break;
}
const int64_t ix = (int64_t)rowx*ncols + col;
const int64_t iy = (int64_t)rowy*ncols + col;
const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f);
const float val = x[col]*scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
vals[col] = val;
max_val = max(max_val, val);
@@ -110,8 +115,29 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
return;
}
const int64_t idst = (int64_t)rowx*ncols + col;
dst[idst] = vals[col] * inv_sum;
dst[col] = vals[col] * inv_sum;
}
}
static __global__ void soft_max_back_f32(
const float * grad, const float * dstf, float * dst, const int ncols, const float scale) {
const int tid = threadIdx.x;
const int rowx = blockIdx.x;
grad += int64_t(rowx)*ncols;
dstf += int64_t(rowx)*ncols;
dst += int64_t(rowx)*ncols;
float dgf_dot = 0.0f; // dot product of dst from forward pass and gradients
for (int col = tid; col < ncols; col += WARP_SIZE) {
dgf_dot += dstf[col]*grad[col];
}
dgf_dot = warp_reduce_sum(dgf_dot);
for (int col = tid; col < ncols; col += WARP_SIZE) {
dst[col] = scale * (grad[col] - dgf_dot) * dstf[col];
}
}
@@ -121,7 +147,7 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
const dim3 block_dims(nth, 1, 1);
const dim3 block_nums(nrows_x, 1, 1);
const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
const uint32_t n_head = nrows_x/nrows_y;
@@ -131,50 +157,68 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
switch (ncols_x) {
case 32:
soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
soft_max_f32<true, 32, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 64:
soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
soft_max_f32<true, 64, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 128:
soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
soft_max_f32<true, 128, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 256:
soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
soft_max_f32<true, 256, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 512:
soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
soft_max_f32<true, 512, 512><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 1024:
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 2048:
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 4096:
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
default:
soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
}
} else {
const size_t shmem_low = WARP_SIZE*sizeof(float);
soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
}
}
static void soft_max_back_f32_cuda(
const float * grad, const float * dstf, float * dst,
const int ncols, const int nrows, const float scale, cudaStream_t stream) {
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(nrows, 1, 1);
soft_max_back_f32<<<block_nums, block_dims, 0, stream>>>(grad, dstf, dst, ncols, scale);
}
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const float * src0_d = (const float *)src0->data;
const void * src1_d = src1 ? (const void *)src1->data : nullptr;
const float * src0_d = (const float *) src0->data;
const void * src1_d = src1 ? (const void *) src1->data : nullptr;
float * dst_d = (float *) dst->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
@@ -189,18 +233,42 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
float scale = 1.0f;
float max_bias = 0.0f;
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
if (use_f16) {
const half * src1_dd = (const half *)src1_d;
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
} else {
const float * src1_dd = (const float *)src1_d;
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
}
}
void ggml_cuda_op_soft_max_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; // grad
const ggml_tensor * src1 = dst->src[1]; // forward pass output
const float * src0_d = (const float *) src0->data;
const float * src1_d = (const float *) src1->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
float scale = 1.0f;
float max_bias = 0.0f;
memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
GGML_ASSERT(max_bias == 0.0f);
soft_max_back_f32_cuda(src0_d, src1_d, dst_d, ncols, nrows, scale, stream);
}
+2
View File
@@ -3,3 +3,5 @@
#define CUDA_SOFT_MAX_BLOCK_SIZE 1024
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_soft_max_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+36
View File
@@ -51,6 +51,19 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
dst[i] = x[i] / (1.0f + expf(-x[i]));
}
static __global__ void silu_back_f32(
const float * grad, const float * xf, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
const float xfi = xf[i];
const float s = 1.0f / (1.0f + expf(-xfi));
dst[i] = grad[i] * s * (1.0f + xfi * (1.0f - s));
}
static __global__ void tanh_f32(const float * x, float * dst, int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
@@ -173,6 +186,11 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static void silu_back_f32_cuda(const float * grad, const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SILU_BACK_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
silu_back_f32<<<num_blocks, CUDA_SILU_BACK_BLOCK_SIZE, 0, stream>>>(grad, x, dst, k);
}
static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@@ -284,6 +302,24 @@ void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
silu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}
void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; // input from forward pass
const ggml_tensor * src1 = dst->src[1]; // grads of forward pass output
const float * src0_d = (const float *) src0->data;
const float * src1_d = (const float *) src1->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
silu_back_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(src0), stream);
}
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
+3
View File
@@ -4,6 +4,7 @@
#define CUDA_STEP_BLOCK_SIZE 256
#define CUDA_GELU_BLOCK_SIZE 256
#define CUDA_SILU_BLOCK_SIZE 256
#define CUDA_SILU_BACK_BLOCK_SIZE 256
#define CUDA_TANH_BLOCK_SIZE 256
#define CUDA_RELU_BLOCK_SIZE 256
#define CUDA_SIGMOID_BLOCK_SIZE 256
@@ -23,6 +24,8 @@ void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+1
View File
@@ -29,5 +29,6 @@
#include "wkv6.hpp"
#include "outprod.hpp"
#include "element_wise.hpp"
#include "gla.hpp"
#endif // GGML_SYCL_BACKEND_HPP
+4
View File
@@ -4040,6 +4040,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
case GGML_OP_RWKV_WKV6:
ggml_sycl_op_rwkv_wkv6(ctx, dst);
break;
case GGML_OP_GATED_LINEAR_ATTN:
ggml_sycl_op_gated_linear_attn(ctx, dst);
break;
default:
return false;
}
@@ -4507,6 +4510,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_LEAKY_RELU:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_RWKV_WKV6:
case GGML_OP_GATED_LINEAR_ATTN:
return true;
default:
return false;
+105
View File
@@ -0,0 +1,105 @@
#include <sycl/sycl.hpp>
#include "common.hpp"
template <u_int HEAD_SIZE>
static void gated_linear_attn_f32_kernel(const dpct::queue_ptr stream, u_int B, u_int T, u_int C, u_int H, float scale,
const float * k, const float * v, const float * r, const float * td,
const float * s, float * dst) {
const u_int head_size = HEAD_SIZE;
const u_int state_size = C * head_size;
const u_int n_seq_tokens = T / B;
sycl::range<1> block_dims((C / H));
sycl::range<1> grid_dims((B * H));
stream->submit([&](sycl::handler & cgh) {
/* local memory accessors*/
auto _k = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh);
auto _r = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh);
auto _td = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh);
cgh.parallel_for(sycl::nd_range<1>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<1> item) {
u_int tid = item.get_local_id(0);
u_int bid = item.get_group(0);
u_int batch_i = bid / H;
u_int head_i = bid % H;
float state[head_size];
#pragma unroll
for (u_int i = 0; i < head_size; i++) {
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
}
for (u_int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
item.barrier(sycl::access::fence_space::local_space); //sync threads
_k[tid] = k[t];
_r[tid] = r[t];
_td[tid] = td[t];
item.barrier(sycl::access::fence_space::local_space); //sync threads
const float _v = v[t];
float y = 0;
for (u_int j = 0; j < head_size; j += 4) {
const sycl::float4 & k = (sycl::float4 &) (_k[j]);
const sycl::float4 & r = (sycl::float4 &) (_r[j]);
const sycl::float4 & td = (sycl::float4 &) (_td[j]);
sycl::float4 & s = (sycl::float4 &) (state[j]);
sycl::float4 kv;
kv.x() = k.x() * _v;
kv.y() = k.y() * _v;
kv.z() = k.z() * _v;
kv.w() = k.w() * _v;
s.x() = s.x() * td.x() + kv.x();
s.y() = s.y() * td.y() + kv.y();
s.z() = s.z() * td.z() + kv.z();
s.w() = s.w() * td.w() + kv.w();
y += r.x() * s.x();
y += r.y() * s.y();
y += r.z() * s.z();
y += r.w() * s.w();
}
dst[t] = y * scale;
}
#pragma unroll
for (u_int i = 0; i < head_size; i++) {
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
}
});
});
}
void ggml_sycl_op_gated_linear_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
const float * k_d = static_cast<const float *>(dst->src[0]->data);
const float * v_d = static_cast<const float *>(dst->src[1]->data);
const float * r_d = static_cast<const float *>(dst->src[2]->data);
const float * td_d = static_cast<const float *>(dst->src[3]->data);
const float * s_d = static_cast<const float *>(dst->src[4]->data);
const int64_t B = dst->src[4]->ne[1];
const int64_t T = dst->src[0]->ne[2];
const int64_t C = dst->ne[0];
const int64_t H = dst->src[0]->ne[1];
dpct::queue_ptr stream = ctx.stream();
GGML_ASSERT(dst->src[4]->type == GGML_TYPE_F32);
GGML_ASSERT(C % H == 0);
GGML_ASSERT(C / H == 64 || C / H == 128);
float scale;
memcpy(&scale, dst->op_params, sizeof(float));
float * dst_d = (float *) dst->data;
if (C / H == 64) {
gated_linear_attn_f32_kernel<64>(stream, B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d);
} else {
gated_linear_attn_f32_kernel<128>(stream, B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d);
}
}
+8
View File
@@ -0,0 +1,8 @@
#ifndef GGML_SYCL_GLA_HPP
#define GGML_SYCL_GLA_HPP
#include "common.hpp"
void ggml_sycl_op_gated_linear_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
#endif // GGML_SYCL_GLA_HPP
+58 -6
View File
@@ -1,5 +1,20 @@
cmake_minimum_required(VERSION 3.19)
cmake_policy(SET CMP0114 NEW)
find_package(Vulkan COMPONENTS glslc REQUIRED)
function(detect_host_compiler)
if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows")
find_program(HOST_C_COMPILER NAMES cl gcc clang NO_CMAKE_FIND_ROOT_PATH)
find_program(HOST_CXX_COMPILER NAMES cl g++ clang++ NO_CMAKE_FIND_ROOT_PATH)
else()
find_program(HOST_C_COMPILER NAMES gcc clang NO_CMAKE_FIND_ROOT_PATH)
find_program(HOST_CXX_COMPILER NAMES g++ clang++ NO_CMAKE_FIND_ROOT_PATH)
endif()
set(HOST_C_COMPILER "${HOST_C_COMPILER}" PARENT_SCOPE)
set(HOST_CXX_COMPILER "${HOST_CXX_COMPILER}" PARENT_SCOPE)
endfunction()
if (Vulkan_FOUND)
message(STATUS "Vulkan found")
@@ -73,19 +88,56 @@ if (Vulkan_FOUND)
add_compile_definitions(GGML_VULKAN_RUN_TESTS)
endif()
add_subdirectory(vulkan-shaders)
if (NOT CMAKE_CROSSCOMPILING)
add_subdirectory(vulkan-shaders)
if (MSVC)
foreach(CONFIG ${CMAKE_CONFIGURATION_TYPES})
string(TOUPPER ${CONFIG} CONFIG)
set_target_properties(vulkan-shaders-gen PROPERTIES
RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
endforeach()
endif()
else()
if (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN)
set(HOST_CMAKE_TOOLCHAIN_FILE ${GGML_VULKAN_SHADERS_GEN_TOOLCHAIN})
else()
detect_host_compiler()
if (NOT HOST_C_COMPILER OR NOT HOST_CXX_COMPILER)
message(FATAL_ERROR "Host compiler not found")
else()
message(STATUS "Host compiler: ${HOST_C_COMPILER} ${HOST_CXX_COMPILER}")
endif()
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/host-toolchain.cmake.in ${CMAKE_BINARY_DIR}/host-toolchain.cmake @ONLY)
set(HOST_CMAKE_TOOLCHAIN_FILE ${CMAKE_BINARY_DIR}/host-toolchain.cmake)
endif()
message(STATUS "vulkan-shaders-gen toolchain file: ${HOST_CMAKE_TOOLCHAIN_FILE}")
set (_ggml_vk_genshaders_cmd vulkan-shaders-gen)
include(ExternalProject)
# Native build through ExternalProject_Add
ExternalProject_Add(
vulkan-shaders-gen
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders
CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE}
-DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}
BUILD_COMMAND ${CMAKE_COMMAND} --build .
INSTALL_COMMAND ${CMAKE_COMMAND} --install .
INSTALL_DIR ${CMAKE_BINARY_DIR}
)
ExternalProject_Add_StepTargets(vulkan-shaders-gen build install)
endif()
set (_ggml_vk_host_suffix $<IF:$<STREQUAL:${CMAKE_HOST_SYSTEM_NAME},Windows>,.exe,>)
set (_ggml_vk_genshaders_cmd ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/vulkan-shaders-gen${_ggml_vk_host_suffix})
set (_ggml_vk_header ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp)
set (_ggml_vk_source ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp)
set (_ggml_vk_input_dir ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders)
set (_ggml_vk_output_dir ${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv)
file(GLOB _ggml_vk_shader_deps "${_ggml_vk_input_dir}/*.comp")
set (_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen)
if (NOT CMAKE_CROSSCOMPILING)
set(_ggml_vk_genshaders_cmd "$<TARGET_FILE_DIR:vulkan-shaders-gen>/${_ggml_vk_genshaders_cmd}")
endif ()
if (CMAKE_CROSSCOMPILING)
set(_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen-build vulkan-shaders-gen-install)
endif()
add_custom_command(
OUTPUT ${_ggml_vk_header}
@@ -99,7 +151,7 @@ if (Vulkan_FOUND)
--target-cpp ${_ggml_vk_source}
--no-clean
DEPENDS ${_ggml_vk_shader_deps} ${_ggml_vk_genshaders_cmd}
DEPENDS ${_ggml_vk_shader_deps}
COMMENT "Generate vulkan shaders"
)
@@ -0,0 +1,15 @@
set(CMAKE_BUILD_TYPE Release)
set(CMAKE_C_FLAGS -O2)
set(CMAKE_CXX_FLAGS -O2)
set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER)
set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY NEVER)
set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE NEVER)
set(CMAKE_C_COMPILER @HOST_C_COMPILER@)
set(CMAKE_CXX_COMPILER @HOST_CXX_COMPILER@)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY @CMAKE_RUNTIME_OUTPUT_DIRECTORY@)
if("@CMAKE_C_COMPILER_ID@" STREQUAL "MSVC")
foreach(CONFIG IN ITEMS DEBUG RELEASE MINSIZEREL RELWITHDEBINFO)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
endforeach()
endif()
+109 -11
View File
@@ -228,6 +228,8 @@ struct vk_device_struct {
vk_pipeline pipeline_repeat_f32;
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16;
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
vk_pipeline pipeline_norm_f32;
vk_pipeline pipeline_group_norm_f32;
vk_pipeline pipeline_rms_norm_f32;
@@ -384,10 +386,13 @@ struct vk_flash_attn_push_constants {
uint32_t nev3;
uint32_t nem1;
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t nb21;
uint32_t nb22;
uint32_t nb23;
uint32_t nb31;
@@ -1965,6 +1970,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_1], "cpy_q4_1_f32", cpy_q4_1_f32_len, cpy_q4_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_0], "cpy_q5_0_f32", cpy_q5_0_f32_len, cpy_q5_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_1], "cpy_q5_1_f32", cpy_q5_1_f32_len, cpy_q5_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q8_0], "cpy_q8_0_f32", cpy_q8_0_f32_len, cpy_q8_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_IQ4_NL], "cpy_iq4_nl_f32", cpy_iq4_nl_f32_len, cpy_iq4_nl_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
ggml_vk_create_pipeline(device, device->pipeline_add_f32_norepeat, "add_f32_norepeat", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
@@ -3689,6 +3708,33 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_cpy_f16_f16;
}
}
if (src->type == GGML_TYPE_F32) {
switch (to) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_IQ4_NL:
return ctx->device->pipeline_cpy_f32_quant[to];
default:
break;
}
}
if (to == GGML_TYPE_F32) {
switch (src->type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_IQ4_NL:
return ctx->device->pipeline_cpy_quant_f32[src->type];
default:
break;
}
}
std::cerr << "Missing CPY op for types: " << ggml_type_name(src->type) << " " << ggml_type_name(to) << std::endl;
GGML_ABORT("fatal error");
@@ -4766,7 +4812,14 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
}
assert(pipelines);
bool aligned = (KV % pipelines[1]->align) == 0;
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
bool aligned = (KV % pipelines[1]->align) == 0 &&
// the "aligned" shader variant will forcibly align strides, for performance
(q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
vk_pipeline pipeline = pipelines[aligned];
assert(pipeline);
@@ -4802,15 +4855,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
if (ctx->device->uma) {
ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset);
ggml_vk_host_get(ctx->device, k->data, d_K, q_buf_offset);
ggml_vk_host_get(ctx->device, v->data, d_V, q_buf_offset);
ggml_vk_host_get(ctx->device, dst->data, d_D, q_buf_offset);
ggml_vk_host_get(ctx->device, k->data, d_K, k_buf_offset);
ggml_vk_host_get(ctx->device, v->data, d_V, v_buf_offset);
ggml_vk_host_get(ctx->device, dst->data, d_D, d_buf_offset);
Q_uma = d_Q != nullptr;
K_uma = d_K != nullptr;
V_uma = d_V != nullptr;
D_uma = d_D != nullptr;
if (mask) {
ggml_vk_host_get(ctx->device, mask->data, d_M, q_buf_offset);
ggml_vk_host_get(ctx->device, mask->data, d_M, m_buf_offset);
M_uma = d_M != nullptr;
}
}
@@ -4848,7 +4901,18 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
}
}
const vk_flash_attn_push_constants pc = { N, KV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, (uint32_t)neq2, (uint32_t)neq3, (uint32_t)nek2, (uint32_t)nek3, (uint32_t)nev2, (uint32_t)nev3, nem1, (uint32_t)nbq2, (uint32_t)nbq3, (uint32_t)nbk2, (uint32_t)nbk3, (uint32_t)nbv2, (uint32_t)nbv3, nbm1, scale, max_bias, logit_softcap, mask != nullptr, n_head_log2, m0, m1 };
const vk_flash_attn_push_constants pc = { N, KV,
(uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
(uint32_t)neq2, (uint32_t)neq3,
(uint32_t)nek2, (uint32_t)nek3,
(uint32_t)nev2, (uint32_t)nev3,
nem1,
q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
nbm1,
scale, max_bias, logit_softcap,
mask != nullptr, n_head_log2, m0, m1 };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{
vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
@@ -5160,7 +5224,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
}
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
std::cerr << "), " << ggml_op_name(op) << ", " << (dryrun ? "dryrun" : "") << ")");
GGML_ASSERT(op == GGML_OP_GET_ROWS || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT
GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT
GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0)); // NOLINT
GGML_ASSERT(dst->buffer != nullptr);
const uint64_t ne00 = src0->ne[0];
@@ -7905,12 +7969,36 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
{
ggml_type src0_type = op->src[0]->type;
ggml_type src1_type = op->src[1] != nullptr ? op->src[1]->type : src0_type;
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
return true;
if (src0_type == GGML_TYPE_F32) {
switch (src1_type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_IQ4_NL:
return true;
default:
break;
}
}
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
return true;
if (src1_type == GGML_TYPE_F32) {
switch (src0_type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_IQ4_NL:
return true;
default:
break;
}
}
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
return true;
}
@@ -8601,6 +8689,7 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
ggml_tensor * src0 = tensor->src[0];
ggml_tensor * src1 = tensor->src[1];
ggml_tensor * src2 = tensor->src[2];
ggml_tensor * src3 = tensor->src[3];
void * tensor_data = tensor->data;
@@ -8663,6 +8752,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
if (src2 != nullptr) {
std::cerr << "src2=" << src2 << " src2->name=" << src2->name << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
}
if (src3 != nullptr) {
std::cerr << "src3=" << src3 << " src3->name=" << src3->name << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
}
std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
std::cerr << std::endl << "Result:" << std::endl;
ggml_vk_print_tensor_area(tensor, tensor_data, i0, i1, i2, i3);
@@ -8707,6 +8799,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
if (src2 != nullptr) {
std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
}
if (src3 != nullptr) {
std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
}
std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
std::cerr << std::endl << "Result:" << std::endl;
ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0);
@@ -8729,6 +8824,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
if (src2 != nullptr) {
std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
}
if (src3 != nullptr) {
std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
}
std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
std::cerr << std::endl << "Result:" << std::endl;
ggml_vk_print_tensor_area(tensor, tensor_data, first_error[0], first_error[1], first_error[2], first_error[3]);
@@ -1,9 +1,11 @@
find_package (Threads REQUIRED)
find_package(Vulkan COMPONENTS glslc REQUIRED)
find_program(GLSLC_EXECUTABLE glslc)
if(NOT GLSLC_EXECUTABLE)
message(FATAL_ERROR "glslc not found.")
endif()
set(TARGET vulkan-shaders-gen)
add_executable(${TARGET} vulkan-shaders-gen.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_compile_features(${TARGET} PRIVATE cxx_std_17)
target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads)
target_link_libraries(vulkan-shaders-gen PRIVATE Vulkan::Vulkan)
@@ -0,0 +1,51 @@
#version 450
#include "types.comp"
#include "generic_unary_head.comp"
#include "dequant_funcs.comp"
#if defined(DATA_A_IQ4_NL)
// 16 invocations needed for init_iq4nl_shmem
layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
#else
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
#endif
void main() {
#if defined(DATA_A_IQ4_NL)
init_iq4nl_shmem();
if (gl_LocalInvocationIndex.x != 0) {
return;
}
#endif
const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K;
if (idx >= p.ne) {
return;
}
uint dst_idx = get_doffset() + dst_idx(idx);
uint src_idx = src0_idx_quant(idx, QUANT_K);
const uint a_offset = 0;
const uint ib = src_idx;
const vec2 dm = get_dm(ib, a_offset);
[[unroll]] for (int j = 0; j < QUANT_K; j += 4) {
vec4 v = dequantize4(ib, j / QUANT_R, a_offset);
v = v * dm.x + vec4(dm.y);
#if QUANT_R == 2
data_d[dst_idx + j/2 + 0] = v[0];
data_d[dst_idx + j/2 + QUANT_K/2 + 0] = v[1];
data_d[dst_idx + j/2 + 1] = v[2];
data_d[dst_idx + j/2 + QUANT_K/2 + 1] = v[3];
#else
data_d[dst_idx + j + 0] = v[0];
data_d[dst_idx + j + 1] = v[1];
data_d[dst_idx + j + 2] = v[2];
data_d[dst_idx + j + 3] = v[3];
#endif
}
}
@@ -0,0 +1,237 @@
#version 450
#include "types.comp"
#include "generic_unary_head.comp"
#if defined(DATA_A_IQ4_NL)
// 16 invocations needed for init_iq4nl_shmem
layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
#else
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
#endif
layout (binding = 0) readonly buffer S {float data_s[];};
layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];};
#if defined(DATA_A_Q4_0)
void quantize(uint dst_idx, uint src_idx)
{
float amax = 0.0;
float vmax = 0.0;
[[unroll]] for (int j = 0; j < QUANT_K_Q4_0; ++j) {
const float v = data_s[src_idx + j];
if (amax < abs(v)) {
amax = abs(v);
vmax = v;
}
}
const float d = vmax / -8;
const float id = (d != 0.0) ? 1.0/d : 0.0;
data_q[dst_idx].d = float16_t(d);
[[unroll]] for (int j = 0; j < QUANT_K_Q4_0/2; ++j) {
const float x0 = data_s[src_idx + 0 + j]*id;
const float x1 = data_s[src_idx + QUANT_K_Q4_0/2 + j]*id;
const uint xi0 = min(15, int(x0 + 8.5));
const uint xi1 = min(15, int(x1 + 8.5));
data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4));
}
}
#endif
#if defined(DATA_A_Q4_1)
void quantize(uint dst_idx, uint src_idx)
{
float vmin = 1.0/0.0;
float vmax = -vmin;
[[unroll]] for (int j = 0; j < QUANT_K_Q4_1; ++j) {
const float v = data_s[src_idx + j];
if (v < vmin) vmin = v;
if (v > vmax) vmax = v;
}
const float d = (vmax - vmin) / ((1 << 4) - 1);
const float id = (d != 0.0) ? 1.0/d : 0.0;
data_q[dst_idx].d = float16_t(d);
data_q[dst_idx].m = float16_t(vmin);
[[unroll]] for (int j = 0; j < QUANT_K_Q4_1/2; ++j) {
const float x0 = (data_s[src_idx + 0 + j] - vmin)*id;
const float x1 = (data_s[src_idx + QUANT_K_Q4_1/2 + j] - vmin)*id;
const uint xi0 = min(15, int(x0 + 0.5));
const uint xi1 = min(15, int(x1 + 0.5));
data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4));
}
}
#endif
#if defined(DATA_A_Q5_0)
void quantize(uint dst_idx, uint src_idx)
{
float amax = 0.0;
float vmax = 0.0;
[[unroll]] for (int j = 0; j < QUANT_K_Q5_0; ++j) {
const float v = data_s[src_idx + j];
if (amax < abs(v)) {
amax = abs(v);
vmax = v;
}
}
const float d = vmax / -16;
const float id = (d != 0.0) ? 1.0/d : 0.0;
data_q[dst_idx].d = float16_t(d);
uint32_t qh = 0;
[[unroll]] for (int j = 0; j < QUANT_K_Q5_0/2; ++j) {
const float x0 = data_s[src_idx + 0 + j]*id;
const float x1 = data_s[src_idx + QUANT_K_Q5_0/2 + j]*id;
const uint xi0 = min(31, int(x0 + 16.5));
const uint xi1 = min(31, int(x1 + 16.5));
data_q[dst_idx].qs[j] = uint8_t((xi0 & 0xf) | ((xi1 & 0xf) << 4));
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QUANT_K_Q5_0/2);
}
data_q[dst_idx].qh[0] = uint16_t(qh & 0xFFFF);
data_q[dst_idx].qh[1] = uint16_t(qh >> 16);
}
#endif
#if defined(DATA_A_Q5_1)
void quantize(uint dst_idx, uint src_idx)
{
float min = data_s[src_idx + 0];
float max = min;
[[unroll]] for (int j = 1; j < QUANT_K_Q5_1; ++j) {
const float v = data_s[src_idx + j];
min = v < min ? v : min;
max = v > max ? v : max;
}
const float d = (max - min) / 31;
const float id = (d != 0) ? 1.0/d : 0.0;
data_q[dst_idx].d = float16_t(d);
data_q[dst_idx].m = float16_t(min);
uint32_t qh = 0;
[[unroll]] for (int j = 0; j < QUANT_K_Q5_1/2; ++j) {
const float x0 = (data_s[src_idx + 0 + j] - min)*id;
const float x1 = (data_s[src_idx + QUANT_K_Q5_1/2 + j] - min)*id;
const uint xi0 = uint(x0 + 0.5);
const uint xi1 = uint(x1 + 0.5);
data_q[dst_idx].qs[j] = uint8_t((xi0 & 0xf) | ((xi1 & 0xf) << 4));
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QUANT_K_Q5_1/2);
}
data_q[dst_idx].qh = qh;
}
#endif
#if defined(DATA_A_Q8_0)
void quantize(uint dst_idx, uint src_idx)
{
float amax = 0.0; // absolute max
[[unroll]] for (int j = 0; j < QUANT_K_Q8_0; j++) {
const float v = data_s[src_idx + j];
amax = max(amax, abs(v));
}
const float d = amax / ((1 << 7) - 1);
const float id = (d != 0.0) ? 1.0/d : 0.0;
data_q[dst_idx].d = float16_t(d);
[[unroll]] for (int j = 0; j < QUANT_K_Q8_0; ++j) {
const float x0 = data_s[src_idx + j]*id;
data_q[dst_idx].qs[j] = int8_t(round(x0));
}
}
#endif
#if defined(DATA_A_IQ4_NL)
uint best_index(float x) {
if (x <= kvalues_iq4nl[0]) return 0;
if (x >= kvalues_iq4nl[15]) return 15;
int ml = 0, mu = 15;
while (mu-ml > 1) {
int mav = (ml+mu)/2;
if (x < kvalues_iq4nl[mav]) mu = mav; else ml = mav;
}
return x - kvalues_iq4nl[mu-1] < kvalues_iq4nl[mu] - x ? mu-1 : mu;
}
void quantize(uint dst_idx, uint src_idx)
{
float amax = 0.0;
float vmax = 0.0;
[[unroll]] for (int j = 0; j < QUANT_K_IQ4_NL; ++j) {
const float v = data_s[src_idx + j];
if (amax < abs(v)) {
amax = abs(v);
vmax = v;
}
}
float d = vmax / kvalues_iq4nl[0];
const float id = (d != 0.0) ? 1.0/d : 0.0;
float sumqx = 0, sumq2 = 0;
[[unroll]] for (int j = 0; j < QUANT_K_IQ4_NL/2; ++j) {
const float x0 = data_s[src_idx + 0 + j]*id;
const float x1 = data_s[src_idx + QUANT_K_IQ4_NL/2 + j]*id;
const uint xi0 = best_index(x0);
const uint xi1 = best_index(x1);
data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4));
const float v0 = kvalues_iq4nl[xi0];
const float v1 = kvalues_iq4nl[xi1];
const float w0 = data_s[src_idx + 0 + j]*data_s[src_idx + 0 + j];
const float w1 = data_s[src_idx + QUANT_K_IQ4_NL/2 + j]*data_s[src_idx + QUANT_K_IQ4_NL/2 + j];
sumqx += w0*v0*data_s[src_idx + j] + w1*v1*data_s[src_idx + QUANT_K_IQ4_NL/2 + j];
sumq2 += w0*v0*v0 + w1*v1*v1;
}
data_q[dst_idx].d = float16_t(sumq2 > 0 ? sumqx/sumq2 : d);
}
#endif
void main() {
#if defined(DATA_A_IQ4_NL)
init_iq4nl_shmem();
if (gl_LocalInvocationIndex.x != 0) {
return;
}
#endif
const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K;
if (idx >= p.ne) {
return;
}
uint dst_idx = dst_idx_quant(idx, QUANT_K);
uint src_idx = get_aoffset() + src0_idx(idx);
quantize(dst_idx, src_idx);
}
@@ -101,19 +101,25 @@ layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_
block_q2_K block;
};
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2_K_packed16 {
block_q2_K_packed16 block;
};
float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl);
const f16vec2 d = bl.block.d;
const uint idx = coordInBlock[1];
const uint iqs = idx;
const uint qsi = (iqs / 128) * 32 + (iqs % 32); // 0..31
const uint scalesi = iqs / 16; // 0..15
const uint qsshift = ((iqs % 128) / 32) * 2; // 0,2,4,6
const uint scalesi = (idx & 0xF0) >> 4; // 0..15
const uint qsshift = (idx & 0x60) >> 4; // 0,2,4,6
uint qs = uint32_t(bl16.block.qs[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]);
qs = (qs >> qsshift) & 0x0303;
qs = unpack8(qs)[idx & 1];
uint32_t qs = bl.block.qs[qsi];
const uint scales = bl.block.scales[scalesi];
float16_t ret = d.x * float16_t(scales & 0xF) * float16_t((qs >> qsshift) & 3) - d.y * float16_t(scales >> 4);
float16_t ret = d.x * float16_t(scales & 0xF) * float16_t(qs) - d.y * float16_t(scales >> 4);
return ret;
}
@@ -157,39 +163,47 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4
block_q4_K_packed16 block;
};
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed128 {
block_q4_K_packed128 block;
};
float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl);
decodeBufQ4_K_packed128 bl128 = decodeBufQ4_K_packed128(bl);
const uint idx = coordInBlock[1];
const uint b = (idx & 0x20) >> 5; // 0,1
const uint is = (idx & 0xE0) >> 5; // 0..7
const f16vec2 loadd = bl.block.d;
uvec4 v = bl128.block.q4k[0];
const f16vec2 loadd = unpackFloat2x16(v.x);
uint32_t sc;
uint32_t mbyte;
uint32_t scidx0 = (is < 4) ? is : (is + 4);
uint32_t scidx1 = (is < 4) ? is : (is - 4);
uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0;
uint32_t scidxshift1 = (is < 4) ? 0 : 2;
uint32_t mbidx0 = is + 4;
uint32_t mbidx1 = (is < 4) ? is + 4 : is;
uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0;
uint32_t mbidxshift0 = (is < 4) ? 0 : 4;
uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
uint32_t mbidxshift1 = (is < 4) ? 0 : 2;
uint32_t scale0 = v.y;
uint32_t scale4 = v.z;
uint32_t scale8 = v.w;
sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1));
mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1));
uint32_t sc_lo = scale0;
uint32_t mb_lo = scale4;
uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
sc = is < 4 ? sc_lo : sc_hi;
mbyte = is < 4 ? mb_lo : mb_hi;
sc = sc >> (8 * (is & 3));
mbyte = mbyte >> (8 * (is & 3));
sc &= 0x3F;
mbyte &= 0x3F;
const float16_t d = loadd.x * float16_t(sc);
const float16_t m = loadd.y * float16_t(mbyte);
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
qs = (qs >> (b * 4)) & 0x0F0F;
qs = unpack8(qs)[idx & 1];
qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF;
float16_t ret = d * float16_t(qs) - m;
@@ -204,47 +218,53 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5
block_q5_K_packed16 block;
};
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed128 {
block_q5_K_packed128 block;
};
float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl);
decodeBufQ5_K_packed128 bl128 = decodeBufQ5_K_packed128(bl);
const uint idx = coordInBlock[1];
const uint b = (idx & 0x20) >> 5; // 0,1
const uint is = (idx & 0xE0) >> 5; // 0..7
const uint32_t hm = 0x0101 << is;
uvec4 v = bl128.block.q5k[0];
const f16vec2 loadd = bl.block.d;
const f16vec2 loadd = unpackFloat2x16(v.x);
uint32_t sc;
uint32_t mbyte;
uint32_t scidx0 = (is < 4) ? is : (is + 4);
uint32_t scidx1 = (is < 4) ? is : (is - 4);
uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0;
uint32_t scidxshift1 = (is < 4) ? 0 : 2;
uint32_t mbidx0 = is + 4;
uint32_t mbidx1 = (is < 4) ? is + 4 : is;
uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0;
uint32_t mbidxshift0 = (is < 4) ? 0 : 4;
uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
uint32_t mbidxshift1 = (is < 4) ? 0 : 2;
uint32_t scale0 = v.y;
uint32_t scale4 = v.z;
uint32_t scale8 = v.w;
sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1));
mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1));
uint32_t sc_lo = scale0;
uint32_t mb_lo = scale4;
uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
sc = is < 4 ? sc_lo : sc_hi;
mbyte = is < 4 ? mb_lo : mb_hi;
sc = sc >> (8 * (is & 3));
mbyte = mbyte >> (8 * (is & 3));
sc &= 0x3F;
mbyte &= 0x3F;
const float16_t d = loadd.x * float16_t(sc);
const float16_t m = loadd.y * float16_t(mbyte);
uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]);
qh = qh & hm;
qh = unpack8(qh)[idx & 1];
qh = ((qh >> is) & 0x101) << 4;
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
qs = (qs >> (b * 4)) & 0x0F0F;
qs = unpack8(qs)[idx & 1];
qs = unpack8(qs | qh)[idx & 1];
float16_t ret = d * (float16_t(qs) + (qh != 0 ? float16_t(16) : float16_t(0))) - m;
float16_t ret = d * (float16_t(qs)) - m;
return ret;
}
@@ -42,10 +42,13 @@ layout (push_constant) uniform parameter {
uint32_t nev3;
uint32_t nem1;
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t nb21;
uint32_t nb22;
uint32_t nb23;
uint32_t nb31;
@@ -146,6 +149,23 @@ void main() {
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
// nb?1 are already divided by the type size and are in units of elements
uint32_t q_stride = p.nb01;
uint32_t k_stride = p.nb11;
uint32_t v_stride = p.nb21;
// hint to the compiler that strides are aligned for the aligned variant of the shader
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
{
q_stride &= ~7;
#if !defined(BLOCK_SIZE)
k_stride &= ~7;
v_stride &= ~7;
#endif
}
tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Q;
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16;
@@ -54,3 +54,23 @@ uint dst_idx(uint idx) {
const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10;
}
uint src0_idx_quant(uint idx, uint qk) {
const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L);
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L);
const uint i02_offset = i02*p.ne01*p.ne00;
const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L);
const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + (i00/qk)*p.nb00;
}
uint dst_idx_quant(uint idx, uint qk) {
const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);
const uint i12_offset = i12*p.ne11*p.ne10;
const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);
const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + (i10/qk)*p.nb10;
}
@@ -5,6 +5,80 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
shared FLOAT_TYPE sccache1[BLOCK_SIZE/16][16];
shared FLOAT_TYPE sccache2[BLOCK_SIZE/16][16];
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint v_im, const uint ix, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
const uint y_idx = i * QUANT_K + y_offset;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
barrier();
if (!all_threads) { // when we don't have enough blocks to use all threads
if (i < num_blocks_per_row) {
const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]);
sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF);
sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
}
barrier();
if (i >= num_blocks_per_row)
continue;
} else {
const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]);
sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF);
sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
barrier();
}
const uint32_t qs_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303));
const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303));
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]);
vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);
vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);
vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);
vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);
vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);
vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
[[unroll]] for (int l = 0; l < 2; ++l) {
sum1 = fma(FLOAT_TYPE(b0[l]), sccache1[ix][ 8*v_im] * qs_u32_0[l ],
fma(FLOAT_TYPE(b16[l]), sccache1[ix][1 + 8*v_im] * qs_u32_0[l+2],
fma(FLOAT_TYPE(b32[l]), sccache1[ix][2 + 8*v_im] * qs_u32_2[l ],
fma(FLOAT_TYPE(b48[l]), sccache1[ix][3 + 8*v_im] * qs_u32_2[l+2],
fma(FLOAT_TYPE(b64[l]), sccache1[ix][4 + 8*v_im] * qs_u32_4[l ],
fma(FLOAT_TYPE(b80[l]), sccache1[ix][5 + 8*v_im] * qs_u32_4[l+2],
fma(FLOAT_TYPE(b96[l]), sccache1[ix][6 + 8*v_im] * qs_u32_6[l ],
fma(FLOAT_TYPE(b112[l]), sccache1[ix][7 + 8*v_im] * qs_u32_6[l+2], sum1))))))));
sum2 = fma(FLOAT_TYPE(b0[l]), sccache2[ix][ 8*v_im],
fma(FLOAT_TYPE(b16[l]), sccache2[ix][1 + 8*v_im],
fma(FLOAT_TYPE(b32[l]), sccache2[ix][2 + 8*v_im],
fma(FLOAT_TYPE(b48[l]), sccache2[ix][3 + 8*v_im],
fma(FLOAT_TYPE(b64[l]), sccache2[ix][4 + 8*v_im],
fma(FLOAT_TYPE(b80[l]), sccache2[ix][5 + 8*v_im],
fma(FLOAT_TYPE(b96[l]), sccache2[ix][6 + 8*v_im],
fma(FLOAT_TYPE(b112[l]), sccache2[ix][7 + 8*v_im], sum2))))))));
}
temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
}
}
}
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
@@ -14,88 +88,28 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
// 16 threads are used to process each block
const uint it_size = gl_WorkGroupSize.x/16;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid%16; // 0...16
const uint ix = tid/16;
const uint itid = tid%16; // 0...15
const uint ix = tid/16;
const uint step = 8;
const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const uint v_in = itid - step*v_im; // 0...15 or 0...7
const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128...
const uint v_in = itid - 8*v_im; // 0...7
const uint l0 = 2*v_in; // 0...15
const uint q_offset = 32*v_im + l0;
const uint s_offset = 8*v_im;
const uint y_offset = 128*v_im + l0;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
const uint y_idx = i * QUANT_K + y_offset;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 0];
uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1];
uint32_t s0_lo4_u32 = s0_u32 & 0x0F0F0F0F;
uint32_t s0_hi4_u32 = (s0_u32 >> 4) & 0x0F0F0F0F;
uint32_t s4_lo4_u32 = s4_u32 & 0x0F0F0F0F;
uint32_t s4_hi4_u32 = (s4_u32 >> 4) & 0x0F0F0F0F;
uvec4 s0_lo4 = uvec4(unpack8(s0_lo4_u32));
uvec4 s4_lo4 = uvec4(unpack8(s4_lo4_u32));
uvec4 s0_hi4 = uvec4(unpack8(s0_hi4_u32));
uvec4 s4_hi4 = uvec4(unpack8(s4_hi4_u32));
uint16_t qs0_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 0];
uint16_t qs16_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 8];
uvec2 qs0 = uvec2(unpack8(qs0_u16));
uvec2 qs16 = uvec2(unpack8(qs16_u16));
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]);
vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);
vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);
vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);
vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);
vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);
vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
[[unroll]] for (int l = 0; l < 2; ++l) {
sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3),
fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3),
fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3),
fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3),
fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l] >> 4) & 3),
fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3),
fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l] >> 6) & 3),
fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1))))))));
sum2 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_hi4[0]),
fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_hi4[1]),
fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_hi4[2]),
fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_hi4[3]),
fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_hi4[0]),
fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_hi4[1]),
fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_hi4[2]),
fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2))))))));
}
temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
}
}
}
const uint nbr_par_th = num_blocks_per_row%it_size;
const uint nbr_all_th = num_blocks_per_row - nbr_par_th;
uint i0 = 0;
[[unroll]] for (; i0 < nbr_all_th; i0 += it_size)
calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true);
calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false);
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
@@ -5,6 +5,74 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
shared FLOAT_TYPE sccache[BLOCK_SIZE/16][2][8];
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, const uint itid8, const uint v_im, const uint v_im4, const uint v_in, const uint32_t hm_m[4], const uint q_offset, const uint y_offset, const uint s_shift, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
const uint y_idx = i * QUANT_K + y_offset;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
if (!all_threads) { // when we don't have enough blocks to use all threads
barrier();
if (i < num_blocks_per_row)
sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
barrier();
if (i >= num_blocks_per_row)
continue;
}
const uint32_t hmk = ~(uint32_t(data_a_packed16[ib0 + i].hmask[v_in]) | (uint32_t(data_a_packed16[ib0 + i].hmask[v_in + 8]) << 16));
const vec4 hmk_0 = vec4(unpack8(((hmk & hm_m[0]) >> ( v_im4)) << 2));
const vec4 hmk_1 = vec4(unpack8(((hmk & hm_m[1]) >> (1 + v_im4)) << 2));
const vec4 hmk_2 = vec4(unpack8(((hmk & hm_m[2]) >> (2 + v_im4)) << 2));
const vec4 hmk_3 = vec4(unpack8(((hmk & hm_m[3]) >> (3 + v_im4)) << 2));
// 0, 1, 16, 17
uint32_t qs_u32 = uint32_t(data_a[ib0 + i].qs[q_offset]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 1]) << 8);
qs_u32 |= (uint32_t(data_a[ib0 + i].qs[q_offset + 16]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 17]) << 8)) << 16;
const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303));
const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303));
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
if (all_threads) {
barrier();
sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
barrier();
}
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]);
vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);
vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);
vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);
vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);
vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);
vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
[[unroll]] for (int l = 0; l < 2; ++l) {
sum = fma(FLOAT_TYPE( b0[l]) * sccache[ix][v_im][0], qs_u32_0[l ] - hmk_0[l ],
fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2],
fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], qs_u32_2[l ] - hmk_1[l ],
fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2],
fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], qs_u32_4[l ] - hmk_2[l ],
fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2],
fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], qs_u32_6[l ] - hmk_3[l ],
fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum))))))));
}
temp[j][n] = fma(d, sum, temp[j][n]);
}
}
}
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
@@ -14,76 +82,37 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
// 16 threads are used to process each block
const uint it_size = gl_WorkGroupSize.x/16;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid%16; // 0...16
const uint ix = tid/16;
const uint itid = tid%16; // 0...15
const uint ix = tid/16;
const uint itid8 = itid%8;
const uint step = 8;
const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128...
const uint v_im4 = v_im*4;
const uint v_in = itid - 8*v_im; // 0...7
const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const uint v_in = itid - step*v_im; // 0...15 or 0...7
const uint8_t m = uint8_t(1 << (4 * v_im));
const uint32_t m = 0x01010101 << (4 * v_im);
uint32_t hm_m[4];
[[unroll]] for (uint j = 0; j < 4; ++j)
hm_m[j] = m << j;
const uint l0 = 2*v_in; // 0...15
const uint q_offset = 32*v_im + l0;
const uint y_offset = 128*v_im + l0;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}
const uint s_shift = 4 * v_im;
const uint s_shift = v_im4 + 2*(itid8/4);
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
const uint y_idx = i * QUANT_K + y_offset;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
uint16_t s0_16 = data_a_packed16[ib0 + i].scales[0];
uint16_t s2_16 = data_a_packed16[ib0 + i].scales[1];
uint16_t s4_16 = data_a_packed16[ib0 + i].scales[2];
uint16_t s6_16 = data_a_packed16[ib0 + i].scales[3];
uint16_t s8_16 = data_a_packed16[ib0 + i].scales[4];
uint16_t s10_16 = data_a_packed16[ib0 + i].scales[5];
u8vec2 s0 = unpack8(s0_16);
u8vec2 s2 = unpack8(s2_16);
u8vec2 s4 = unpack8(s4_16);
u8vec2 s6 = unpack8(s6_16);
u8vec2 s8 = unpack8(s8_16);
u8vec2 s10 = unpack8(s10_16);
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]);
vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);
vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);
vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);
vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);
vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);
vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
[[unroll]] for (int l = 0; l < 2; ++l) {
sum = fma(FLOAT_TYPE(b0[l]) * FLOAT_TYPE(int8_t(((s0[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)),
fma(FLOAT_TYPE(b32[l]) * FLOAT_TYPE(int8_t(((s2[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)),
fma(FLOAT_TYPE(b64[l]) * FLOAT_TYPE(int8_t(((s4[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)),
fma(FLOAT_TYPE(b96[l]) * FLOAT_TYPE(int8_t(((s6[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)),
fma(FLOAT_TYPE(b16[l]) * FLOAT_TYPE(int8_t(((s0[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)),
fma(FLOAT_TYPE(b48[l]) * FLOAT_TYPE(int8_t(((s2[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)),
fma(FLOAT_TYPE(b80[l]) * FLOAT_TYPE(int8_t(((s4[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)),
fma(FLOAT_TYPE(b112[l]) * FLOAT_TYPE(int8_t(((s6[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum))))))));
}
temp[j][n] = fma(d, sum, temp[j][n]);
}
}
}
const uint nbr_par_th = num_blocks_per_row%it_size;
const uint nbr_all_th = num_blocks_per_row - nbr_par_th;
uint i0 = 0;
[[unroll]] for (; i0 < nbr_all_th; i0 += it_size)
calc_superblock(a_offset, b_offset, ix, itid8, v_im, v_im4, v_in, hm_m, q_offset, y_offset, s_shift, i0 + ix, num_blocks_per_row, first_row, num_rows, true);
calc_superblock(a_offset, b_offset, ix, itid8, v_im, v_im4, v_in, hm_m, q_offset, y_offset, s_shift, i0 + ix, num_blocks_per_row, first_row, num_rows, false);
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
@@ -6,6 +6,86 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
const uint y1_idx = i * QUANT_K + y_offset;
const uint y2_idx = y1_idx + 128;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32;
const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2;
const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F));
const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h));
const FLOAT_TYPE sc0 = scale_0_4_l_f.x;
const FLOAT_TYPE sc1 = scale_0_4_l_f.y;
const FLOAT_TYPE sc2 = scale_0_4_l_f.z;
const FLOAT_TYPE sc3 = scale_0_4_l_f.w;
const FLOAT_TYPE sc4 = scale8_f.x;
const FLOAT_TYPE sc5 = scale8_f.y;
const FLOAT_TYPE sc6 = scale8_f.z;
const FLOAT_TYPE sc7 = scale8_f.w;
const uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4];
const uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16];
const uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F;
const uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F;
const uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F;
const uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F;
const vec4 qs0_lo4 = vec4(unpack8(qs0_u32_lo4));
const vec4 qs64_lo4 = vec4(unpack8(qs64_u32_lo4));
const vec4 qs0_hi4 = vec4(unpack8(qs0_u32_hi4));
const vec4 qs64_hi4 = vec4(unpack8(qs64_u32_hi4));
const FLOAT_TYPE q4_0 = qs0_lo4.x;
const FLOAT_TYPE q4_1 = qs0_lo4.y;
const FLOAT_TYPE q4_2 = qs0_lo4.z;
const FLOAT_TYPE q4_3 = qs0_lo4.w;
const FLOAT_TYPE q4_4 = qs0_hi4.x;
const FLOAT_TYPE q4_5 = qs0_hi4.y;
const FLOAT_TYPE q4_6 = qs0_hi4.z;
const FLOAT_TYPE q4_7 = qs0_hi4.w;
const FLOAT_TYPE q4_8 = qs64_lo4.x;
const FLOAT_TYPE q4_9 = qs64_lo4.y;
const FLOAT_TYPE q4_10 = qs64_lo4.z;
const FLOAT_TYPE q4_11 = qs64_lo4.w;
const FLOAT_TYPE q4_12 = qs64_hi4.x;
const FLOAT_TYPE q4_13 = qs64_hi4.y;
const FLOAT_TYPE q4_14 = qs64_hi4.z;
const FLOAT_TYPE q4_15 = qs64_hi4.w;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec4 by10 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 ]);
vec4 by132 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8]);
vec4 by20 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 ]);
vec4 by232 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8]);
const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3)));
const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7)));
const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x), q4_8, fma(FLOAT_TYPE(by20.y), q4_9, fma(FLOAT_TYPE(by20.z), q4_10, FLOAT_TYPE(by20.w) * q4_11)));
const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x), q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15)));
const FLOAT_TYPE smin =
fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7,
fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7)))))))))))))));
temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
}
}
}
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
@@ -15,13 +95,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
// 16 threads are used to process each block
const uint it_size = gl_WorkGroupSize.x/16;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid%16; // 0...16
const uint ix = tid/16;
const uint itid = tid%16; // 0...15
const uint ix = tid/16;
const uint step = 4;
const uint il = itid/step; // 0...3
const uint ir = itid - step*il; // 0...7 or 0...3
const uint il = itid/4; // 0...3
const uint ir = itid - 4*il; // 0...3
const uint n = 4;
const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
@@ -31,89 +109,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint q_offset = 32*v_im + l0;
const uint y_offset = 64*v_im + l0;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
const uint y1_idx = i * QUANT_K + y_offset;
const uint y2_idx = y1_idx + 128;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
uvec4 scale0 = uvec4(unpack8(scale0_u32));
uvec4 scale4 = uvec4(unpack8(scale4_u32));
uvec4 scale8 = uvec4(unpack8(scale8_u32));
const uint32_t sc0 = ( scale0.x & 0x3f);
const uint32_t sc1 = ( scale0.y & 0x3f);
const uint32_t sc2 = ( scale4.x & 0x3f);
const uint32_t sc3 = ( scale4.y & 0x3f);
const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2));
const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2));
const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2));
const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2));
uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4];
uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16];
uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F;
uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F;
uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F;
uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F;
uvec4 qs0_lo4 = uvec4(unpack8(qs0_u32_lo4));
uvec4 qs64_lo4 = uvec4(unpack8(qs64_u32_lo4));
uvec4 qs0_hi4 = uvec4(unpack8(qs0_u32_hi4));
uvec4 qs64_hi4 = uvec4(unpack8(qs64_u32_hi4));
const uint32_t q4_0 = qs0_lo4.x;
const uint32_t q4_1 = qs0_lo4.y;
const uint32_t q4_2 = qs0_lo4.z;
const uint32_t q4_3 = qs0_lo4.w;
const uint32_t q4_4 = qs0_hi4.x;
const uint32_t q4_5 = qs0_hi4.y;
const uint32_t q4_6 = qs0_hi4.z;
const uint32_t q4_7 = qs0_hi4.w;
const uint32_t q4_8 = qs64_lo4.x;
const uint32_t q4_9 = qs64_lo4.y;
const uint32_t q4_10 = qs64_lo4.z;
const uint32_t q4_11 = qs64_lo4.w;
const uint32_t q4_12 = qs64_hi4.x;
const uint32_t q4_13 = qs64_hi4.y;
const uint32_t q4_14 = qs64_hi4.z;
const uint32_t q4_15 = qs64_hi4.w;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec4 by10 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 ]);
vec4 by132 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8]);
vec4 by20 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 ]);
vec4 by232 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8]);
const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3)));
const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7)));
const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x), q4_8, fma(FLOAT_TYPE(by20.y), q4_9, fma(FLOAT_TYPE(by20.z), q4_10, FLOAT_TYPE(by20.w) * q4_11)));
const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x), q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15)));
const FLOAT_TYPE smin =
fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7,
fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7)))))))))))))));
temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
}
}
}
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size)
calc_superblock(a_offset, b_offset, v_im, q_offset, y_offset, i, num_blocks_per_row, first_row, num_rows);
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
@@ -6,6 +6,118 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint l0, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
const uint y1_idx = i * QUANT_K + y_offset;
const uint y2_idx = y1_idx + 128;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32;
const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2;
const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F));
const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h));
const FLOAT_TYPE sc0 = scale_0_4_l_f.x;
const FLOAT_TYPE sc1 = scale_0_4_l_f.y;
const FLOAT_TYPE sc2 = scale_0_4_l_f.z;
const FLOAT_TYPE sc3 = scale_0_4_l_f.w;
const FLOAT_TYPE sc4 = scale8_f.x;
const FLOAT_TYPE sc5 = scale8_f.y;
const FLOAT_TYPE sc6 = scale8_f.z;
const FLOAT_TYPE sc7 = scale8_f.w;
const uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
const uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16);
uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F;
uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F;
uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F;
uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F;
const uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8]));
const uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4;
const uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3;
const uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010);
const uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1;
qs0_16_u32_lo4 += qs0_16_lo4_offset16;
qs0_16_u32_hi4 += qs0_16_hi4_offset16;
qs64_80_u32_lo4 += qs64_80_lo4_offset16;
qs64_80_u32_hi4 += qs64_80_hi4_offset16;
const vec4 qs0_16_lo4 = vec4(unpack8(qs0_16_u32_lo4));
const vec4 qs64_80_lo4 = vec4(unpack8(qs64_80_u32_lo4));
const vec4 qs0_16_hi4 = vec4(unpack8(qs0_16_u32_hi4));
const vec4 qs64_80_hi4 = vec4(unpack8(qs64_80_u32_hi4));
const FLOAT_TYPE q4_0 = qs0_16_lo4.x;
const FLOAT_TYPE q4_1 = qs0_16_lo4.y;
const FLOAT_TYPE q4_2 = qs0_16_lo4.z;
const FLOAT_TYPE q4_3 = qs0_16_lo4.w;
const FLOAT_TYPE q4_4 = qs0_16_hi4.x;
const FLOAT_TYPE q4_5 = qs0_16_hi4.y;
const FLOAT_TYPE q4_6 = qs0_16_hi4.z;
const FLOAT_TYPE q4_7 = qs0_16_hi4.w;
const FLOAT_TYPE q4_8 = qs64_80_lo4.x;
const FLOAT_TYPE q4_9 = qs64_80_lo4.y;
const FLOAT_TYPE q4_10 = qs64_80_lo4.z;
const FLOAT_TYPE q4_11 = qs64_80_lo4.w;
const FLOAT_TYPE q4_12 = qs64_80_hi4.x;
const FLOAT_TYPE q4_13 = qs64_80_hi4.y;
const FLOAT_TYPE q4_14 = qs64_80_hi4.z;
const FLOAT_TYPE q4_15 = qs64_80_hi4.w;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec2 by10 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 ]);
vec2 by116 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8]);
vec2 by132 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16]);
vec2 by148 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24]);
vec2 by20 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 ]);
vec2 by216 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8]);
vec2 by232 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16]);
vec2 by248 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24]);
const FLOAT_TYPE sx =
fma(FLOAT_TYPE(by10.x), q4_0,
fma(FLOAT_TYPE(by10.y), q4_1,
fma(FLOAT_TYPE(by116.x), q4_2,
FLOAT_TYPE(by116.y) * q4_3)));
const FLOAT_TYPE sy =
fma(FLOAT_TYPE(by132.x), q4_4,
fma(FLOAT_TYPE(by132.y), q4_5,
fma(FLOAT_TYPE(by148.x), q4_6,
FLOAT_TYPE(by148.y) * q4_7)));
const FLOAT_TYPE sz =
fma(FLOAT_TYPE(by20.x), q4_8,
fma(FLOAT_TYPE(by20.y), q4_9,
fma(FLOAT_TYPE(by216.x), q4_10,
FLOAT_TYPE(by216.y) * q4_11)));
const FLOAT_TYPE sw =
fma(FLOAT_TYPE(by232.x), q4_12,
fma(FLOAT_TYPE(by232.y), q4_13,
fma(FLOAT_TYPE(by248.x), q4_14,
FLOAT_TYPE(by248.y) * q4_15)));
const FLOAT_TYPE smin =
fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2,
fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
(FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
}
}
}
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
@@ -15,11 +127,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
// 16 threads are used to process each block
const uint it_size = gl_WorkGroupSize.x/16;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid%16; // 0...16
const uint ix = tid/16;
const uint itid = tid%16; // 0...15
const uint ix = tid/16;
const uint il = itid/4; // 0...3
const uint ir = itid - 4*il; // 0...7 or 0...3
const uint ir = itid - 4*il; // 0...3
const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
const uint v_in = il % 2;
@@ -28,121 +140,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint q_offset = 32*v_im + l0;
const uint y_offset = 64*v_im + l0;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
const uint y1_idx = i * QUANT_K + y_offset;
const uint y2_idx = y1_idx + 128;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
uvec4 scale0 = uvec4(unpack8(scale0_u32));
uvec4 scale4 = uvec4(unpack8(scale4_u32));
uvec4 scale8 = uvec4(unpack8(scale8_u32));
const uint32_t sc0 = ( scale0.x & 0x3f);
const uint32_t sc1 = ( scale0.y & 0x3f);
const uint32_t sc2 = ( scale4.x & 0x3f);
const uint32_t sc3 = ( scale4.y & 0x3f);
const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2));
const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2));
const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2));
const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2));
uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16);
uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F;
uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F;
uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F;
uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F;
uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8]));
uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4;
uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3;
uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010) << 0;
uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1;
qs0_16_u32_lo4 += qs0_16_lo4_offset16;
qs0_16_u32_hi4 += qs0_16_hi4_offset16;
qs64_80_u32_lo4 += qs64_80_lo4_offset16;
qs64_80_u32_hi4 += qs64_80_hi4_offset16;
uvec4 qs0_16_lo4 = uvec4(unpack8(qs0_16_u32_lo4));
uvec4 qs64_80_lo4 = uvec4(unpack8(qs64_80_u32_lo4));
uvec4 qs0_16_hi4 = uvec4(unpack8(qs0_16_u32_hi4));
uvec4 qs64_80_hi4 = uvec4(unpack8(qs64_80_u32_hi4));
const uint32_t q4_0 = qs0_16_lo4.x;
const uint32_t q4_1 = qs0_16_lo4.y;
const uint32_t q4_2 = qs0_16_lo4.z;
const uint32_t q4_3 = qs0_16_lo4.w;
const uint32_t q4_4 = qs0_16_hi4.x;
const uint32_t q4_5 = qs0_16_hi4.y;
const uint32_t q4_6 = qs0_16_hi4.z;
const uint32_t q4_7 = qs0_16_hi4.w;
const uint32_t q4_8 = qs64_80_lo4.x;
const uint32_t q4_9 = qs64_80_lo4.y;
const uint32_t q4_10 = qs64_80_lo4.z;
const uint32_t q4_11 = qs64_80_lo4.w;
const uint32_t q4_12 = qs64_80_hi4.x;
const uint32_t q4_13 = qs64_80_hi4.y;
const uint32_t q4_14 = qs64_80_hi4.z;
const uint32_t q4_15 = qs64_80_hi4.w;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec2 by10 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 ]);
vec2 by116 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8]);
vec2 by132 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16]);
vec2 by148 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24]);
vec2 by20 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 ]);
vec2 by216 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8]);
vec2 by232 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16]);
vec2 by248 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24]);
const FLOAT_TYPE sx =
fma(FLOAT_TYPE(by10.x), q4_0,
fma(FLOAT_TYPE(by10.y), q4_1,
fma(FLOAT_TYPE(by116.x), q4_2,
FLOAT_TYPE(by116.y) * q4_3)));
const FLOAT_TYPE sy =
fma(FLOAT_TYPE(by132.x), q4_4,
fma(FLOAT_TYPE(by132.y), q4_5,
fma(FLOAT_TYPE(by148.x), q4_6,
FLOAT_TYPE(by148.y) * q4_7)));
const FLOAT_TYPE sz =
fma(FLOAT_TYPE(by20.x), q4_8,
fma(FLOAT_TYPE(by20.y), q4_9,
fma(FLOAT_TYPE(by216.x), q4_10,
FLOAT_TYPE(by216.y) * q4_11)));
const FLOAT_TYPE sw =
fma(FLOAT_TYPE(by232.x), q4_12,
fma(FLOAT_TYPE(by232.y), q4_13,
fma(FLOAT_TYPE(by248.x), q4_14,
FLOAT_TYPE(by248.y) * q4_15)));
const FLOAT_TYPE smin =
fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2,
fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
(FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
}
}
}
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size)
calc_superblock(a_offset, b_offset, v_im, l0, q_offset, y_offset, i, num_blocks_per_row, first_row, num_rows);
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
@@ -6,7 +6,77 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
shared FLOAT_TYPE sccache[BLOCK_SIZE/16][16];
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint ix, const uint ql_offset, const uint qh_offset, const uint s_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
const uint y_idx = i * QUANT_K + y_offset;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
if (!all_threads) { // when we don't have enough blocks to use all threads
barrier();
if (i < num_blocks_per_row)
sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
barrier();
if (i >= num_blocks_per_row)
continue;
}
const uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16);
const uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16);
const uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F;
const uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F;
const uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F;
const uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F;
const uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16);
const uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4;
const uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2;
const uint32_t qh4_u32 = (qh_u32 & 0x30303030);
const uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2;
const uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32;
const uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32;
const uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32;
const uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32;
const vec4 q0 = vec4(unpack8(q0_u32)) - 32;
const vec4 q1 = vec4(unpack8(q1_u32)) - 32;
const vec4 q2 = vec4(unpack8(q2_u32)) - 32;
const vec4 q3 = vec4(unpack8(q3_u32)) - 32;
if (all_threads) {
barrier();
sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
barrier();
}
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec4 by0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 ]);
vec4 by32 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8]);
vec4 by64 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]);
vec4 by96 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]);
FLOAT_TYPE sum[4] = {0, 0, 0, 0};
[[unroll]] for (uint l = 0; l < 4; ++l) {
sum[0] = fma(FLOAT_TYPE(by0[l]), q0[l], sum[0]);
sum[1] = fma(FLOAT_TYPE(by32[l]), q1[l], sum[1]);
sum[2] = fma(FLOAT_TYPE(by64[l]), q2[l], sum[2]);
sum[3] = fma(FLOAT_TYPE(by96[l]), q3[l], sum[3]);
}
temp[j][n] = fma(fma(sum[0], sccache[ix][s_offset], fma(sum[1], sccache[ix][s_offset + 2], fma(sum[2], sccache[ix][s_offset + 4], sum[3] * sccache[ix][s_offset + 6]))), d, temp[j][n]);
}
}
}
void compute_outputs(const uint first_row, const uint num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
@@ -15,13 +85,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
// 16 threads are used to process each block
const uint it_size = gl_WorkGroupSize.x/16;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid%16; // 0...16
const uint ix = tid/16;
const uint itid = tid%16; // 0...15
const uint ix = tid/16;
const uint step = 8;
const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const uint v_in = itid - step*v_im; // 0...15 or 0...7
const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128...
const uint v_in = itid - 8*v_im; // 0...7
const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28
const uint is = v_in / 4;
@@ -31,68 +99,18 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint s_offset = 8*v_im + is;
const uint y_offset = 128*v_im + l0;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
const uint y_idx = i * QUANT_K + y_offset;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
FLOAT_TYPE scales[4];
scales[0] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]);
scales[1] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]);
scales[2] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]);
scales[3] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]);
uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16);
uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16);
uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F;
uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F;
uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F;
uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F;
uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16);
uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4;
uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2;
uint32_t qh4_u32 = (qh_u32 & 0x30303030) << 0;
uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2;
uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32;
uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32;
uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32;
uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32;
uvec4 q0 = uvec4(unpack8(q0_u32));
uvec4 q1 = uvec4(unpack8(q1_u32));
uvec4 q2 = uvec4(unpack8(q2_u32));
uvec4 q3 = uvec4(unpack8(q3_u32));
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec4 by0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 ]);
vec4 by32 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8]);
vec4 by64 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]);
vec4 by96 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]);
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
[[unroll]] for (int l = 0; l < 4; ++l) {
sum = fma(FLOAT_TYPE(by0[l]) * scales[0], FLOAT_TYPE(int8_t(q0[l]) - 32),
fma(FLOAT_TYPE(by32[l]) * scales[1], FLOAT_TYPE(int8_t(q1[l]) - 32),
fma(FLOAT_TYPE(by64[l]) * scales[2], FLOAT_TYPE(int8_t(q2[l]) - 32),
fma(FLOAT_TYPE(by96[l]) * scales[3], FLOAT_TYPE(int8_t(q3[l]) - 32), sum))));
}
temp[j][n] += sum * d;
}
}
}
const uint nbr_par_th = num_blocks_per_row%it_size;
const uint nbr_all_th = num_blocks_per_row - nbr_par_th;
uint i0 = 0;
[[unroll]] for (; i0 < nbr_all_th; i0 += it_size)
calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true);
calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false);
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
@@ -227,6 +227,11 @@ struct block_q4_K_packed32
uint32_t qs[QUANT_K_Q4_K/2/4];
};
struct block_q4_K_packed128
{
uvec4 q4k[9];
};
#if defined(DATA_A_Q4_K)
#define QUANT_K QUANT_K_Q4_K
#define A_TYPE block_q4_K
@@ -252,6 +257,11 @@ struct block_q5_K_packed16
uint16_t qs[QUANT_K_Q5_K/2/2];
};
struct block_q5_K_packed128
{
uvec4 q5k[11];
};
#if defined(DATA_A_Q5_K)
#define QUANT_K QUANT_K_Q5_K
#define A_TYPE block_q5_K
@@ -30,8 +30,6 @@
#include <fcntl.h>
#endif
#include <vulkan/vulkan_core.h>
#define ASYNCIO_CONCURRENCY 64
std::mutex lock;
@@ -419,6 +417,11 @@ void process_shaders() {
string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
}
string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
+64 -41
View File
@@ -3450,12 +3450,14 @@ struct ggml_tensor * ggml_soft_max_ext(
return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
}
// ggml_soft_max_back
// ggml_soft_max_ext_back
static struct ggml_tensor * ggml_soft_max_back_impl(
static struct ggml_tensor * ggml_soft_max_ext_back_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
float scale,
float max_bias,
bool inplace) {
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
@@ -3463,21 +3465,28 @@ static struct ggml_tensor * ggml_soft_max_back_impl(
result->src[0] = a;
result->src[1] = b;
memcpy((float *) result->op_params + 0, &scale, sizeof(float));
memcpy((float *) result->op_params + 1, &max_bias, sizeof(float));
return result;
}
struct ggml_tensor * ggml_soft_max_back(
struct ggml_tensor * ggml_soft_max_ext_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b) {
return ggml_soft_max_back_impl(ctx, a, b, false);
struct ggml_tensor * b,
float scale,
float max_bias) {
return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, false);
}
struct ggml_tensor * ggml_soft_max_back_inplace(
struct ggml_tensor * ggml_soft_max_ext_back_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b) {
return ggml_soft_max_back_impl(ctx, a, b, true);
struct ggml_tensor * b,
float scale,
float max_bias) {
return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, true);
}
// ggml_rope
@@ -3695,7 +3704,7 @@ void ggml_rope_yarn_corr_dims(
// ggml_rope_back
struct ggml_tensor * ggml_rope_back(
struct ggml_tensor * ggml_rope_ext_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
@@ -3709,29 +3718,32 @@ struct ggml_tensor * ggml_rope_back(
float attn_factor,
float beta_fast,
float beta_slow) {
GGML_ASSERT(ggml_is_vector(b));
GGML_ASSERT(b->type == GGML_TYPE_I32);
GGML_ASSERT(a->ne[2] == b->ne[0]);
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
memcpy(params + 5, &freq_base, sizeof(float));
memcpy(params + 6, &freq_scale, sizeof(float));
memcpy(params + 7, &ext_factor, sizeof(float));
memcpy(params + 8, &attn_factor, sizeof(float));
memcpy(params + 9, &beta_fast, sizeof(float));
memcpy(params + 10, &beta_slow, sizeof(float));
ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_ROPE_BACK;
result->src[0] = a;
result->src[1] = b;
result->src[2] = c;
struct ggml_tensor * result = ggml_rope_ext(
ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
result->op = GGML_OP_ROPE_BACK;
return result;
}
struct ggml_tensor * ggml_rope_multi_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
int n_dims,
int sections[4],
int mode,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow) {
struct ggml_tensor * result = ggml_rope_multi(
ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
result->op = GGML_OP_ROPE_BACK;
return result;
}
// ggml_clamp
struct ggml_tensor * ggml_clamp(
@@ -5073,10 +5085,10 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c) {
GGML_ASSERT(ggml_are_same_shape(a, b));
GGML_ASSERT(ggml_is_scalar(c));
GGML_ASSERT(ggml_is_scalar(a));
GGML_ASSERT(ggml_are_same_shape(b, c));
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
struct ggml_tensor * result = ggml_dup_tensor(ctx, b);
result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
result->src[0] = a;
@@ -5255,7 +5267,7 @@ static void ggml_sub_or_set(
}
static void ggml_compute_backward(
struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, bool * grads_needed) {
struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, const bool * grads_needed) {
struct ggml_tensor * tensor = cgraph->nodes[i];
struct ggml_tensor * grad = ggml_graph_get_grad(cgraph, tensor);
@@ -5399,7 +5411,7 @@ static void ggml_compute_backward(
if (src0_needs_grads) {
float eps;
memcpy(&eps, tensor->op_params, sizeof(float));
ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, src0, grad, eps));
ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, grad, src0, eps));
}
} break;
case GGML_OP_MUL_MAT: {
@@ -5582,7 +5594,13 @@ static void ggml_compute_backward(
} break;
case GGML_OP_SOFT_MAX: {
if (src0_needs_grads) {
ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_back(ctx, grad, tensor));
float scale = 1.0f;
float max_bias = 0.0f;
memcpy(&scale, (const float *) tensor->op_params + 0, sizeof(float));
memcpy(&max_bias, (const float *) tensor->op_params + 1, sizeof(float));
ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_ext_back(ctx, grad, tensor, scale, max_bias));
}
GGML_ASSERT((!src1 || !src1_needs_grads) && "backward pass for softmax mask not implemented");
} break;
@@ -5594,6 +5612,7 @@ static void ggml_compute_backward(
//const int n_ctx = ((int32_t *) tensor->op_params)[3];
const int n_ctx_orig = ((const int32_t *) tensor->op_params)[4];
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
int sections[4] = {0, 0, 0, 0};
memcpy(&freq_base, (const float *) tensor->op_params + 5, sizeof(float));
memcpy(&freq_scale, (const float *) tensor->op_params + 6, sizeof(float));
@@ -5601,10 +5620,14 @@ static void ggml_compute_backward(
memcpy(&attn_factor, (const float *) tensor->op_params + 8, sizeof(float));
memcpy(&beta_fast, (const float *) tensor->op_params + 9, sizeof(float));
memcpy(&beta_slow, (const float *) tensor->op_params + 10, sizeof(float));
memcpy(&sections, tensor->op_params + 11, sizeof(sections));
ggml_add_or_set(ctx, cgraph, isrc0,
ggml_rope_back(ctx, grad, src1, src2, n_dims, mode, n_ctx_orig, freq_base,
freq_scale, ext_factor, attn_factor, beta_fast, beta_slow));
struct ggml_tensor * rope_back = grad->ne[2] == src1->ne[0] ?
ggml_rope_ext_back(ctx, grad, src1, src2, n_dims,
mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow) :
ggml_rope_multi_back(ctx, grad, src1, src2, n_dims, sections,
mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
ggml_add_or_set(ctx, cgraph, isrc0, rope_back);
}
GGML_ASSERT((!src2 || !src2_needs_grads) && "gradients for freq factors not implemented");
} break;
@@ -5618,7 +5641,7 @@ static void ggml_compute_backward(
const int32_t d1 = ggml_get_op_params_i32(tensor, 5);
const bool is_2D = ggml_get_op_params_i32(tensor, 6) == 1;
ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, src0, grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D));
ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, grad, src0, src1->ne, s0, s1, p0, p1, d0, d1, is_2D));
}
} break;
case GGML_OP_POOL_2D: {
@@ -5661,7 +5684,7 @@ static void ggml_compute_backward(
} break;
case GGML_UNARY_OP_SILU: {
if (src0_needs_grads) {
ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, src0, grad));
ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, grad, src0));
}
} break;
case GGML_UNARY_OP_EXP: {
@@ -5678,7 +5701,7 @@ static void ggml_compute_backward(
} break;
case GGML_OP_CROSS_ENTROPY_LOSS: {
if (src0_needs_grads) {
ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, src0, src1, grad));
ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, grad, src0, src1));
}
GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
} break;
+11 -4
View File
@@ -288,9 +288,6 @@ extern "C" {
// proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices()
const float * tensor_split;
// comma separated list of RPC servers to use for offloading
const char * rpc_servers;
// Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
// If the provided progress_callback returns true, model loading continues.
// If it returns false, model loading is immediately aborted.
@@ -418,10 +415,20 @@ extern "C" {
struct llama_model_params params),
"use llama_model_load_from_file instead");
// Load the model from a file
// If the file is split into multiple parts, the file name must follow this pattern: <name>-%05d-of-%05d.gguf
// If the split file name does not follow this pattern, use llama_model_load_from_splits
LLAMA_API struct llama_model * llama_model_load_from_file(
const char * path_model,
struct llama_model_params params);
// Load the model from multiple splits (support custom naming scheme)
// The paths must be in the correct order
LLAMA_API struct llama_model * llama_model_load_from_splits(
const char ** paths,
size_t n_paths,
struct llama_model_params params);
DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model),
"use llama_model_free instead");
@@ -951,7 +958,7 @@ extern "C" {
LLAMA_API llama_token llama_vocab_fim_rep(const struct llama_vocab * vocab);
LLAMA_API llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab);
DEPRECATED(LLAMA_API const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token), "use llama_vocabable_get_text instead");
DEPRECATED(LLAMA_API const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_text instead");
DEPRECATED(LLAMA_API float llama_token_get_score(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_score instead");
DEPRECATED(LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_attr instead");
DEPRECATED(LLAMA_API bool llama_token_is_eog(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_is_eog instead");
-112
View File
@@ -1,112 +0,0 @@
#!/bin/bash
#
# Shortcut for downloading HF models
#
# Usage:
# ./llama-cli -m $(./scripts/hf.sh https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q4_K_M.gguf)
# ./llama-cli -m $(./scripts/hf.sh --url https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/blob/main/mixtral-8x7b-v0.1.Q4_K_M.gguf)
# ./llama-cli -m $(./scripts/hf.sh --repo TheBloke/Mixtral-8x7B-v0.1-GGUF --file mixtral-8x7b-v0.1.Q4_K_M.gguf)
#
# all logs go to stderr
function log {
echo "$@" 1>&2
}
function usage {
log "Usage: $0 [[--url] <url>] [--repo <repo>] [--file <file>] [--outdir <dir> [-h|--help]"
exit 1
}
# check for curl or wget
function has_cmd {
if ! [ -x "$(command -v $1)" ]; then
return 1
fi
}
if has_cmd wget; then
cmd="wget -q -c -O %s/%s %s"
elif has_cmd curl; then
cmd="curl -C - -f --output-dir %s -o %s -L %s"
else
log "[E] curl or wget not found"
exit 1
fi
url=""
repo=""
file=""
outdir="."
# parse args
while [[ $# -gt 0 ]]; do
case "$1" in
--url)
url="$2"
shift 2
;;
--repo)
repo="$2"
shift 2
;;
--file)
file="$2"
shift 2
;;
--outdir)
outdir="$2"
shift 2
;;
-h|--help)
usage
;;
*)
url="$1"
shift
;;
esac
done
if [ -n "$repo" ] && [ -n "$file" ]; then
url="https://huggingface.co/$repo/resolve/main/$file"
fi
if [ -z "$url" ]; then
log "[E] missing --url"
usage
fi
# check if the URL is a HuggingFace model, and if so, try to download it
is_url=false
if [[ ${#url} -gt 22 ]]; then
if [[ ${url:0:22} == "https://huggingface.co" ]]; then
is_url=true
fi
fi
if [ "$is_url" = false ]; then
log "[E] invalid URL, must start with https://huggingface.co"
exit 0
fi
# replace "blob/main" with "resolve/main"
url=${url/blob\/main/resolve\/main}
basename=$(basename $url)
log "[+] attempting to download $basename"
if [ -n "$cmd" ]; then
cmd=$(printf "$cmd" "$outdir" "$basename" "$url")
log "[+] $cmd"
if $cmd; then
echo $outdir/$basename
exit 0
fi
fi
log "[-] failed to download"
exit 1
+9 -1
View File
@@ -73,6 +73,7 @@ while read c; do
src/ggml*.h \
src/ggml*.c \
src/ggml*.cpp \
src/gguf*.cpp \
src/ggml-blas/* \
src/ggml-cann/* \
src/ggml-cpu/* \
@@ -81,10 +82,12 @@ while read c; do
src/ggml-kompute/* \
src/ggml-metal/* \
src/ggml-musa/* \
src/ggml-opencl/* \
src/ggml-rpc/* \
src/ggml-sycl/* \
src/ggml-vulkan/* \
include/ggml*.h \
include/gguf*.h \
tests/test-opt.cpp \
tests/test-quantize-fns.cpp \
tests/test-quantize-perf.cpp \
@@ -123,6 +126,7 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then
# src/ggml*.c -> ggml/src/ggml*.c
# src/ggml*.cpp -> ggml/src/ggml*.cpp
# src/ggml*.h -> ggml/src/ggml*.h
# src/gguf*.cpp -> ggml/src/gguf*.cpp
# src/ggml-blas/* -> ggml/src/ggml-blas/*
# src/ggml-cann/* -> ggml/src/ggml-cann/*
# src/ggml-cpu/* -> ggml/src/ggml-cpu/*
@@ -131,11 +135,13 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then
# src/ggml-kompute/* -> ggml/src/ggml-kompute/*
# src/ggml-metal/* -> ggml/src/ggml-metal/*
# src/ggml-musa/* -> ggml/src/ggml-musa/*
# src/ggml-opencl/* -> ggml/src/ggml-opencl/*
# src/ggml-rpc/* -> ggml/src/ggml-rpc/*
# src/ggml-sycl/* -> ggml/src/ggml-sycl/*
# src/ggml-vulkan/* -> ggml/src/ggml-vulkan/*
#
# include/ggml*.h -> ggml/include/ggml*.h
# include/gguf*.h -> ggml/include/gguf*.h
#
# tests/test*.cpp -> tests/
#
@@ -149,6 +155,7 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then
-e 's/([[:space:]]|[ab]\/)src\/ggml(.*)\.c/\1ggml\/src\/ggml\2.c/g' \
-e 's/([[:space:]]|[ab]\/)src\/ggml(.*)\.cpp/\1ggml\/src\/ggml\2.cpp/g' \
-e 's/([[:space:]]|[ab]\/)src\/ggml(.*)\.h/\1ggml\/src\/ggml\2.h/g' \
-e 's/([[:space:]]|[ab]\/)src\/gguf(.*)\.cpp/\1ggml\/src\/gguf\2.cpp/g' \
-e 's/([[:space:]]|[ab]\/)src\/ggml-blas\//\1ggml\/src\/ggml-blas\//g' \
-e 's/([[:space:]]|[ab]\/)src\/ggml-cann\//\1ggml\/src\/ggml-cann\//g' \
-e 's/([[:space:]]|[ab]\/)src\/ggml-cpu\//\1ggml\/src\/ggml-cpu\//g' \
@@ -156,11 +163,12 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then
-e 's/([[:space:]]|[ab]\/)src\/ggml-hip\//\1ggml\/src\/ggml-hip\//g' \
-e 's/([[:space:]]|[ab]\/)src\/ggml-kompute\//\1ggml\/src\/ggml-kompute\//g' \
-e 's/([[:space:]]|[ab]\/)src\/ggml-metal\//\1ggml\/src\/ggml-metal\//g' \
-e 's/([[:space:]]|[ab]\/)src\/ggml-musa\//\1ggml\/src\/ggml-musa\//g' \
-e 's/([[:space:]]|[ab]\/)src\/ggml-opencl\//\1ggml\/src\/ggml-opencl\//g' \
-e 's/([[:space:]]|[ab]\/)src\/ggml-rpc\//\1ggml\/src\/ggml-rpc\//g' \
-e 's/([[:space:]]|[ab]\/)src\/ggml-sycl\//\1ggml\/src\/ggml-sycl\//g' \
-e 's/([[:space:]]|[ab]\/)src\/ggml-vulkan\//\1ggml\/src\/ggml-vulkan\//g' \
-e 's/([[:space:]]|[ab]\/)include\/ggml(.*)\.h/\1ggml\/include\/ggml\2.h/g' \
-e 's/([[:space:]]|[ab]\/)include\/gguf(.*)\.h/\1ggml\/include\/gguf\2.h/g' \
-e 's/([[:space:]]|[ab]\/)tests\/(.*)\.cpp/\1tests\/\2.cpp/g' \
-e 's/([[:space:]]|[ab]\/)LICENSE/\1LICENSE/g' \
-e 's/([[:space:]]|[ab]\/)scripts\/gen-authors\.sh/\1scripts\/gen-authors.sh/g' \
+1 -1
View File
@@ -1 +1 @@
c8bd0fee71dc8328d93be301bbee06bc10d30429
d92321c0d151fe73a47d89738c7c3091ac904297
+3
View File
@@ -7,6 +7,7 @@ cp -rpv ../ggml/cmake/FindSIMD.cmake ./ggml/cmake/FindSIMD.cmake
cp -rpv ../ggml/src/ggml*.c ./ggml/src/
cp -rpv ../ggml/src/ggml*.cpp ./ggml/src/
cp -rpv ../ggml/src/ggml*.h ./ggml/src/
cp -rpv ../ggml/src/gguf*.cpp ./ggml/src/
cp -rpv ../ggml/src/ggml-blas/* ./ggml/src/ggml-blas/
cp -rpv ../ggml/src/ggml-cann/* ./ggml/src/ggml-cann/
cp -rpv ../ggml/src/ggml-cpu/* ./ggml/src/ggml-cpu/
@@ -15,11 +16,13 @@ cp -rpv ../ggml/src/ggml-hip/* ./ggml/src/ggml-hip/
cp -rpv ../ggml/src/ggml-kompute/* ./ggml/src/ggml-kompute/
cp -rpv ../ggml/src/ggml-metal/* ./ggml/src/ggml-metal/
cp -rpv ../ggml/src/ggml-musa/* ./ggml/src/ggml-musa/
cp -rpv ../ggml/src/ggml-opencl/* ./ggml/src/ggml-opencl/
cp -rpv ../ggml/src/ggml-rpc/* ./ggml/src/ggml-rpc/
cp -rpv ../ggml/src/ggml-sycl/* ./ggml/src/ggml-sycl/
cp -rpv ../ggml/src/ggml-vulkan/* ./ggml/src/ggml-vulkan/
cp -rpv ../ggml/include/ggml*.h ./ggml/include/
cp -rpv ../ggml/include/gguf*.h ./ggml/include/
cp -rpv ../ggml/tests/test-opt.cpp ./tests/test-opt.cpp
cp -rpv ../ggml/tests/test-quantize-fns.cpp ./tests/test-quantize-fns.cpp
+63 -11
View File
@@ -64,6 +64,33 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
}
}
// return a list of splits for a given path
// for example, given "<name>-00002-of-00004.gguf", returns list of all 4 splits
static std::vector<std::string> llama_get_list_splits(const std::string & path, const int idx, const int n_split) {
std::vector<std::string> paths;
std::string split_prefix;
std::vector<char> buf(llama_path_max(), 0);
{
int ret = llama_split_prefix(buf.data(), buf.size(), path.c_str(), idx, n_split);
if (!ret) {
throw std::runtime_error(format("invalid split file name: %s", path.c_str()));
}
split_prefix = std::string(buf.data(), ret);
}
if (split_prefix.empty()) {
throw std::runtime_error(format("invalid split file: %s", path.c_str()));
}
for (int idx = 0; idx < n_split; ++idx) {
int ret = llama_split_path(buf.data(), buf.size(), split_prefix.c_str(), idx, n_split);
paths.push_back(std::string(buf.data(), ret));
}
return paths;
}
namespace GGUFMeta {
template <typename T, gguf_type gt_, T (*gfun)(const gguf_context *, const int64_t)>
struct GKV_Base_Type {
@@ -413,7 +440,12 @@ namespace GGUFMeta {
template bool llama_model_loader::get_key_or_arr<std::array<int, 4>>(enum llm_kv kid, std::array<int, 4> & result, uint32_t n, bool required);
template bool llama_model_loader::get_key_or_arr<std::array<uint32_t, 512>>(enum llm_kv kid, std::array<uint32_t, 512> & result, uint32_t n, bool required);
llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p) {
llama_model_loader::llama_model_loader(
const std::string & fname,
std::vector<std::string> & splits,
bool use_mmap,
bool check_tensors,
const struct llama_model_kv_override * param_overrides_p) {
int trace = 0;
if (getenv("LLAMA_TRACE")) {
trace = atoi(getenv("LLAMA_TRACE"));
@@ -425,6 +457,7 @@ llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap,
}
}
// Load the main GGUF
struct ggml_context * ctx = NULL;
struct gguf_init_params params = {
/*.no_alloc = */ true,
@@ -460,35 +493,54 @@ llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap,
// Load additional GGML contexts
if (n_split > 1) {
// make sure the main file is loaded first
uint16_t idx = 0;
get_key(llm_kv(LLM_KV_SPLIT_NO), idx);
const std::string kv_split_no = llm_kv(LLM_KV_SPLIT_NO);
get_key(kv_split_no, idx);
if (idx != 0) {
throw std::runtime_error(format("illegal split file: %d, model must be loaded with the first split", idx));
throw std::runtime_error(format("illegal split file idx: %d (file: %s), model must be loaded with the first split", idx, fname.c_str()));
}
std::vector<char> split_prefix(llama_path_max(), 0);
if (!llama_split_prefix(split_prefix.data(), split_prefix.size(), fname.c_str(), idx, n_split)) {
throw std::runtime_error(format("invalid split file: %s", fname.c_str()));
// generate list of splits if needed
if (splits.empty()) {
splits = llama_get_list_splits(fname, idx, n_split);
}
// in case user give a custom list of splits, check if it matches the expected number
if (n_split != (uint16_t)splits.size()) {
throw std::runtime_error(format("invalid split count, given: %zu splits, but expected %d", splits.size(), n_split));
}
if (trace > 0) {
LLAMA_LOG_INFO("%s: loading additional %d GGUFs\n", __func__, n_split);
}
std::vector<char> split_path(llama_path_max(), 0);
// load other splits
for (idx = 1; idx < n_split; idx++) {
llama_split_path(split_path.data(), split_path.size(), split_prefix.data(), idx, n_split);
const char * fname_split = splits[idx].c_str();
struct gguf_init_params split_params = {
/*.no_alloc = */ true,
/*.ctx = */ &ctx,
};
gguf_context_ptr ctx_gguf { gguf_init_from_file(split_path.data(), split_params) };
gguf_context_ptr ctx_gguf { gguf_init_from_file(fname_split, split_params) };
if (!ctx_gguf) {
throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, split_path.data()));
throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, fname_split));
}
files.emplace_back(new llama_file(split_path.data(), "rb"));
// check idx
{
const int kid = gguf_find_key(ctx_gguf.get(), kv_split_no.c_str());
if (kid < 0) {
throw std::runtime_error(format("missing key %s in GGUF split %s", kv_split_no.c_str(), fname_split));
}
int idx_gguf = gguf_get_val_u16(ctx_gguf.get(), kid);
if (idx_gguf != idx) {
throw std::runtime_error(format("invalid split file idx: %d (file: %s), expected %d", idx_gguf, fname_split, idx));
}
}
files.emplace_back(new llama_file(fname_split, "rb"));
contexts.emplace_back(ctx);
// Save tensors data offset info of the shard.
+6 -1
View File
@@ -90,7 +90,12 @@ struct llama_model_loader {
size_t size_data = 0;
std::vector<std::pair<size_t, size_t>> mmaps_used;
llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p);
llama_model_loader(
const std::string & fname,
std::vector<std::string> & splits, // optional, only need if the split does not follow naming scheme
bool use_mmap,
bool check_tensors,
const struct llama_model_kv_override * param_overrides_p);
template<typename T>
typename std::enable_if<std::is_integral<T>::value, bool>::type
-1
View File
@@ -3717,7 +3717,6 @@ struct llama_model_params llama_model_default_params() {
/*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER,
/*.main_gpu =*/ 0,
/*.tensor_split =*/ nullptr,
/*.rpc_servers =*/ nullptr,
/*.progress_callback =*/ nullptr,
/*.progress_callback_user_data =*/ nullptr,
/*.kv_overrides =*/ nullptr,
-2
View File
@@ -323,8 +323,6 @@ struct llama_model {
// gguf metadata
std::unordered_map<std::string, std::string> gguf_kv;
std::vector<std::string> rpc_servers;
// list of devices used in this model
std::vector<ggml_backend_dev_t> devices;
+2 -1
View File
@@ -526,7 +526,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
kv_overrides = v->data();
}
llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, kv_overrides);
std::vector<std::string> splits = {};
llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides);
ml.init_mappings(false); // no prefetching
llama_model model(llama_model_default_params());
+5 -4
View File
@@ -439,7 +439,7 @@ struct llm_tokenizer_bpe_session {
"also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
"Are you sure this is what you want?\n", __FUNCTION__);
}
if (vocab.get_add_bos() && output.size() >= 2 && *(output.end()-2) == vocab.token_eos()) {
if (vocab.get_add_eos() && output.size() >= 2 && *(output.end()-2) == vocab.token_eos()) {
LLAMA_LOG_WARN(
"%s: Added a EOS token to the prompt as specified by the model but the prompt "
"also ends with a EOS token. So now the final prompt ends with 2 EOS tokens. "
@@ -1356,8 +1356,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
// read vocab size from metadata
uint32_t n_tokens = 0;
if (!ml.get_key(LLM_KV_VOCAB_SIZE, n_tokens, false)) {
LLAMA_LOG_WARN("%s: there is no vocab_size in metadata\n", __func__);
if (ml.get_key(LLM_KV_VOCAB_SIZE, n_tokens, false)) {
LLAMA_LOG_WARN("%s: adding %u dummy tokens\n", __func__, n_tokens);
id_to_token.resize(n_tokens);
}
return;
@@ -1729,7 +1730,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
continue;
}
if (new_id >= id_to_token.size()) {
LLAMA_LOG_WARN("%s: bad special token: '%s' = %ud, using default id %d\n",
LLAMA_LOG_WARN("%s: bad special token: '%s' = %u, using default id %d\n",
__func__, key.c_str(), new_id, id);
} else {
id = new_id;
+39 -56
View File
@@ -31,7 +31,7 @@
#endif
// Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
static int llama_model_load(const std::string & fname, llama_model & model, llama_model_params & params) {
static int llama_model_load(const std::string & fname, std::vector<std::string> & splits, llama_model & model, llama_model_params & params) {
// loading time will be recalculated after the first eval, so
// we take page faults deferred by mmap() into consideration
model.t_load_us = 0;
@@ -40,7 +40,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
model.t_start_us = tm.t_start_us;
try {
llama_model_loader ml(fname, params.use_mmap, params.check_tensors, params.kv_overrides);
llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides);
ml.print_info();
@@ -4642,7 +4642,7 @@ struct llm_build_context {
0);
cb(v_states, "v_states", il);
q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
q_pe = ggml_rope_ext(
ctx0, q_pe, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -4651,7 +4651,7 @@ struct llm_build_context {
cb(q_pe, "q_pe", il);
// shared RoPE key
k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
k_pe = ggml_rope_ext(
ctx0, k_pe, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -6496,7 +6496,7 @@ struct llm_build_context {
0);
cb(v_states, "v_states", il);
q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
q_pe = ggml_rope_ext(
ctx0, q_pe, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -6505,7 +6505,7 @@ struct llm_build_context {
cb(q_pe, "q_pe", il);
// shared RoPE key
k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
k_pe = ggml_rope_ext(
ctx0, k_pe, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -9374,14 +9374,9 @@ int64_t llama_time_us(void) {
return ggml_time_us();
}
struct llama_model * llama_load_model_from_file(
const char * path_model,
struct llama_model_params params) {
return llama_model_load_from_file(path_model, params);
}
struct llama_model * llama_model_load_from_file(
const char * path_model,
static struct llama_model * llama_model_load_from_file_impl(
const std::string & path_model,
std::vector<std::string> & splits,
struct llama_model_params params) {
ggml_time_init();
@@ -9404,47 +9399,6 @@ struct llama_model * llama_model_load_from_file(
};
}
if (params.rpc_servers != nullptr && params.rpc_servers[0] != '\0') {
// split the servers set them into model->rpc_servers
std::string servers(params.rpc_servers);
size_t pos = 0;
while ((pos = servers.find(',')) != std::string::npos) {
std::string server = servers.substr(0, pos);
model->rpc_servers.push_back(server);
servers.erase(0, pos + 1);
}
model->rpc_servers.push_back(servers);
}
// add RPC devices
if (!model->rpc_servers.empty()) {
ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC");
if (!rpc_reg) {
LLAMA_LOG_ERROR("%s: failed to find RPC backend\n", __func__);
llama_model_free(model);
return nullptr;
}
typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint);
ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
if (!ggml_backend_rpc_add_device_fn) {
LLAMA_LOG_ERROR("%s: failed to find RPC device add function\n", __func__);
llama_model_free(model);
return nullptr;
}
for (const std::string & server : model->rpc_servers) {
ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str());
if (dev) {
model->devices.push_back(dev);
} else {
LLAMA_LOG_ERROR("%s: failed to add RPC device for server '%s'\n", __func__, server.c_str());
llama_model_free(model);
return nullptr;
}
}
}
// create list of devices to use with this model
if (params.devices) {
for (ggml_backend_dev_t * dev = params.devices; *dev; ++dev) {
@@ -9485,7 +9439,7 @@ struct llama_model * llama_model_load_from_file(
LLAMA_LOG_INFO("%s: using device %s (%s) - %zu MiB free\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), free/1024/1024);
}
const int status = llama_model_load(path_model, *model, params);
const int status = llama_model_load(path_model, splits, *model, params);
GGML_ASSERT(status <= 0);
if (status < 0) {
if (status == -1) {
@@ -9501,6 +9455,35 @@ struct llama_model * llama_model_load_from_file(
return model;
}
// deprecated
struct llama_model * llama_load_model_from_file(
const char * path_model,
struct llama_model_params params) {
return llama_model_load_from_file(path_model, params);
}
struct llama_model * llama_model_load_from_file(
const char * path_model,
struct llama_model_params params) {
std::vector<std::string> splits = {};
return llama_model_load_from_file_impl(path_model, splits, params);
}
struct llama_model * llama_model_load_from_splits(
const char ** paths,
size_t n_paths,
struct llama_model_params params) {
std::vector<std::string> splits;
if (n_paths == 0) {
LLAMA_LOG_ERROR("%s: list of splits is empty\n", __func__);
return nullptr;
}
for (size_t i = 0; i < n_paths; ++i) {
splits.push_back(paths[i]);
}
return llama_model_load_from_file_impl(splits.front(), splits, params);
}
struct llama_context * llama_init_from_model(
struct llama_model * model,
struct llama_context_params params) {
+319 -54
View File
@@ -780,7 +780,7 @@ struct test_case {
}
}
if (!any_params) {
printf("not supported [%s] \n", op_name);
printf("not supported [%s] \n", op_desc(out).c_str());
supported = false;
}
if (!supported) {
@@ -1130,6 +1130,59 @@ struct test_get_rows : public test_case {
}
};
// GGML_OP_GET_ROWS_BACK
struct test_get_rows_back : public test_case {
const ggml_type type;
const int n; // cols
const int m; // rows
const int r; // rows to get
const int b; // batch size
const bool v; // view (non-contiguous src1)
std::string vars() override {
return VARS_TO_STR6(type, n, m, r, b, v);
}
test_get_rows_back(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3, int b = 1, bool v = false)
: type(type), n(n), m(m), r(r), b(b), v(v) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * in_forward = ggml_new_tensor_3d(ctx, type, n, m, b);
ggml_set_name(in_forward, "in_forward");
ggml_tensor * rows = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, r, b);
ggml_set_name(rows, "rows");
if (v) {
rows = ggml_view_2d(ctx, rows, r/2, b, rows->nb[1], 0);
ggml_set_name(rows, "view_of_rows");
}
ggml_tensor * grad = ggml_new_tensor_3d(ctx, type, n, r, b);
ggml_set_name(grad, "grad");
ggml_tensor * out = ggml_get_rows_back(ctx, grad, rows, in_forward);
ggml_set_name(out, "out");
return out;
}
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->type == GGML_TYPE_I32) {
if (ggml_is_view_op(t->op)) { continue; }
// rows
std::vector<int> data(r*b);
for (int i = 0; i < r*b; i++) {
data[i] = rand() % m;
}
ggml_backend_tensor_set(t, data.data(), 0, r * b * sizeof(int));
} else {
init_tensor_uniform(t);
}
}
}
};
// GGML_OP_ARGMAX
struct test_argmax : public test_case {
const ggml_type type;
@@ -1531,6 +1584,39 @@ struct test_scale : public test_case {
}
};
// GGML_OP_SILU_BACK
struct test_silu_back : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
float eps;
std::string vars() override {
return VARS_TO_STR3(type, ne, eps);
}
test_silu_back(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {64, 5, 4, 3},
float eps = 1e-6f)
: type(type), ne(ne), eps(eps) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(a, "a");
ggml_tensor * grad = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(grad, "grad");
ggml_tensor * out = ggml_silu_back(ctx, a, grad);
ggml_set_name(out, "out");
return out;
}
bool grad_precise() override {
return true;
}
};
// GGML_OP_NORM
struct test_norm : public test_case {
const ggml_type type;
@@ -1583,11 +1669,56 @@ struct test_rms_norm : public test_case {
return out;
}
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
init_tensor_uniform(t, -10.f, 10.f);
}
}
float grad_eps() override {
return 1.0f;
}
bool grad_precise() override {
return true;
}
};
// GGML_OP_RMS_NORM_BACK
struct test_rms_norm_back : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
float eps;
std::string vars() override {
return VARS_TO_STR3(type, ne, eps);
}
test_rms_norm_back(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {64, 5, 4, 3},
float eps = 1e-6f)
: type(type), ne(ne), eps(eps) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(a, "a");
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(b, "b");
ggml_tensor * out = ggml_rms_norm_back(ctx, a, b, eps);
ggml_set_name(out, "out");
return out;
}
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
init_tensor_uniform(t, -10.f, 10.f);
}
}
};
// GGML_OP_SSM_CONV
struct test_ssm_conv : public test_case {
const ggml_type type;
@@ -1855,10 +1986,11 @@ struct test_out_prod : public test_case {
const int64_t n;
const int64_t k;
const std::array<int64_t, 2> bs; // dims 3 and 4
const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
const bool trans_b;
std::string vars() override {
return VARS_TO_STR7(type_a, type_b, m, n, k, bs, trans_b);
return VARS_TO_STR8(type_a, type_b, m, n, k, bs, nr, trans_b);
}
double max_nmse_err() override {
@@ -1868,8 +2000,9 @@ struct test_out_prod : public test_case {
test_out_prod(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
int64_t m = 32, int64_t n = 32, int64_t k = 32,
std::array<int64_t, 2> bs = {10, 10},
std::array<int64_t, 2> nr = {2, 2},
bool trans_b = false)
: type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), trans_b(trans_b) {}
: type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), trans_b(trans_b) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, m, k, bs[0], bs[1]);
@@ -1877,10 +2010,10 @@ struct test_out_prod : public test_case {
ggml_tensor * b;
if (trans_b) {
b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0], bs[1]);
b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
b = ggml_transpose(ctx, b);
} else {
b = ggml_new_tensor_4d(ctx, type_b, n, k, bs[0], bs[1]);
b = ggml_new_tensor_4d(ctx, type_b, n, k, bs[0]*nr[0], bs[1]*nr[1]);
}
ggml_set_name(b, "b");
@@ -2191,8 +2324,38 @@ struct test_soft_max : public test_case {
}
};
// GGML_OP_SOFT_MAX_BACK
struct test_soft_max_back : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
const float scale;
const float max_bias;
// GGML_OP_ROPE
std::string vars() override {
return VARS_TO_STR4(type, ne, scale, max_bias);
}
test_soft_max_back(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 5, 4, 3},
float scale = 1.0f,
float max_bias = 0.0f)
: type(type), ne(ne), scale(scale), max_bias(max_bias) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(a, "a");
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(a, "a");
ggml_tensor * out = ggml_soft_max_ext_back(ctx, a, b, scale, max_bias);
ggml_set_name(out, "out");
return out;
}
};
// GGML_OP_ROPE + GGML_OP_ROPE_BACK
struct test_rope : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne_a;
@@ -2204,29 +2367,36 @@ struct test_rope : public test_case {
float af; // attn_factor
bool ff;
int v; // view (1 : non-contiguous a)
bool forward;
std::string vars() override {
// forward can be inferred from the op, does not need to be printed
return VARS_TO_STR10(type, ne_a, n_dims, mode, n_ctx, fs, ef, af, ff, v);
}
test_rope(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne_a = {10, 5, 3, 1},
int n_dims = 10, int mode = 0, int n_ctx = 512, float fs = 1.0f, float ef = 0.0f, float af = 0.0f, bool ff = false, int v = 0)
: type(type), ne_a(ne_a), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff), v(v) {}
int n_dims = 10, int mode = 0, int n_ctx = 512, float fs = 1.0f,
float ef = 0.0f, float af = 0.0f, bool ff = false, int v = 0, bool forward = true)
: type(type), ne_a(ne_a), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff), v(v), forward(forward) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a;
if (v & 1) {
auto ne = ne_a; ne[0] *= 2; ne[1] *= 4; ne[2] *= 3;
a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, a);
if (forward) {
ggml_set_param(ctx, a);
}
ggml_set_name(a, "a");
a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
ggml_set_name(a, "view_of_a");
} else {
a = ggml_new_tensor(ctx, type, 4, ne_a.data());
ggml_set_param(ctx, a);
if (forward) {
ggml_set_param(ctx, a);
}
ggml_set_name(a, "a");
}
@@ -2252,14 +2422,26 @@ struct test_rope : public test_case {
if (is_vision) {
GGML_ASSERT(n_dims/4 > 0);
int rope_sections[4] = {n_dims/4, n_dims/4, 0, 0}; // Vision-RoPE only use first two dimension for image (x, y) coordinate
out = ggml_rope_multi(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
if (forward) {
out = ggml_rope_multi (ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
} else {
out = ggml_rope_multi_back(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
}
} else {
GGML_ASSERT(n_dims/3 > 0);
int rope_sections[4] = {n_dims/3, n_dims/3, n_dims/3, 0};
out = ggml_rope_multi(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
if (forward) {
out = ggml_rope_multi (ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
} else {
out = ggml_rope_multi_back(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
}
}
} else {
out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
if (forward) {
out = ggml_rope_ext (ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
} else {
out = ggml_rope_ext_back(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
}
}
ggml_set_name(out, "out");
@@ -2864,9 +3046,10 @@ struct test_flash_attn_ext : public test_case {
const float logit_softcap; // Gemma 2
const ggml_type type_KV;
std::array<int32_t, 4> permute;
std::string vars() override {
return VARS_TO_STR8(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV);
return VARS_TO_STR9(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV, permute);
}
double max_nmse_err() override {
@@ -2881,19 +3064,33 @@ struct test_flash_attn_ext : public test_case {
}
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8,
bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16)
: hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV) {}
bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16,
std::array<int32_t, 4> permute = {0, 1, 2, 3})
: hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV), permute(permute) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs_padded, nb, nh, 1);
auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) -> ggml_tensor * {
int64_t ne[4] = {ne0, ne1, ne2, ne3};
int64_t ne_perm[4];
for (int i = 0; i < 4; ++i) {
ne_perm[permute[i]] = ne[i];
}
ggml_tensor * t = ggml_new_tensor_4d(ctx, type, ne_perm[0], ne_perm[1], ne_perm[2], ne_perm[3]);
if (permute != std::array<int32_t, 4>{0, 1, 2, 3}) {
t = ggml_permute(ctx, t, permute[0], permute[1], permute[2], permute[3]);
}
return t;
};
ggml_tensor * q = create_permuted(GGML_TYPE_F32, hs_padded, nb, nh, 1);
ggml_set_name(q, "q");
ggml_tensor * k = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1);
ggml_tensor * k = create_permuted(type_KV, hs_padded, kv, nh, 1);
ggml_set_name(k, "k");
ggml_tensor * v = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1);
ggml_tensor * v = create_permuted(type_KV, hs_padded, kv, nh, 1);
ggml_set_name(v, "v");
ggml_tensor * m = nullptr;
@@ -2961,6 +3158,40 @@ struct test_cross_entropy_loss : public test_case {
}
};
// GGML_OP_CROSS_ENTROPY_LOSS_BACK
struct test_cross_entropy_loss_back : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
std::string vars() override {
return VARS_TO_STR2(type, ne);
}
test_cross_entropy_loss_back(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 5, 4, 3})
: type(type), ne(ne) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * grad = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
ggml_set_name(grad, "grad");
ggml_tensor * logits = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(logits, "logits");
ggml_tensor * labels = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(labels, "labels");
// Ensure labels add up to 1:
labels = ggml_soft_max(ctx, labels);
ggml_set_name(labels, "labels_normalized");
ggml_tensor * out = ggml_cross_entropy_loss_back(ctx, grad, logits, labels);
ggml_set_name(out, "out");
return out;
}
};
// GGML_OP_OPT_STEP_ADAMW
struct test_opt_step_adamw : public test_case {
const ggml_type type;
@@ -3460,6 +3691,16 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}
test_cases.emplace_back(new test_get_rows_back(GGML_TYPE_F32, 1, 8, 2, 1, false));
for (ggml_type type : all_types) {
for (bool v : {false, true}) {
test_cases.emplace_back(new test_get_rows_back(type, 256, 5, 4, 1, v));
}
}
for (bool v : {false, true}) {
test_cases.emplace_back(new test_get_rows_back(GGML_TYPE_I32, 256, 5, 4, 1, v));
}
for (ggml_type type_input : {GGML_TYPE_F32}) {
for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) {
for (int k0 : {1, 3}) {
@@ -3582,6 +3823,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows
}
}
for (ggml_type type_dst : {GGML_TYPE_F32}) {
for (ggml_type type_src : all_types) {
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows
}
}
for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_F32}) {
for (ggml_type type_dst : {GGML_TYPE_F16, GGML_TYPE_F32}) {
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {1, 0, 2, 3})); // cpy not-contiguous
@@ -3638,10 +3885,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_add1());
test_cases.emplace_back(new test_scale());
test_cases.emplace_back(new test_silu_back());
for (float eps : {1e-6f, 1e-5f, 1e-3f, 1e-1f}) {
test_cases.emplace_back(new test_norm(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
for (float eps : {0.0f, 1e-7f, 1e-4f, 1e-1f}) {
test_cases.emplace_back(new test_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
test_cases.emplace_back(new test_rms_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
}
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 1, 1}, {4, 1536, 1, 1}));
@@ -3781,22 +4030,19 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, { 1, 1}));
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 1}));
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 1}));
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 10}));
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 10}));
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 10}));
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 10}));
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, { 1, 1}));
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, { 1, 1}, true));
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 1}));
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 1}));
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 10}));
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 10}));
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 10}));
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 10}));
for (int n : {1, 16}) {
for (int k : {1, 16}) {
for (int bs2 : {1, 3}) {
for (int bs3 : {1, 3}) {
for (int nr2 : {1, 2}) {
for (int nr3 : {1, 2}) {
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, n, k, {bs2, bs3}, {nr2, nr3}));
}
}
}
}
}
}
}
}
@@ -3839,12 +4085,23 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}
}
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, 0.1f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, 0.1f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f));
{
for (float max_bias : {0.0f, 8.0f}) {
for (float scale : {1.0f, 0.1f}) {
for (int64_t ne0 : {16, 1024}) {
for (int64_t ne1 : {16, 1024}) {
test_cases.emplace_back(new test_soft_max_back(GGML_TYPE_F32, {ne0, ne1, 1, 1}, scale, max_bias));
test_cases.emplace_back(new test_soft_max_back(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, scale, max_bias));
}
}
}
}
for (bool fw : {true, false}) { // fw == forward
bool all = true;
for (float v : { 0, 1 }) {
@@ -3853,29 +4110,29 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (float af : { 1.0f, 1.4245f }) {
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
for (bool ff : {false, true}) { // freq_factors
test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 7B
test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 7B
if (all) {
test_cases.emplace_back(new test_rope(type, {128, 40, 2, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 13B
test_cases.emplace_back(new test_rope(type, {128, 52, 2, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 30B
test_cases.emplace_back(new test_rope(type, {128, 64, 2, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 65B
test_cases.emplace_back(new test_rope(type, {128, 40, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 13B
test_cases.emplace_back(new test_rope(type, {128, 52, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 30B
test_cases.emplace_back(new test_rope(type, {128, 64, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 65B
}
if (all) {
test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B)
test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B)
test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B)
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 2, 512, fs, ef, af, ff, v)); // neox (stablelm)
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 2, 512, fs, ef, af, ff, v)); // neox (phi-2)
test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 2, 512, fs, ef, af, ff, v, fw)); // neox (stablelm)
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
}
if (all) {
test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v)); // rope_multi,m-rope (qwen2vl 2B)
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v)); // rope_multi,m-rope (qwen2vl 7B)
test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v)); // rope_multi,m-rope (qwen2vl ViT)
test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 2B)
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 7B)
test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT)
}
test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B)
test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
}
}
@@ -3925,6 +4182,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (int nb : { 1, 3, 32, 35, }) {
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV));
// run fewer test cases permuted
if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV, {0, 2, 1, 3}));
}
}
}
}
@@ -3934,7 +4195,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}
test_cases.emplace_back(new test_cross_entropy_loss());
test_cases.emplace_back(new test_cross_entropy_loss (GGML_TYPE_F32, { 10, 5, 4, 3}));
test_cases.emplace_back(new test_cross_entropy_loss (GGML_TYPE_F32, {30000, 1, 1, 1}));
test_cases.emplace_back(new test_cross_entropy_loss_back(GGML_TYPE_F32, { 10, 5, 4, 3}));
test_cases.emplace_back(new test_cross_entropy_loss_back(GGML_TYPE_F32, {30000, 1, 1, 1}));
test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}));
// these tests are disabled to save execution time, but they can be handy for debugging
+162 -132
View File
@@ -9,7 +9,7 @@
#include "common.h"
int main(void) {
llama_chat_message conversation[] = {
std::vector<llama_chat_message> conversation {
{"system", "You are a helpful assistant"},
{"user", "Hello"},
{"assistant", "Hi there"},
@@ -17,130 +17,161 @@ int main(void) {
{"assistant", " I am an assistant "},
{"user", "Another question"},
};
size_t message_count = 6;
std::vector<std::string> templates = {
// teknium/OpenHermes-2.5-Mistral-7B
"{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
// mistralai/Mistral-7B-Instruct-v0.2 (NOTE: Old pre-v1 without a system prompt)
"{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
// TheBloke/FusionNet_34Bx2_MoE-AWQ
"{%- for idx in range(0, messages|length) -%}\\n{%- if messages[idx]['role'] == 'user' -%}\\n{%- if idx > 1 -%}\\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\\n{%- else -%}\\n{{- messages[idx]['content'] + ' [/INST]' -}}\\n{%- endif -%}\\n{% elif messages[idx]['role'] == 'system' %}\\n{{- '[INST] <<SYS>>\\\\n' + messages[idx]['content'] + '\\\\n<</SYS>>\\\\n\\\\n' -}}\\n{%- elif messages[idx]['role'] == 'assistant' -%}\\n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}}\\n{% endif %}\\n{% endfor %}",
// bofenghuang/vigogne-2-70b-chat
"{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez autant que vous le pouvez.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\\\n' + system_message + '\\\\n<</SYS>>\\\\n\\\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\\\n' + content.strip() + '\\\\n<</SYS>>\\\\n\\\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}",
// mlabonne/AlphaMonarch-7B
"{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}",
// google/gemma-7b-it
"{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\\n' + message['content'] | trim + '<end_of_turn>\\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\\n'}}{% endif %}",
// OrionStarAI/Orion-14B-Chat
"{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}",
// openchat/openchat-3.5-0106
// The included chat_template differs from the author's suggestions here: https://huggingface.co/openchat/openchat_3.5/discussions/5#65448109b4a3f3a2f486fd9d
// So we match against the included template but implement the suggested version.
"{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}",
// deepseek-ai/deepseek-coder-33b-instruct
"{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}",
// eachadea/vicuna-13b-1.1
// No template included in tokenizer_config.json, so this template likely needs to be manually set.
"{%- for message in messages %}{%- if message['role'] == 'system' -%}{{- '' + message['content'] + '\n\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '</s>\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}",
// Orca-Vicuna
// No template included in tokenizer_config.json, so this template likely needs to be manually set.
"{%- for message in messages %}{%- if message['role'] == 'system' -%}{{-'SYSTEM: ' + message['content'] + '\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '</s>\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}",
// CohereForAI/c4ai-command-r-plus
"{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
// Llama-3
"{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}",
//Phi-3-mini
"{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
//Phi-3-small
"{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}",
//Phi-3-medium
"{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
//Phi-3-vision
"{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}",
// ChatGLM3
"{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
// ChatGLM4
u8"[gMASK]<sop>{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
u8"{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + '<AI>'}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}",
// DeepSeek-V2
"{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}",
// ibm-granite/granite-3.0-8b-instruct
"{%- if tools %}\n {{- '<|start_of_role|>available_tools<|end_of_role|>\n' }}\n {%- for tool in tools %}\n {{- tool | tojson(indent=4) }}\n {%- if not loop.last %}\n {{- '\n\n' }}\n {%- endif %}\n {%- endfor %}\n {{- '<|end_of_text|>\n' }}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n {{- '<|start_of_role|>system<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'user' %}\n {{- '<|start_of_role|>user<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'assistant' %}\n {{- '<|start_of_role|>assistant<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'assistant_tool_call' %}\n {{- '<|start_of_role|>assistant<|end_of_role|><|tool_call|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'tool_response' %}\n {{- '<|start_of_role|>tool_response<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- endif %}\n {%- if loop.last and add_generation_prompt %}\n {{- '<|start_of_role|>assistant<|end_of_role|>' }}\n {%- endif %}\n{%- endfor %}",
// mistralai/Mistral-7B-Instruct-v0.2 (mistralai 'v1' template with a system prompt)
"{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message['role'] == 'user' %}\n {%- if loop.first and system_message is defined %}\n {{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}\n {%- else %}\n {{- ' [INST] ' + message['content'] + ' [/INST]' }}\n {%- endif %}\n {%- elif message['role'] == 'assistant' %}\n {{- ' ' + message['content'] + eos_token}}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n{%- endfor %}\n",
// Mistral-Large-Instruct-2407 (mistralai 'v3' template)
"{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n",
// Mistral-Nemo-Instruct-2407 (mistralai 'v3-tekken' template)
"{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS][\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST]\" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST]\" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif (message.tool_calls is defined and message.tool_calls is not none) %}\n {{- \"[TOOL_CALLS][\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- message[\"content\"] + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS]{\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n",
// mistralai/Mistral-Large-Instruct-2411 (mistralai 'v7' template)
"{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'system' %}{{ '[SYSTEM_PROMPT] ' + message['content'] + '[/SYSTEM_PROMPT]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token }}{% else %}{{ raise_exception('Only user, system and assistant roles are supported!') }}{% endif %}{% endfor %}",
// ai-sage/GigaChat-20B-A3B-instruct
"{% if messages[0]['role'] == 'system' -%}\n {%- set loop_messages = messages[1:] -%}\n {%- set system_message = bos_token + messages[0]['content'] + additional_special_tokens[1] -%}\n{%- else -%}\n {%- set loop_messages = messages -%}\n {%- set system_message = bos_token + '' -%}\n{%- endif -%}\n{%- for message in loop_messages %}\n {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}\n {% endif %}\n \n {%- if loop.index0 == 0 -%}\n {{ system_message -}}\n {%- endif -%}\n {%- if message['role'] == 'user' -%}\n {{ message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1] -}}\n {{ 'available functions' + additional_special_tokens[0] + additional_special_tokens[2] + additional_special_tokens[3] + additional_special_tokens[1] -}}\n {%- endif -%}\n {%- if message['role'] == 'assistant' -%}\n {{ message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1] -}}\n {%- endif -%}\n {%- if loop.last and add_generation_prompt -%}\n {{ 'assistant' + additional_special_tokens[0] -}}\n {%- endif -%}\n{%- endfor %}",
// Infinigence/Megrez-3B-Instruct
u8"{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|role_start|>system<|role_end|>你是Megrez-3B-Instruct,将针对用户的问题给出详细的、积极的回答。<|turn_end|>' }}{% endif %}{{ '<|role_start|>' + message['role'] + '<|role_end|>' + message['content'] + '<|turn_end|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|role_start|>assistant<|role_end|>' }}{% endif %}",
// phi-4
"{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|im_start|>system<|im_sep|>' + message['content'] + '<|im_end|>'}}{% elif (message['role'] == 'user') %}{{'<|im_start|>user<|im_sep|>' + message['content'] + '<|im_end|><|im_start|>assistant<|im_sep|>'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|im_end|>'}}{% endif %}{% endfor %}",
struct TestCase {
std::string name;
std::string template_str;
std::string expected_output;
};
std::vector<std::string> expected_output = {
// teknium/OpenHermes-2.5-Mistral-7B
"<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n",
// mistralai/Mistral-7B-Instruct-v0.2 (NOTE: Old pre-v1 without a system prompt)
"[INST] You are a helpful assistant\nHello [/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
// TheBloke/FusionNet_34Bx2_MoE-AWQ
"[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST]Hi there</s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]",
// bofenghuang/vigogne-2-70b-chat
"[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST]Hi there</s>[INST] Who are you [/INST]I am an assistant</s>[INST] Another question [/INST]",
// mlabonne/AlphaMonarch-7B
"system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n I am an assistant </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
// google/gemma-7b-it
"<start_of_turn>user\nYou are a helpful assistant\n\nHello<end_of_turn>\n<start_of_turn>model\nHi there<end_of_turn>\n<start_of_turn>user\nWho are you<end_of_turn>\n<start_of_turn>model\nI am an assistant<end_of_turn>\n<start_of_turn>user\nAnother question<end_of_turn>\n<start_of_turn>model\n",
// OrionStarAI/Orion-14B-Chat
"Human: You are a helpful assistant\n\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: </s>",
// openchat/openchat-3.5-0106
"You are a helpful assistant<|end_of_turn|>GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi there<|end_of_turn|>GPT4 Correct User: Who are you<|end_of_turn|>GPT4 Correct Assistant: I am an assistant <|end_of_turn|>GPT4 Correct User: Another question<|end_of_turn|>GPT4 Correct Assistant:",
// deepseek-ai/deepseek-coder-33b-instruct
"You are a helpful assistant### Instruction:\nHello\n### Response:\nHi there\n<|EOT|>\n### Instruction:\nWho are you\n### Response:\n I am an assistant \n<|EOT|>\n### Instruction:\nAnother question\n### Response:\n",
// eachadea/vicuna-13b-1.1
"You are a helpful assistant\n\nUSER: Hello\nASSISTANT: Hi there</s>\nUSER: Who are you\nASSISTANT: I am an assistant </s>\nUSER: Another question\nASSISTANT:",
// Orca-Vicuna
"SYSTEM: You are a helpful assistant\nUSER: Hello\nASSISTANT: Hi there</s>\nUSER: Who are you\nASSISTANT: I am an assistant </s>\nUSER: Another question\nASSISTANT:",
// CohereForAI/c4ai-command-r-plus
"<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a helpful assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Who are you<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>I am an assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Another question<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
// Llama 3
"<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHi there<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWho are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI am an assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nAnother question<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
//Phi-3-mini
"<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
//Phi-3-small
"<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
//Phi-3-medium
"<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
//Phi-3-vision
"<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
// ChatGLM3
"[gMASK]sop<|system|>\n You are a helpful assistant<|user|>\n Hello<|assistant|>\n Hi there<|user|>\n Who are you<|assistant|>\n I am an assistant <|user|>\n Another question<|assistant|>",
// ChatGLM4
"[gMASK]<sop><|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>",
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
u8"You are a helpful assistant<用户>Hello<AI>Hi there<用户>Who are you<AI>I am an assistant<用户>Another question<AI>",
// DeepSeek-V2
u8"You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<end▁of▁sentence>User: Who are you\n\nAssistant: I am an assistant <end▁of▁sentence>User: Another question\n\nAssistant:",
// ibm-granite/granite-3.0-8b-instruct
"<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|> I am an assistant <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>\n",
// mistralai/Mistral-7B-Instruct-v0.2 (mistralai 'v1' template with a system prompt)
" [INST] You are a helpful assistant\n\nHello [/INST] Hi there</s> [INST] Who are you [/INST] I am an assistant </s> [INST] Another question [/INST]",
// Mistral-Large-Instruct-2407 (mistralai 'v3' template; modified to have system prompt at start)
"[INST] You are a helpful assistant\n\nHello[/INST] Hi there</s>[INST] Who are you[/INST] I am an assistant</s>[INST] Another question[/INST]",
// Mistral-Nemo-Instruct-2407 (mistralai 'v3-tekken' template; modified to have system prompt at start)
"[INST]You are a helpful assistant\n\nHello[/INST]Hi there</s>[INST]Who are you[/INST] I am an assistant </s>[INST]Another question[/INST]",
// mistralai/Mistral-Large-Instruct-2411 (mistralai 'v7' template)
"[SYSTEM_PROMPT] You are a helpful assistant[/SYSTEM_PROMPT][INST] Hello[/INST] Hi there</s>[INST] Who are you[/INST] I am an assistant </s>[INST] Another question[/INST]",
// ai-sage/GigaChat-20B-A3B-instruct
"<s>You are a helpful assistant<|message_sep|>user<|role_sep|>Hello<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>Hi there<|message_sep|>user<|role_sep|>Who are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|> I am an assistant <|message_sep|>user<|role_sep|>Another question<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>",
// Infinigence/Megrez-3B-Instruct
"<|role_start|>system<|role_end|>You are a helpful assistant<|turn_end|><|role_start|>user<|role_end|>Hello<|turn_end|><|role_start|>assistant<|role_end|>Hi there<|turn_end|><|role_start|>user<|role_end|>Who are you<|turn_end|><|role_start|>assistant<|role_end|> I am an assistant <|turn_end|><|role_start|>user<|role_end|>Another question<|turn_end|><|role_start|>assistant<|role_end|>",
// phi-4
"<|im_start|>system<|im_sep|>You are a helpful assistant<|im_end|><|im_start|>user<|im_sep|>Hello<|im_end|><|im_start|>assistant<|im_sep|>Hi there<|im_end|><|im_start|>user<|im_sep|>Who are you<|im_end|><|im_start|>assistant<|im_sep|> I am an assistant <|im_end|><|im_start|>user<|im_sep|>Another question<|im_end|><|im_start|>assistant<|im_sep|>",
std::vector<TestCase> test_cases {
{
/* .name= */ "teknium/OpenHermes-2.5-Mistral-7B",
/* .template_str= */ "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
/* .expected_output= */ "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n",
},
{
/* .name= */ "mistralai/Mistral-7B-Instruct-v0.2 (NOTE: Old pre-v1 without a system prompt)",
/* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
/* .expected_output= */ "[INST] You are a helpful assistant\nHello [/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
},
{
/* .name= */ "TheBloke/FusionNet_34Bx2_MoE-AWQ",
/* .template_str= */ "{%- for idx in range(0, messages|length) -%}\\n{%- if messages[idx]['role'] == 'user' -%}\\n{%- if idx > 1 -%}\\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\\n{%- else -%}\\n{{- messages[idx]['content'] + ' [/INST]' -}}\\n{%- endif -%}\\n{% elif messages[idx]['role'] == 'system' %}\\n{{- '[INST] <<SYS>>\\\\n' + messages[idx]['content'] + '\\\\n<</SYS>>\\\\n\\\\n' -}}\\n{%- elif messages[idx]['role'] == 'assistant' -%}\\n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}}\\n{% endif %}\\n{% endfor %}",
/* .expected_output= */ "[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST]Hi there</s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]",
},
{
/* .name= */ "bofenghuang/vigogne-2-70b-chat",
/* .template_str= */ "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez autant que vous le pouvez.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\\\n' + system_message + '\\\\n<</SYS>>\\\\n\\\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\\\n' + content.strip() + '\\\\n<</SYS>>\\\\n\\\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}",
/* .expected_output= */ "[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST]Hi there</s>[INST] Who are you [/INST]I am an assistant</s>[INST] Another question [/INST]",
},
{
/* .name= */ "mlabonne/AlphaMonarch-7B",
/* .template_str= */ "{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}",
/* .expected_output= */ "system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n I am an assistant </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
},
{
/* .name= */ "google/gemma-7b-it",
/* .template_str= */ "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\\n' + message['content'] | trim + '<end_of_turn>\\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\\n'}}{% endif %}",
/* .expected_output= */ "<start_of_turn>user\nYou are a helpful assistant\n\nHello<end_of_turn>\n<start_of_turn>model\nHi there<end_of_turn>\n<start_of_turn>user\nWho are you<end_of_turn>\n<start_of_turn>model\nI am an assistant<end_of_turn>\n<start_of_turn>user\nAnother question<end_of_turn>\n<start_of_turn>model\n",
},
{
/* .name= */ "OrionStarAI/Orion-14B-Chat",
/* .template_str= */ "{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}",
/* .expected_output= */ "Human: You are a helpful assistant\n\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: </s>",
},
{
/* .name= */ "openchat/openchat-3.5-0106",
// The included chat_template differs from the author's suggestions here: https://huggingface.co/openchat/openchat_3.5/discussions/5#65448109b4a3f3a2f486fd9d
// So we match against the included template but implement the suggested version.
/* .template_str= */ "{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}",
/* .expected_output= */ "You are a helpful assistant<|end_of_turn|>GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi there<|end_of_turn|>GPT4 Correct User: Who are you<|end_of_turn|>GPT4 Correct Assistant: I am an assistant <|end_of_turn|>GPT4 Correct User: Another question<|end_of_turn|>GPT4 Correct Assistant:",
},
{
/* .name= */ "deepseek-ai/deepseek-coder-33b-instruct",
/* .template_str= */ "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}",
/* .expected_output= */ "You are a helpful assistant### Instruction:\nHello\n### Response:\nHi there\n<|EOT|>\n### Instruction:\nWho are you\n### Response:\n I am an assistant \n<|EOT|>\n### Instruction:\nAnother question\n### Response:\n",
},
{
/* .name= */ "eachadea/vicuna-13b-1.1",
// No template included in tokenizer_config.json, so this template likely needs to be manually set.
/* .template_str= */ "{%- for message in messages %}{%- if message['role'] == 'system' -%}{{- '' + message['content'] + '\n\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '</s>\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}",
/* .expected_output= */ "You are a helpful assistant\n\nUSER: Hello\nASSISTANT: Hi there</s>\nUSER: Who are you\nASSISTANT: I am an assistant </s>\nUSER: Another question\nASSISTANT:",
},
{
/* .name= */ "Orca-Vicuna",
// No template included in tokenizer_config.json, so this template likely needs to be manually set.
/* .template_str= */ "{%- for message in messages %}{%- if message['role'] == 'system' -%}{{-'SYSTEM: ' + message['content'] + '\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '</s>\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}",
/* .expected_output= */ "SYSTEM: You are a helpful assistant\nUSER: Hello\nASSISTANT: Hi there</s>\nUSER: Who are you\nASSISTANT: I am an assistant </s>\nUSER: Another question\nASSISTANT:",
},
{
/* .name= */ "CohereForAI/c4ai-command-r-plus",
/* .template_str= */ "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
/* .expected_output= */ "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a helpful assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Who are you<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>I am an assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Another question<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
},
{
/* .name= */ "Llama-3",
/* .template_str= */ "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}",
/* .expected_output= */ "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHi there<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWho are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI am an assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nAnother question<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
},
{
/* .name= */ "Phi-3-mini",
/* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
/* .expected_output= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
},
{
/* .name= */ "Phi-3-small",
/* .template_str= */ "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}",
/* .expected_output= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
},
{
/* .name= */ "Phi-3-medium",
/* .template_str= */ "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
/* .expected_output= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
},
{
/* .name= */ "Phi-3-vision",
/* .template_str= */ "{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}",
/* .expected_output= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
},
{
/* .name= */ "ChatGLM3",
/* .template_str= */ "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
/* .expected_output= */ "[gMASK]sop<|system|>\n You are a helpful assistant<|user|>\n Hello<|assistant|>\n Hi there<|user|>\n Who are you<|assistant|>\n I am an assistant <|user|>\n Another question<|assistant|>",
},
{
/* .name= */ "ChatGLM4",
/* .template_str= */ u8"[gMASK]<sop>{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
/* .expected_output= */ "[gMASK]<sop><|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>",
},
{
/* .name= */ "MiniCPM-3B-OpenHermes-2.5-v2-GGUF",
/* .template_str= */ u8"{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + '<AI>'}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}",
/* .expected_output= */ u8"You are a helpful assistant<用户>Hello<AI>Hi there<用户>Who are you<AI>I am an assistant<用户>Another question<AI>",
},
{
/* .name= */ "DeepSeek-V2",
/* .template_str= */ "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}",
/* .expected_output= */ u8"You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<end▁of▁sentence>User: Who are you\n\nAssistant: I am an assistant <end▁of▁sentence>User: Another question\n\nAssistant:",
},
{
/* .name= */ "ibm-granite/granite-3.0-8b-instruct",
/* .template_str= */ "{%- if tools %}\n {{- '<|start_of_role|>available_tools<|end_of_role|>\n' }}\n {%- for tool in tools %}\n {{- tool | tojson(indent=4) }}\n {%- if not loop.last %}\n {{- '\n\n' }}\n {%- endif %}\n {%- endfor %}\n {{- '<|end_of_text|>\n' }}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n {{- '<|start_of_role|>system<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'user' %}\n {{- '<|start_of_role|>user<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'assistant' %}\n {{- '<|start_of_role|>assistant<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'assistant_tool_call' %}\n {{- '<|start_of_role|>assistant<|end_of_role|><|tool_call|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'tool_response' %}\n {{- '<|start_of_role|>tool_response<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- endif %}\n {%- if loop.last and add_generation_prompt %}\n {{- '<|start_of_role|>assistant<|end_of_role|>' }}\n {%- endif %}\n{%- endfor %}",
/* .expected_output= */ "<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|> I am an assistant <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>\n",
},
{
/* .name= */ "mistralai/Mistral-7B-Instruct-v0.2 (mistralai 'v1' template with a system prompt)",
/* .template_str= */ "{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message['role'] == 'user' %}\n {%- if loop.first and system_message is defined %}\n {{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}\n {%- else %}\n {{- ' [INST] ' + message['content'] + ' [/INST]' }}\n {%- endif %}\n {%- elif message['role'] == 'assistant' %}\n {{- ' ' + message['content'] + eos_token}}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n{%- endfor %}\n",
/* .expected_output= */ " [INST] You are a helpful assistant\n\nHello [/INST] Hi there</s> [INST] Who are you [/INST] I am an assistant </s> [INST] Another question [/INST]",
},
{
/* .name= */ "Mistral-Large-Instruct-2407 (mistralai 'v3' template; modified to have system prompt at start)",
/* .template_str= */ "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n",
/* .expected_output= */ "[INST] You are a helpful assistant\n\nHello[/INST] Hi there</s>[INST] Who are you[/INST] I am an assistant</s>[INST] Another question[/INST]",
},
{
/* .name= */ "Mistral-Nemo-Instruct-2407 (mistralai 'v3-tekken' template; modified to have system prompt at start)",
/* .template_str= */ "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS][\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST]\" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST]\" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif (message.tool_calls is defined and message.tool_calls is not none) %}\n {{- \"[TOOL_CALLS][\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- message[\"content\"] + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS]{\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n",
/* .expected_output= */ "[INST]You are a helpful assistant\n\nHello[/INST]Hi there</s>[INST]Who are you[/INST] I am an assistant </s>[INST]Another question[/INST]",
},
{
/* .name= */ "mistralai/Mistral-Large-Instruct-2411 (mistralai 'v7' template)",
/* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'system' %}{{ '[SYSTEM_PROMPT] ' + message['content'] + '[/SYSTEM_PROMPT]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token }}{% else %}{{ raise_exception('Only user, system and assistant roles are supported!') }}{% endif %}{% endfor %}",
/* .expected_output= */ "[SYSTEM_PROMPT] You are a helpful assistant[/SYSTEM_PROMPT][INST] Hello[/INST] Hi there</s>[INST] Who are you[/INST] I am an assistant </s>[INST] Another question[/INST]",
},
{
/* .name= */ "ai-sage/GigaChat-20B-A3B-instruct",
/* .template_str= */ "{% if messages[0]['role'] == 'system' -%}\n {%- set loop_messages = messages[1:] -%}\n {%- set system_message = bos_token + messages[0]['content'] + additional_special_tokens[1] -%}\n{%- else -%}\n {%- set loop_messages = messages -%}\n {%- set system_message = bos_token + '' -%}\n{%- endif -%}\n{%- for message in loop_messages %}\n {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}\n {% endif %}\n \n {%- if loop.index0 == 0 -%}\n {{ system_message -}}\n {%- endif -%}\n {%- if message['role'] == 'user' -%}\n {{ message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1] -}}\n {{ 'available functions' + additional_special_tokens[0] + additional_special_tokens[2] + additional_special_tokens[3] + additional_special_tokens[1] -}}\n {%- endif -%}\n {%- if message['role'] == 'assistant' -%}\n {{ message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1] -}}\n {%- endif -%}\n {%- if loop.last and add_generation_prompt -%}\n {{ 'assistant' + additional_special_tokens[0] -}}\n {%- endif -%}\n{%- endfor %}",
/* .expected_output= */ "<s>You are a helpful assistant<|message_sep|>user<|role_sep|>Hello<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>Hi there<|message_sep|>user<|role_sep|>Who are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|> I am an assistant <|message_sep|>user<|role_sep|>Another question<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>",
},
{
/* .name= */ "Infinigence/Megrez-3B-Instruct",
/* .template_str= */ u8"{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|role_start|>system<|role_end|>你是Megrez-3B-Instruct,将针对用户的问题给出详细的、积极的回答。<|turn_end|>' }}{% endif %}{{ '<|role_start|>' + message['role'] + '<|role_end|>' + message['content'] + '<|turn_end|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|role_start|>assistant<|role_end|>' }}{% endif %}",
/* .expected_output= */ "<|role_start|>system<|role_end|>You are a helpful assistant<|turn_end|><|role_start|>user<|role_end|>Hello<|turn_end|><|role_start|>assistant<|role_end|>Hi there<|turn_end|><|role_start|>user<|role_end|>Who are you<|turn_end|><|role_start|>assistant<|role_end|> I am an assistant <|turn_end|><|role_start|>user<|role_end|>Another question<|turn_end|><|role_start|>assistant<|role_end|>",
},
{
/* .name= */ "phi-4",
/* .template_str= */ "{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|im_start|>system<|im_sep|>' + message['content'] + '<|im_end|>'}}{% elif (message['role'] == 'user') %}{{'<|im_start|>user<|im_sep|>' + message['content'] + '<|im_end|><|im_start|>assistant<|im_sep|>'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|im_end|>'}}{% endif %}{% endfor %}",
/* .expected_output= */ "<|im_start|>system<|im_sep|>You are a helpful assistant<|im_end|><|im_start|>user<|im_sep|>Hello<|im_end|><|im_start|>assistant<|im_sep|>Hi there<|im_end|><|im_start|>user<|im_sep|>Who are you<|im_end|><|im_start|>assistant<|im_sep|> I am an assistant <|im_end|><|im_start|>user<|im_sep|>Another question<|im_end|><|im_start|>assistant<|im_sep|>",
},
};
std::vector<char> formatted_chat(1024);
int32_t res;
@@ -157,17 +188,16 @@ int main(void) {
}
// test invalid chat template
res = llama_chat_apply_template("INVALID TEMPLATE", conversation, message_count, true, formatted_chat.data(), formatted_chat.size());
res = llama_chat_apply_template("INVALID TEMPLATE", conversation.data(), conversation.size(), true, formatted_chat.data(), formatted_chat.size());
assert(res < 0);
for (size_t i = 0; i < templates.size(); i++) {
std::string custom_template = templates[i];
std::string expected = expected_output[i];
for (const auto & test_case : test_cases) {
printf("\n\n=== %s ===\n\n", test_case.name.c_str());
formatted_chat.resize(1024);
res = llama_chat_apply_template(
custom_template.c_str(),
conversation,
message_count,
test_case.template_str.c_str(),
conversation.data(),
conversation.size(),
true,
formatted_chat.data(),
formatted_chat.size()
@@ -176,7 +206,7 @@ int main(void) {
std::string output(formatted_chat.data(), formatted_chat.size());
printf("%s\n", output.c_str());
printf("-------------------------\n");
assert(output == expected);
assert(output == test_case.expected_output);
}
+3 -3
View File
@@ -80,18 +80,18 @@ run_conversion_and_inference_lora() {
# Run inference
echo -e "\n\n---------------------------\n\n"
echo "Running llama-cli without lora for $model_name with hidden_size $hidden_size..."
OUTPUT_BASE=$(./llama-cli -m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32.gguf \
OUTPUT_BASE=$(./llama-cli -no-cnv -m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32.gguf \
-p "$EXPECTED_BASE_FIRST_WORD" -n 50 --seed 42 --temp 0)
echo -e "\n\n---------------------------\n\n"
echo "Running llama-cli with hot lora for $model_name with hidden_size $hidden_size..."
OUTPUT_LORA_HOT=$(./llama-cli -m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32.gguf \
OUTPUT_LORA_HOT=$(./llama-cli -no-cnv -m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32.gguf \
--lora $MODELS_REPO/$model_name/hidden_size=$hidden_size/lora/Lora-F32-LoRA.gguf \
-p "$EXPECTED_LORA_FIRST_WORD" -n 50 --seed 42 --temp 0)
echo -e "\n\n---------------------------\n\n"
echo "Running llama-cli with merged lora for $model_name with hidden_size $hidden_size..."
OUTPUT_LORA_MERGED=$(./llama-cli -m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32-lora-merged.gguf \
OUTPUT_LORA_MERGED=$(./llama-cli -no-cnv -m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32-lora-merged.gguf \
-p "$EXPECTED_LORA_FIRST_WORD" -n 50 --seed 42 --temp 0)
# Remove any initial white space