Compare commits

..

60 Commits

Author SHA1 Message Date
Diego Devesa 5364ae4ba5 llama : print hint when loading a model when no backends are loaded (#13589) 2025-05-16 16:38:07 +02:00
Sigbjørn Skjæret 7c07ac244d ci : add ppc64el to build-linux-cross (#13575) 2025-05-16 14:54:23 +02:00
Łukasz Ślusarczyk 0a338ed013 sycl : fixed compilation warnings (#13582) 2025-05-16 18:15:29 +08:00
Olivier Chafik bc098c3cf0 minja: sync (qwen3) (#13573)
* minja: sync https://github.com/google/minja/commit/f06140fa52fd140fe38e531ec373d8dc9c86aa06

- https://github.com/google/minja/pull/67 (@grf53)
- https://github.com/google/minja/pull/66 (@taha-yassine)
- https://github.com/google/minja/pull/63 (@grf53)
- https://github.com/google/minja/pull/58

---------

Co-authored-by: ochafik <ochafik@google.com>
2025-05-15 23:29:10 +01:00
Diego Devesa c6a2c9e741 gguf : use ggml log system (#13571)
* gguf : use ggml log system

* llama : remove unnecessary new lines in exception messages
2025-05-15 19:13:11 +02:00
Daniel Tang 07ad2b6db3 gguf-py : fix disconnect-before-connect in editor-gui (#13569)
The bug caused a crash upon load with venvs created with
--system-site-packages to use
python3-pyside6.qtwidgets=python3-pyside6.qtwidgets=6.6.2-4
from Kubuntu 24.10.
2025-05-15 18:47:10 +02:00
Xuan-Son Nguyen c531edfa34 convert : fix conversion for llama 4 (#13567) 2025-05-15 17:40:07 +02:00
Atharva Dubey 02cdd2d8b0 sycl: simplify bin_bcast_kernel (#13383) 2025-05-15 17:39:52 +02:00
Svetlozar Georgiev 64bb51cf90 sycl: reordered Q4_K MMVQ (#13109) 2025-05-15 17:35:44 +02:00
Łukasz Ślusarczyk 9c404ed54c sycl: use oneDNN for matrices multiplication (#12972) 2025-05-15 16:53:41 +02:00
Diego Devesa 6c8b91500e llama-bench : fix -ot with dl backends (#13563) 2025-05-15 15:46:55 +02:00
Xuan-Son Nguyen 3cc1f1f1d2 webui : handle PDF input (as text or image) + convert pasted long content to file (#13562)
* webui : handle PDF input (as text or image)

* handle the case where pdf image + server without mtmd

* fix bug missing pages
2025-05-15 14:24:50 +02:00
Piotr Wilkin (ilintar) c753d7bed0 server : proper error handling for missing elements in messages array (OpenAI compatible backend) (#13540) 2025-05-15 08:40:58 +02:00
Georgi Gerganov b2838049cc bench : handle decode errors (#13548)
ggml-ci
2025-05-15 05:57:02 +03:00
Olivier Chafik aa48e373f2 server: inject date_string in llama 3.x template + fix date for firefunction v2 (#12802)
* Inject date_string in llama 3.x + fix for functionary v2

https://github.com/ggml-org/llama.cpp/issues/12729

* move/fix detection of functionary v3.1 before llama 3.x, fix & test their non-tool mode

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* generate more tokens in test_completion_with_required_tool_tiny_fast to avoid truncation

---------

Co-authored-by: ochafik <ochafik@google.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2025-05-15 02:39:51 +01:00
Georgi Gerganov e3a9421b78 kv-cache : fix out-of-bounds view during reserve graph (#13547)
* kv-cache : fix reserve graph out-of-bounds access

ggml-ci

* cont : add comment

* cont : fix comments [no ci]

* cont : more correct comment [no ci]
2025-05-14 23:15:15 +03:00
Yibo Cai 5ab5d5fb25 arm64: optimize q6_k_q8_k kernel with i8mm (#13519)
This PR improves q6_k_q8_k gemm kernel with arm64 i8mm instruction.

Tested on neoverse-n2 with llama3 8b q6_k quantization model.
- 40% ~ 54% S_PP uplift for all batch sizes
- 16% ~ 47% S_TG uplift for batch size 4 and above

Perplexity doesn't change with this PR.

```
// tested on neoverse-n2
$ llama-batched-bench \
      -m Meta-Llama-3-8B-Instruct-Q6_K.gguf \
      --no-mmap -fa \
      -c 8192 -b 4096 -ub 512 -npp 128 -ntg 128 \
      -npl 1,2,4,8,16,32 \
      -t 64

---------------------------------------------------------------------
|    PP |     TG |    B |       S_PP t/s      |       S_TG t/s      |
|       |        |      | original |  this pr | original |  this pr |
|-------|--------|------|----------|----------|----------|----------|
|   128 |    128 |    1 |    78.52 |   109.18 |    18.63 |    18.88 |
|   128 |    128 |    2 |    84.62 |   123.94 |    34.54 |    36.92 |
|   128 |    128 |    4 |    84.36 |   122.49 |    52.65 |    61.32 |
|   128 |    128 |    8 |    90.52 |   138.87 |    63.46 |    84.41 |
|   128 |    128 |   16 |    90.11 |   138.56 |    71.04 |   101.33 |
|   128 |    128 |   32 |    89.81 |   137.79 |    75.14 |   110.47 |
---------------------------------------------------------------------
```
2025-05-14 21:53:52 +02:00
Olivier Chafik 3198405e98 common: add partial regex support (#12808)
* move string_find_partial_stop & string_ends_with to common

* add common_regex (supports partial matches)

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update common/regex-partial.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update common/regex-partial.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update common/regex-partial.h

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* partial regex: add missing iterator end checks

* string utils: use string_views

* direct throw to avoid ggml.h include

* regex-partial: replace missed ggml_asserts

---------

Co-authored-by: ochafik <ochafik@google.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-05-14 19:50:57 +01:00
Sigbjørn Skjæret f5170c1d7a editorconfig : fix trailing whitespace from #13542 (#13546) 2025-05-14 21:22:49 +03:00
Gilad S. 017f10b5fa fix: crash when calling llama_state_get_size on a context without a KV cache (#13542) 2025-05-14 19:18:18 +03:00
Johannes Gäßler 4696d56749 CUDA: fix crash on large batch size for quant. MoE (#13537) 2025-05-14 16:41:02 +02:00
Diego Devesa b7d2672082 llama : fix quantize with dl backends (#13539) 2025-05-14 16:12:36 +02:00
Johannes Gäßler 6da34fa276 CUDA: faster Deepseek FA, add Turing support (#13435) 2025-05-14 16:08:20 +02:00
Gabe Goodhart 5e7d95e22e fix: Move build_inp_pos to the top of the graph section for build_granite (#13538)
This matches how others do it, but will still avoid the extra
initialization when rope is disabled.

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2025-05-14 15:53:59 +03:00
Georgi Gerganov 053174436f server : passthrough the /models endpoint during loading (#13535)
* server : passthrough the /models endpoint during loading

* server : update readme + return json for "meta" field
2025-05-14 15:42:10 +03:00
Xuan-Son Nguyen 360a9c98e1 server : fix cache_tokens bug with no cache_prompt (#13533) 2025-05-14 13:35:07 +02:00
bandoti 09d13d94fb cmake: simplify vulkan shader test logic (#13263) 2025-05-14 07:53:57 -03:00
Jeff Bolz 24e86cae72 vulkan: KHR_coopmat flash attention (#13506)
This shader uses coopmat1 to do the Q*K^T multiply. The P*V multiply is more
difficult for various reasons so I haven't done it. Performance for this
shader is around 2.5x better than for the scalar shader when doing prompt
processing. Some of the benefit may be from other optimizations like staging
through shared memory, or splitting by rows.
2025-05-14 11:55:26 +02:00
Xuan-Son Nguyen bb1681fbd5 webui : use fflate for more deterministic gzip compress (#13525)
* webui : use pako for more deterministic gzip compress

* simpler code

* use fflate instead of pako
2025-05-14 10:26:12 +02:00
Luca Stefani d486dd3e8e webui: Allow pasting file from clipboard (#13526)
* server: Allow pasting file from clipboard

* server: Prevent default action on file paste

* update build

* format then build combined

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
2025-05-14 10:07:31 +02:00
ddpasa 21ca987fba docs: Update link to ggml-org in multimodal.md (#13513)
* Update multimodal.md

Minor change to include the huggingface link

* Update docs/multimodal.md

---------

Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
2025-05-14 09:59:12 +02:00
Sigbjørn Skjæret be1d4a13db scripts : fix compare-llama-bench.py show parameter (#13514) 2025-05-14 08:41:01 +02:00
Jeff Bolz ab3971f2a0 vulkan: workaround FA compile failures on macos (#13517) 2025-05-14 06:15:50 +02:00
Ed Addario e5c834f718 quantize : improve tensor-type pattern matching (#13033) 2025-05-13 19:12:31 +02:00
Xuan-Son Nguyen 71bdbdb587 clip : clip.h become private API (⚠️ breaking change) (#13510) 2025-05-13 17:07:21 +02:00
Georgi Gerganov f0995d28ce metal : use FA-vec kernel up to batch size 20 (#13496)
* batched-bench : fix pp batch contents

* metal : optimize multi-sequence FA vec kernel

ggml-ci

* metal : use FA-vec kernel up to batch size 20

ggml-ci
2025-05-13 18:04:39 +03:00
Georgi Gerganov c252e0c409 metal : optimize multi-sequence FA vec kernel (#13493)
* batched-bench : fix pp batch contents

* metal : optimize multi-sequence FA vec kernel

ggml-ci
2025-05-13 18:04:00 +03:00
Dan Johansson 4f711afed5 ggml-cpu: Update KleidiAI to v1.6 and fix include directives (#13509)
Signed-off-by: Dan Johansson <dan.johansson@arm.com>
2025-05-13 18:02:28 +03:00
Georgi Gerganov b89d605a91 batched-bench : fix pp batch contents (#13492) 2025-05-13 18:01:53 +03:00
Xuan-Son Nguyen b4726345ac mtmd : remove libllava, remove clip-quantize-cli (⚠️ breaking change) (#13460)
* mtmd : remove libllava, remove clip-quantize-cli

* rm clip_model_quantize
2025-05-13 15:33:58 +02:00
Sigbjørn Skjæret bf79371120 scripts : support arbitrary input file formats in compare-llama-bench.py (#13455) 2025-05-13 15:31:12 +02:00
Gabe Goodhart d590cd4c24 model : Granite MoE shared (#13269)
* feat: Add GGUF conversion for granitemoeshared

Branch: GraniteMoEShared

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: hparam and arch plumbing for granitemoeshared

Branch: GraniteMoEShared

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Split MoE fused tensors for shared experts in conversion

Branch: GraniteMoEShared

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: First WIP cut at model arch in cpp

The hparam and architecture plumbing should be correct, but the
implementation of the shared experts seems to still be broken.

Branch: GraniteMoEShared

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Cleaner (maybe more correct?) splitting for gate/up

Branch: GraniteMoEShared

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Fix the input to the shared experts

I had misread that the shared experts take the inputs _before_ the standard
MoE layer and was feeding the output of the MoE to the shared experts.

Branch: GraniteMoEShared

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Avoid architecture-specific checks for Granite MoE Shared

This is a cleaner way that will allow more flexibility in architecture
strings going forward.

Branch: GraniteMoEShared

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* refactor: Split granite architectures out of llm_build_llama

This helps de-clutter the llama-family graph construction and allows
granite to diverge further (in preparation for Granite 4).

NOTE: I removed the granite scale factors from llm_build_deci because they
appear to only be there as copy-paste from llm_build_llama. The HF config
does not seem to set those values:
https://huggingface.co/Deci/DeciLM-7B/blob/main/config.json

Branch: GraniteMoEShared

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Fix compiler warning about uninitialized inp_pos

This should not have been reachable, but it warns on some compliers

Branch: GraniteMoEShared

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Consoladate GraniteMoEShared into GraniteMoE for conversion

Branch: GraniteMoEShared

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Consolidate GraniteMoEShared into GraniteMoE on the c++ side

Branch: GraniteMoEShared

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

---------

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2025-05-13 15:12:01 +02:00
Georgi Gerganov 1e2809bc4b sync : ggml 2025-05-13 14:02:28 +03:00
Diego Devesa cf0a43bb64 llama-bench : add defrag-thold, check for invalid ranges (#13487) 2025-05-13 00:31:37 +02:00
lhez f0d46ef157 opencl: remove unnecessary assert for add (#13257) 2025-05-12 13:13:49 -07:00
Xuan-Son Nguyen de4c07f937 clip : cap max image size 1024 for qwen vl model (#13478) 2025-05-12 15:06:51 +02:00
Johannes Gäßler 10d2af0eaa llama/ggml: add LLM training support (#10544)
* llama/ggml: add LLM training support

more compact progress bar

llama_save_model_to_file

llama_opt_param_filter

ggml_graph_dup force_grads

refactor ggml_opt, fix test-opt

* remove logits_all

* refactor CUDA implementation for ACC

* reset graph at beginning of opt period
2025-05-12 14:44:49 +02:00
Georgi Gerganov 064cc596ac context : fix state io for memory-less contexts (#13470)
ggml-ci
2025-05-12 15:12:27 +03:00
Anudit Nagar 91159ee9df server : allow content to be null in oaicompat_completion_params_parse (#13477) 2025-05-12 13:56:42 +02:00
Diego Devesa 22cdab343b llama-bench : accept ranges for integer parameters (#13410) 2025-05-12 13:08:22 +02:00
Dan Johansson a71a4075cd ggml-cpu: Integrate fp32=bf16xbf16 SME KleidiAI kernel (#13053)
* ggml-cpu: Integrate fp32=bf16xbf16 SME KleidiAI kernel

Signed-off-by: Dan Johansson <dan.johansson@arm.com>

* * code review fixes

Signed-off-by: Dan Johansson <dan.johansson@arm.com>

* * adds a comment that clarifies barrier usage

Signed-off-by: Dan Johansson <dan.johansson@arm.com>

---------

Signed-off-by: Dan Johansson <dan.johansson@arm.com>
Co-authored-by: Charles Xu <charles.xu@arm.com>
2025-05-12 13:06:19 +02:00
Johannes Gäßler 95e18884fc CUDA: fix misaligned synchronization in FA (#13469) 2025-05-12 10:51:21 +02:00
Xuan-Son Nguyen df8491922f ggml : add mrope kernel for metal (#13457) 2025-05-12 10:29:13 +02:00
Atharva Dubey 14492144c2 enable dpcpp nightly builds with libraries (#13406) 2025-05-12 13:15:32 +08:00
City c104023994 mtmd : Use RMS norm for InternVL 3 38B and 78B mmproj (#13459) 2025-05-12 00:39:06 +02:00
Anthony Umfer 9a390c4829 tools : fix uninitialized llama_batch in server (#13436)
* add constructor to initialize server_context::batch, preventing destructor's call to llama_batch_free from causing an invalid free()

* Update tools/server/server.cpp

Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>

* use C++11 initializer syntax

* switch from Copy-list-initialization to Direct-list-initialization

---------

Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
2025-05-11 17:08:26 +02:00
Sigbjørn Skjæret 09232370fc scripts : exit compare-llama-bench.py gracefully when there's nothing to compare (#13451) 2025-05-11 16:20:39 +02:00
Johannes Gäßler 7474e00b34 CUDA: fix crash with partial offloading of MoE (#13439) 2025-05-11 16:09:33 +02:00
David Huang 7f323a589f Add --no-op-offload to improve -ot pp perf in MoE models like llama4 400B (#13386) 2025-05-11 14:18:39 +02:00
City 3eac209319 mtmd : support InternVL 3 38B and 78B mmproj (#13443)
* Support InternVL 3 38B and 78B mmproj

* Swap norms in clip.cpp

* Group variables together
2025-05-11 11:35:52 +02:00
129 changed files with 6706 additions and 3761 deletions
+91
View File
@@ -140,3 +140,94 @@ jobs:
-DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
cmake --build build --config Release -j $(nproc)
ubuntu-24-ppc64el-cpu-cross:
runs-on: ubuntu-24.04
steps:
- uses: actions/checkout@v4
- name: Setup PowerPC64le
run: |
sudo dpkg --add-architecture ppc64el
# Add arch-specific repositories for non-amd64 architectures
cat << EOF | sudo tee /etc/apt/sources.list.d/ppc64el-ports.list
deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble main universe
deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe
deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe
deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe
EOF
sudo apt-get update || true ;# Prevent failure due to missing URLs.
sudo apt-get install -y --no-install-recommends \
build-essential \
gcc-14-powerpc64le-linux-gnu \
g++-14-powerpc64le-linux-gnu \
libcurl4-openssl-dev:ppc64el
- name: Build
run: |
cmake -B build -DCMAKE_BUILD_TYPE=Release \
-DGGML_OPENMP=OFF \
-DLLAMA_BUILD_EXAMPLES=ON \
-DLLAMA_BUILD_TOOLS=ON \
-DLLAMA_BUILD_TESTS=OFF \
-DCMAKE_SYSTEM_NAME=Linux \
-DCMAKE_SYSTEM_PROCESSOR=ppc64 \
-DCMAKE_C_COMPILER=powerpc64le-linux-gnu-gcc-14 \
-DCMAKE_CXX_COMPILER=powerpc64le-linux-gnu-g++-14 \
-DCMAKE_POSITION_INDEPENDENT_CODE=ON \
-DCMAKE_FIND_ROOT_PATH=/usr/lib/powerpc64le-linux-gnu \
-DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
-DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
-DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
cmake --build build --config Release -j $(nproc)
ubuntu-24-ppc64el-vulkan-cross:
runs-on: ubuntu-24.04
steps:
- uses: actions/checkout@v4
- name: Setup PowerPC64le
run: |
sudo dpkg --add-architecture ppc64el
# Add arch-specific repositories for non-amd64 architectures
cat << EOF | sudo tee /etc/apt/sources.list.d/ppc64el-ports.list
deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble main universe
deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe
deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe
deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe
EOF
sudo apt-get update || true ;# Prevent failure due to missing URLs.
sudo apt-get install -y --no-install-recommends \
build-essential \
glslc \
gcc-14-powerpc64le-linux-gnu \
g++-14-powerpc64le-linux-gnu \
libvulkan-dev:ppc64el \
libcurl4-openssl-dev:ppc64el
- name: Build
run: |
cmake -B build -DCMAKE_BUILD_TYPE=Release \
-DGGML_VULKAN=ON \
-DGGML_OPENMP=OFF \
-DLLAMA_BUILD_EXAMPLES=ON \
-DLLAMA_BUILD_TOOLS=ON \
-DLLAMA_BUILD_TESTS=OFF \
-DCMAKE_SYSTEM_NAME=Linux \
-DCMAKE_SYSTEM_PROCESSOR=ppc64 \
-DCMAKE_C_COMPILER=powerpc64le-linux-gnu-gcc-14 \
-DCMAKE_CXX_COMPILER=powerpc64le-linux-gnu-g++-14 \
-DCMAKE_POSITION_INDEPENDENT_CODE=ON \
-DCMAKE_FIND_ROOT_PATH=/usr/lib/powerpc64le-linux-gnu \
-DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
-DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
-DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
cmake --build build --config Release -j $(nproc)
+1
View File
@@ -117,6 +117,7 @@ setup_framework_structure() {
# Copy all required headers (common for all platforms)
cp include/llama.h ${header_path}
cp ggml/include/ggml.h ${header_path}
cp ggml/include/ggml-opt.h ${header_path}
cp ggml/include/ggml-alloc.h ${header_path}
cp ggml/include/ggml-backend.h ${header_path}
cp ggml/include/ggml-metal.h ${header_path}
+2
View File
@@ -73,6 +73,8 @@ add_library(${TARGET} STATIC
minja/minja.hpp
ngram-cache.cpp
ngram-cache.h
regex-partial.cpp
regex-partial.h
sampling.cpp
sampling.h
speculative.cpp
+7
View File
@@ -2437,6 +2437,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
}
));
add_opt(common_arg(
{"--no-op-offload"},
string_format("disable offloading host tensor operations to device (default: %s)", params.no_op_offload ? "true" : "false"),
[](common_params & params) {
params.no_op_offload = true;
}
));
add_opt(common_arg(
{"--lora"}, "FNAME",
"path to LoRA adapter (can be repeated to use multiple adapters)",
+130 -110
View File
@@ -6,6 +6,15 @@
#include <optional>
static std::string format_time(const std::chrono::system_clock::time_point & now, const std::string & format) {
auto time = std::chrono::system_clock::to_time_t(now);
auto local_time = *std::localtime(&time);
std::ostringstream ss;
ss << std::put_time(&local_time, format.c_str());
auto res = ss.str();
return res;
}
typedef minja::chat_template common_chat_template;
struct common_chat_templates {
@@ -24,6 +33,7 @@ struct templates_params {
std::string grammar;
bool add_generation_prompt = true;
bool extract_reasoning = true;
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
};
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
@@ -939,78 +949,83 @@ static void expect_tool_parameters(const std::string & name, const json & parame
}
}
static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
static common_chat_params common_chat_params_init_llama_3_x(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
auto builtin_tools = json::array();
common_chat_params data;
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
if (!inputs.tools.is_null()) {
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
expect_tool_parameters(name, parameters, {"query"});
} else if (name == "python" || name == "code_interpreter") {
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
expect_tool_parameters(name, parameters, {"code"});
} else {
return false;
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
expect_tool_parameters(name, parameters, {"query"});
} else if (name == "python" || name == "code_interpreter") {
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
expect_tool_parameters(name, parameters, {"code"});
} else {
return false;
}
std::vector<std::string> kvs;
for (const auto & [key, value] : parameters.at("properties").items()) {
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
}
tool_rules.push_back(
builder.add_rule(
name + "-call",
"\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\""));
builtin_tools.push_back(name);
return true;
};
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
std::string name = function.at("name");
auto parameters = function.at("parameters");
builder.resolve_refs(parameters);
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
if (allow_python_tag_builtin_tools) {
handle_builtin_tool(name, parameters);
}
tool_rules.push_back(
builder.add_rule(
name + "-call",
"\"{\" space "
"( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? "
" \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space "
" \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " "
"\"}\" space"));
});
// Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name.
data.grammar_triggers.push_back({
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
"\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"", // + name + "\"[\\s\\S]*",
});
if (!builtin_tools.empty()) {
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
data.preserved_tokens.push_back("<|python_tag|>");
}
std::vector<std::string> kvs;
for (const auto & [key, value] : parameters.at("properties").items()) {
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
}
tool_rules.push_back(
builder.add_rule(
name + "-call",
"\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\""));
builtin_tools.push_back(name);
return true;
};
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
std::string name = function.at("name");
auto parameters = function.at("parameters");
builder.resolve_refs(parameters);
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
if (allow_python_tag_builtin_tools) {
handle_builtin_tool(name, parameters);
}
tool_rules.push_back(
builder.add_rule(
name + "-call",
"\"{\" space "
"( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? "
" \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space "
" \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " "
"\"}\" space"));
// Allow a few empty lines on top of the usual constrained json schema space rule.
builder.add_rule("root", string_join(tool_rules, " | "));
data.additional_stops.push_back("<|eom_id|>");
});
// Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name.
data.grammar_triggers.push_back({
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
"\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"", // + name + "\"[\\s\\S]*",
});
if (!builtin_tools.empty()) {
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
data.preserved_tokens.push_back("<|python_tag|>");
}
// Allow a few empty lines on top of the usual constrained json schema space rule.
builder.add_rule("root", string_join(tool_rules, " | "));
});
data.additional_stops.push_back("<|eom_id|>");
data.format = allow_python_tag_builtin_tools && !builtin_tools.empty()
? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS
: COMMON_CHAT_FORMAT_LLAMA_3_X;
} else {
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
}
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
{"date_string", format_time(inputs.now, "%d %b %Y")},
{"tools_in_user_message", false},
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
});
data.format = allow_python_tag_builtin_tools && !builtin_tools.empty()
? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS
: COMMON_CHAT_FORMAT_LLAMA_3_X;
return data;
}
static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) {
@@ -1150,7 +1165,7 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c
LOG_DBG("%s\n", __func__);
common_chat_params data;
data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
{"datetime", "Jan 29 2025 13:00:00 GMT"},
{"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")},
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
});
if (inputs.tools.is_array() && !inputs.tools.empty()) {
@@ -1285,55 +1300,59 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
common_chat_params data;
json tools = inputs.tools.is_null() ? inputs.tools : json::array();
std::string python_code_argument_name;
auto has_raw_python = false;
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
const auto & parameters = function.at("parameters");
std::string name = function.at("name");
if (name == "python" || name == "ipython") {
if (!parameters.contains("type")) {
throw std::runtime_error("Missing type in python tool");
}
has_raw_python = true;
const auto & type = parameters.at("type");
if (type == "object") {
auto properties = parameters.at("properties");
for (auto it = properties.begin(); it != properties.end(); ++it) {
if (it.value().at("type") == "string") {
if (!python_code_argument_name.empty()) {
throw std::runtime_error("Multiple string arguments found in python tool");
if (!inputs.tools.is_null()) {
std::string python_code_argument_name;
auto has_raw_python = false;
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
const auto & parameters = function.at("parameters");
std::string name = function.at("name");
if (name == "python" || name == "ipython") {
if (!parameters.contains("type")) {
throw std::runtime_error("Missing type in python tool");
}
has_raw_python = true;
const auto & type = parameters.at("type");
if (type == "object") {
auto properties = parameters.at("properties");
for (auto it = properties.begin(); it != properties.end(); ++it) {
if (it.value().at("type") == "string") {
if (!python_code_argument_name.empty()) {
throw std::runtime_error("Multiple string arguments found in python tool");
}
python_code_argument_name = it.key();
}
python_code_argument_name = it.key();
}
if (python_code_argument_name.empty()) {
throw std::runtime_error("No string argument found in python tool");
}
} else if (type != "string") {
throw std::runtime_error("Invalid type in python tool: " + type.dump());
}
if (python_code_argument_name.empty()) {
throw std::runtime_error("No string argument found in python tool");
}
} else if (type != "string") {
throw std::runtime_error("Invalid type in python tool: " + type.dump());
}
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
});
if (has_raw_python) {
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
data.preserved_tokens.push_back("<|python_tag|>");
}
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
});
if (has_raw_python) {
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
data.preserved_tokens.push_back("<|python_tag|>");
}
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
});
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
} else {
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
}
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
// TODO: if (has_raw_python)
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
return data;
}
static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::string & input) {
@@ -1593,6 +1612,7 @@ static common_chat_params common_chat_templates_apply_jinja(
params.extract_reasoning = inputs.extract_reasoning;
params.tool_choice = inputs.tool_choice;
params.grammar = inputs.grammar;
params.now = inputs.now;
if (!inputs.json_schema.empty()) {
params.json_schema = json::parse(inputs.json_schema);
}
@@ -1644,21 +1664,21 @@ static common_chat_params common_chat_templates_apply_jinja(
return common_chat_params_init_firefunction_v2(tmpl, params);
}
// Plain handler (no tools)
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
return common_chat_params_init_without_tools(tmpl, params);
}
// Functionary v3.1 (w/ tools)
if (src.find("<|start_header_id|>") != std::string::npos
&& src.find("<function=") != std::string::npos) {
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params);
}
// Llama 3.1, 3.2, 3.3 (w/ tools)
// Llama 3.1, 3.2, 3.3 (also requires date_string so using it even w/o tools)
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
return common_chat_params_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools);
return common_chat_params_init_llama_3_x(tmpl, params, allow_python_tag_builtin_tools);
}
// Plain handler (no tools)
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
return common_chat_params_init_without_tools(tmpl, params);
}
// Mistral Nemo (w/ tools)
+2
View File
@@ -3,6 +3,7 @@
#pragma once
#include "common.h"
#include <chrono>
#include <string>
#include <vector>
@@ -71,6 +72,7 @@ struct common_chat_templates_inputs {
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
bool parallel_tool_calls = false;
bool extract_reasoning = true;
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
};
struct common_chat_params {
+37
View File
@@ -443,6 +443,25 @@ void string_replace_all(std::string & s, const std::string & search, const std::
s = std::move(builder);
}
bool string_ends_with(const std::string_view & str, const std::string_view & suffix) {
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
}
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop) {
if (!str.empty() && !stop.empty()) {
const char text_last_char = str.back();
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
if (stop[char_index] == text_last_char) {
const auto current_partial = stop.substr(0, char_index + 1);
if (string_ends_with(str, current_partial)) {
return str.size() - char_index - 1;
}
}
}
}
return std::string::npos;
}
std::string regex_escape(const std::string & s) {
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
return std::regex_replace(s, special_chars, "\\$0");
@@ -1113,6 +1132,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.offload_kqv = !params.no_kv_offload;
cparams.flash_attn = params.flash_attn;
cparams.no_perf = params.no_perf;
cparams.op_offload = !params.no_op_offload;
if (params.reranking) {
cparams.embeddings = true;
@@ -1564,3 +1584,20 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
return result;
}
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride) {
const int64_t ne_datapoint = llama_n_ctx(ctx);
const int64_t ndata = (tokens.size() - ne_datapoint - 1) / stride;
ggml_opt_dataset_t result = ggml_opt_dataset_init(
GGML_TYPE_I32, GGML_TYPE_I32, ne_datapoint, ne_datapoint, ndata, /*ndata_shard =*/ 1);
llama_token * data = (llama_token *) ggml_opt_dataset_data(result)->data;
llama_token * labels = (llama_token *) ggml_opt_dataset_labels(result)->data;
for (int64_t idata = 0; idata < ndata; ++idata) {
memcpy(data + idata*ne_datapoint, tokens.data() + idata*stride + 0, ne_datapoint*sizeof(llama_token));
memcpy(labels + idata*ne_datapoint, tokens.data() + idata*stride + 1, ne_datapoint*sizeof(llama_token));
}
return result;
}
+11 -4
View File
@@ -6,6 +6,7 @@
#include <set>
#include <string>
#include <string_view>
#include <vector>
#include <sstream>
@@ -332,6 +333,7 @@ struct common_params {
bool no_kv_offload = false; // disable KV offloading
bool warmup = true; // warmup run
bool check_tensors = false; // validate tensor data
bool no_op_offload = false; // globally disable offload host tensor operations to device
bool single_turn = false; // single turn chat conversation
@@ -502,10 +504,9 @@ static bool string_starts_with(const std::string & str,
return str.rfind(prefix, 0) == 0;
}
static bool string_ends_with(const std::string & str,
const std::string & suffix) { // While we wait for C++20's std::string::ends_with...
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
}
// While we wait for C++20's std::string::ends_with...
bool string_ends_with(const std::string_view & str, const std::string_view & suffix);
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop);
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
void string_process_escapes(std::string & input);
@@ -665,3 +666,9 @@ const char * const LLM_KV_SPLIT_COUNT = "split.count";
const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
}
//
// training utils
//
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride);
+9 -5
View File
@@ -13,10 +13,12 @@
#include <chrono>
#include <cstddef>
#include <cstdio>
#include <ctime>
#include <exception>
#include <iomanip>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>
@@ -393,8 +395,8 @@ class chat_template {
for (const auto & message_ : adjusted_messages) {
auto message = message_;
if (!message.contains("role") || !message.contains("content")) {
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
if (!message.contains("role") || (!message.contains("content") && !message.contains("tool_calls"))) {
throw std::runtime_error("message must have 'role' and one of 'content' or 'tool_calls' fields: " + message.dump());
}
std::string role = message.at("role");
@@ -415,7 +417,6 @@ class chat_template {
}
}
if (polyfill_tool_calls) {
auto content = message.at("content");
auto tool_calls = json::array();
for (const auto & tool_call : message.at("tool_calls")) {
if (tool_call.at("type") != "function") {
@@ -434,8 +435,11 @@ class chat_template {
auto obj = json {
{"tool_calls", tool_calls},
};
if (!content.is_null() && !content.empty()) {
obj["content"] = content;
if (message.contains("content")) {
auto content = message.at("content");
if (!content.is_null() && !content.empty()) {
obj["content"] = content;
}
}
message["content"] = obj.dump(2);
message.erase("tool_calls");
+69 -36
View File
@@ -11,6 +11,7 @@
#include <algorithm>
#include <cctype>
#include <cstddef>
#include <cstdint>
#include <cmath>
#include <exception>
#include <functional>
@@ -233,7 +234,7 @@ public:
}
} else if (is_object()) {
if (!index.is_hashable())
throw std::runtime_error("Unashable type: " + index.dump());
throw std::runtime_error("Unhashable type: " + index.dump());
auto it = object_->find(index.primitive_);
if (it == object_->end())
throw std::runtime_error("Key not found: " + index.dump());
@@ -252,7 +253,7 @@ public:
auto index = key.get<int>();
return array_->at(index < 0 ? array_->size() + index : index);
} else if (object_) {
if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
if (!key.is_hashable()) throw std::runtime_error("Unhashable type: " + dump());
auto it = object_->find(key.primitive_);
if (it == object_->end()) return Value();
return it->second;
@@ -261,7 +262,7 @@ public:
}
void set(const Value& key, const Value& value) {
if (!object_) throw std::runtime_error("Value is not an object: " + dump());
if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
if (!key.is_hashable()) throw std::runtime_error("Unhashable type: " + dump());
(*object_)[key.primitive_] = value;
}
Value call(const std::shared_ptr<Context> & context, ArgumentsValue & args) const {
@@ -398,7 +399,7 @@ public:
}
return false;
} else if (object_) {
if (!value.is_hashable()) throw std::runtime_error("Unashable type: " + value.dump());
if (!value.is_hashable()) throw std::runtime_error("Unhashable type: " + value.dump());
return object_->find(value.primitive_) != object_->end();
} else {
throw std::runtime_error("contains can only be called on arrays and objects: " + dump());
@@ -416,7 +417,7 @@ public:
return const_cast<Value*>(this)->at(index);
}
Value& at(const Value & index) {
if (!index.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
if (!index.is_hashable()) throw std::runtime_error("Unhashable type: " + dump());
if (is_array()) return array_->at(index.get<int>());
if (is_object()) return object_->at(index.primitive_);
throw std::runtime_error("Value is not an array or object: " + dump());
@@ -676,8 +677,8 @@ public:
class VariableExpr : public Expression {
std::string name;
public:
VariableExpr(const Location & location, const std::string& n)
: Expression(location), name(n) {}
VariableExpr(const Location & loc, const std::string& n)
: Expression(loc), name(n) {}
std::string get_name() const { return name; }
Value do_evaluate(const std::shared_ptr<Context> & context) const override {
if (!context->contains(name)) {
@@ -1200,9 +1201,9 @@ public:
class SliceExpr : public Expression {
public:
std::shared_ptr<Expression> start, end;
SliceExpr(const Location & loc, std::shared_ptr<Expression> && s, std::shared_ptr<Expression> && e)
: Expression(loc), start(std::move(s)), end(std::move(e)) {}
std::shared_ptr<Expression> start, end, step;
SliceExpr(const Location & loc, std::shared_ptr<Expression> && s, std::shared_ptr<Expression> && e, std::shared_ptr<Expression> && st = nullptr)
: Expression(loc), start(std::move(s)), end(std::move(e)), step(std::move(st)) {}
Value do_evaluate(const std::shared_ptr<Context> &) const override {
throw std::runtime_error("SliceExpr not implemented");
}
@@ -1219,18 +1220,35 @@ public:
if (!index) throw std::runtime_error("SubscriptExpr.index is null");
auto target_value = base->evaluate(context);
if (auto slice = dynamic_cast<SliceExpr*>(index.get())) {
auto start = slice->start ? slice->start->evaluate(context).get<int64_t>() : 0;
auto end = slice->end ? slice->end->evaluate(context).get<int64_t>() : (int64_t) target_value.size();
auto len = target_value.size();
auto wrap = [len](int64_t i) -> int64_t {
if (i < 0) {
return i + len;
}
return i;
};
int64_t step = slice->step ? slice->step->evaluate(context).get<int64_t>() : 1;
if (!step) {
throw std::runtime_error("slice step cannot be zero");
}
int64_t start = slice->start ? wrap(slice->start->evaluate(context).get<int64_t>()) : (step < 0 ? len - 1 : 0);
int64_t end = slice->end ? wrap(slice->end->evaluate(context).get<int64_t>()) : (step < 0 ? -1 : len);
if (target_value.is_string()) {
std::string s = target_value.get<std::string>();
if (start < 0) start = s.size() + start;
if (end < 0) end = s.size() + end;
return s.substr(start, end - start);
std::string result;
if (start < end && step == 1) {
result = s.substr(start, end - start);
} else {
for (int64_t i = start; step > 0 ? i < end : i > end; i += step) {
result += s[i];
}
}
return result;
} else if (target_value.is_array()) {
if (start < 0) start = target_value.size() + start;
if (end < 0) end = target_value.size() + end;
auto result = Value::array();
for (auto i = start; i < end; ++i) {
for (int64_t i = start; step > 0 ? i < end : i > end; i += step) {
result.push_back(target_value.at(i));
}
return result;
@@ -1305,6 +1323,8 @@ public:
if (name == "iterable") return l.is_iterable();
if (name == "sequence") return l.is_array();
if (name == "defined") return !l.is_null();
if (name == "true") return l.to_bool();
if (name == "false") return !l.to_bool();
throw std::runtime_error("Unknown type for 'is' operator: " + name);
};
auto value = eval();
@@ -1520,6 +1540,10 @@ public:
vargs.expectArgs("endswith method", {1, 1}, {0, 0});
auto suffix = vargs.args[0].get<std::string>();
return suffix.length() <= str.length() && std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
} else if (method->get_name() == "startswith") {
vargs.expectArgs("startswith method", {1, 1}, {0, 0});
auto prefix = vargs.args[0].get<std::string>();
return prefix.length() <= str.length() && std::equal(prefix.begin(), prefix.end(), str.begin());
} else if (method->get_name() == "title") {
vargs.expectArgs("title method", {0, 0}, {0, 0});
auto res = str;
@@ -2082,28 +2106,37 @@ private:
while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) {
if (!consumeToken("[").empty()) {
std::shared_ptr<Expression> index;
std::shared_ptr<Expression> index;
auto slice_loc = get_location();
std::shared_ptr<Expression> start, end, step;
bool has_first_colon = false, has_second_colon = false;
if (!peekSymbols({ ":" })) {
start = parseExpression();
}
if (!consumeToken(":").empty()) {
has_first_colon = true;
if (!peekSymbols({ ":", "]" })) {
end = parseExpression();
}
if (!consumeToken(":").empty()) {
auto slice_end = parseExpression();
index = std::make_shared<SliceExpr>(slice_end->location, nullptr, std::move(slice_end));
} else {
auto slice_start = parseExpression();
if (!consumeToken(":").empty()) {
consumeSpaces();
if (peekSymbols({ "]" })) {
index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), nullptr);
} else {
auto slice_end = parseExpression();
index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), std::move(slice_end));
}
} else {
index = std::move(slice_start);
has_second_colon = true;
if (!peekSymbols({ "]" })) {
step = parseExpression();
}
}
if (!index) throw std::runtime_error("Empty index in subscript");
if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript");
}
value = std::make_shared<SubscriptExpr>(value->location, std::move(value), std::move(index));
if ((has_first_colon || has_second_colon) && (start || end || step)) {
index = std::make_shared<SliceExpr>(slice_loc, std::move(start), std::move(end), std::move(step));
} else {
index = std::move(start);
}
if (!index) throw std::runtime_error("Empty index in subscript");
if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript");
value = std::make_shared<SubscriptExpr>(value->location, std::move(value), std::move(index));
} else if (!consumeToken(".").empty()) {
auto identifier = parseIdentifier();
if (!identifier) throw std::runtime_error("Expected identifier in subscript");
+204
View File
@@ -0,0 +1,204 @@
#include "regex-partial.h"
#include "common.h"
#include <functional>
#include <optional>
common_regex::common_regex(const std::string & pattern) :
pattern(pattern),
rx(pattern),
rx_reversed_partial(regex_to_reversed_partial_regex(pattern)) {}
common_regex_match common_regex::search(const std::string & input, size_t pos, bool as_match) const {
std::smatch match;
if (pos > input.size()) {
throw std::runtime_error("Position out of bounds");
}
auto start = input.begin() + pos;
auto found = as_match
? std::regex_match(start, input.end(), match, rx)
: std::regex_search(start, input.end(), match, rx);
if (found) {
common_regex_match res;
res.type = COMMON_REGEX_MATCH_TYPE_FULL;
for (size_t i = 0; i < match.size(); ++i) {
auto begin = pos + match.position(i);
res.groups.emplace_back(begin, begin + match.length(i));
}
return res;
}
std::match_results<std::string::const_reverse_iterator> srmatch;
if (std::regex_match(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial)) {
auto group = srmatch[1].str();
if (group.length() != 0) {
auto it = srmatch[1].second.base();
// auto position = static_cast<size_t>(std::distance(input.begin(), it));
if ((!as_match) || it == input.begin()) {
common_regex_match res;
res.type = COMMON_REGEX_MATCH_TYPE_PARTIAL;
const size_t begin = std::distance(input.begin(), it);
const size_t end = input.size();
if (begin == std::string::npos || end == std::string::npos || begin > end) {
throw std::runtime_error("Invalid range");
}
res.groups.push_back({begin, end});
return res;
}
}
}
return {};
}
/*
Transforms a regex pattern to a partial match pattern that operates on a reversed input string to find partial final matches of the original pattern.
Ideally we'd like to use boost::match_partial (https://beta.boost.org/doc/libs/1_59_0/libs/regex/doc/html/boost_regex/partial_matches.html)
to see if a string ends with a partial regex match, but but it's not in std::regex yet.
Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input.
- /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:(?:d)?c)?b)?a).*
- /a|b/ -> (a|b).*
- /a*?/ -> error, could match ""
- /a*b/ -> ((?:b)?a*+).* (final repetitions become eager)
- /.*?ab/ -> ((?:b)?a).* (merge .*)
- /a.*?b/ -> ((?:b)?.*?a).* (keep reluctant matches)
- /a(bc)d/ -> ((?:(?:d)?(?:(?:c)?b))?a).*
- /a(bc|de)/ -> ((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a).*
- /ab{2,4}c/ -> abbb?b?c -> ((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a).*
The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern
(i.e. just where the final .* starts in the inverted pattern; all other groups are turned into non-capturing groups, and reluctant quantifiers are ignored)
*/
std::string regex_to_reversed_partial_regex(const std::string & pattern) {
auto it = pattern.begin();
const auto end = pattern.end();
std::function<std::string()> process = [&]() {
std::vector<std::vector<std::string>> alternatives(1);
std::vector<std::string> * sequence = &alternatives.back();
while (it != end) {
if (*it == '[') {
auto start = it;
++it;
while (it != end) {
if ((*it == '\\') && (++it != end)) {
++it;
} else if ((it != end) && (*it == ']')) {
break;
} else {
++it;
}
}
if (it == end) {
throw std::runtime_error("Unmatched '[' in pattern");
}
++it;
sequence->push_back(std::string(start, it));
} else if (*it == '*' || *it == '?' || *it == '+') {
if (sequence->empty()) {
throw std::runtime_error("Quantifier without preceding element");
}
sequence->back() += *it;
auto is_star = *it == '*';
++it;
if (is_star) {
if (*it == '?') {
++it;
}
}
} else if (*it == '{') {
if (sequence->empty()) {
throw std::runtime_error("Repetition without preceding element");
}
++it;
auto start = it;
while (it != end && *it != '}') {
++it;
}
if (it == end) {
throw std::runtime_error("Unmatched '{' in pattern");
}
auto parts = string_split(std::string(start, it), ",");
++it;
if (parts.size() > 2) {
throw std::runtime_error("Invalid repetition range in pattern");
}
auto parseOptInt = [&](const std::string & s, const std::optional<int> & def = std::nullopt) -> std::optional<int> {
if (s.empty()) {
return def;
}
return std::stoi(s);
};
auto min = parseOptInt(parts[0], 0);
auto max = parts.size() == 1 ? min : parseOptInt(parts[1]);
if (min && max && *max < *min) {
throw std::runtime_error("Invalid repetition range in pattern");
}
// Brutal but... let's repeat at least min times, then ? for the delta between min & max (or * for unbounded)
auto part = sequence->back();
sequence->pop_back();
for (int i = 0; i < *min; i++) {
sequence->push_back(part);
}
if (max) {
for (int i = *min; i < *max; i++) {
sequence->push_back(part + "?");
}
} else {
sequence->push_back(part + "*");
}
} else if (*it == '(') {
++it;
if (it != end && *it == '?' && (it + 1 != end) && *(it + 1) == ':') {
it += 2;
}
auto sub = process();
if (*it != ')') {
throw std::runtime_error("Unmatched '(' in pattern");
}
++it;
auto & part = sequence->emplace_back("(?:");
part += sub;
part += ")";
} else if (*it == ')') {
break;
} else if (*it == '|') {
++it;
alternatives.emplace_back();
sequence = &alternatives.back();
} else if (*it == '\\' && (++it != end)) {
auto str = std::string("\\") + *it;
sequence->push_back(str);
++it;
} else if (it != end) {
sequence->push_back(std::string(1, *it));
++it;
}
}
// /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:d)?c)?b)?a).*
// if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group
// We'll do the outermost capturing group and final .* in the enclosing function.
std::vector<std::string> res_alts;
for (const auto & parts : alternatives) {
auto & res = res_alts.emplace_back();
for (size_t i = 0; i < parts.size() - 1; i++) {
res += "(?:";
}
for (auto it = parts.rbegin(); it != parts.rend(); ++it) {
res += *it;
if (it != parts.rend() - 1) {
res += ")?";
}
}
}
return string_join(res_alts, "|");
};
auto res = process();
if (it != end) {
throw std::runtime_error("Unmatched '(' in pattern");
}
return "(" + res + ")[\\s\\S]*";
}
+56
View File
@@ -0,0 +1,56 @@
#pragma once
#include <regex>
#include <string>
enum common_regex_match_type {
COMMON_REGEX_MATCH_TYPE_NONE,
COMMON_REGEX_MATCH_TYPE_PARTIAL,
COMMON_REGEX_MATCH_TYPE_FULL,
};
struct common_string_range {
size_t begin;
size_t end;
common_string_range(size_t begin, size_t end) : begin(begin), end(end) {
if (begin > end) {
throw std::runtime_error("Invalid range");
}
}
// prevent default ctor
common_string_range() = delete;
bool empty() const {
return begin == end;
}
bool operator==(const common_string_range & other) const {
return begin == other.begin && end == other.end;
}
};
struct common_regex_match {
common_regex_match_type type = COMMON_REGEX_MATCH_TYPE_NONE;
std::vector<common_string_range> groups;
bool operator==(const common_regex_match & other) const {
return type == other.type && groups == other.groups;
}
bool operator!=(const common_regex_match & other) const {
return !(*this == other);
}
};
class common_regex {
std::string pattern;
std::regex rx;
std::regex rx_reversed_partial;
public:
explicit common_regex(const std::string & pattern);
common_regex_match search(const std::string & input, size_t pos, bool as_match = false) const;
const std::string & str() const { return pattern; }
};
// For testing only (pretty print of failures).
std::string regex_to_reversed_partial_regex(const std::string & pattern);
+23 -2
View File
@@ -2069,6 +2069,9 @@ class Llama4Model(LlamaModel):
self.gguf_writer.add_expert_feed_forward_length(self.hparams["intermediate_size_moe"])
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
if name.startswith("language_model."):
name = name.replace("language_model.", "")
# split the gate_up into gate and up
if "gate_up_proj" in name:
name_up = name.replace("gate_up_proj", "up_proj.weight")
@@ -5746,11 +5749,20 @@ class GraniteModel(LlamaModel):
logger.info("gguf: (granite) logits_scale = %s", logits_scale)
@ModelBase.register("GraniteMoeForCausalLM")
@ModelBase.register("GraniteMoeForCausalLM", "GraniteMoeSharedForCausalLM")
class GraniteMoeModel(GraniteModel):
"""Conversion for IBM's GraniteMoeForCausalLM"""
model_arch = gguf.MODEL_ARCH.GRANITE_MOE
def set_gguf_parameters(self):
"""GraniteMoeShared uses GraniteMoe parameters plus the following:
- shared_intermediate_size
"""
super().set_gguf_parameters()
if shared_feed_forward_length := self.hparams.get("shared_intermediate_size"):
self.gguf_writer.add_expert_shared_feed_forward_length(shared_feed_forward_length)
logger.info("gguf: (granitemoeshared) shared_feed_forward_length = %s", shared_feed_forward_length)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
"""In modeling_granitemoe, the JetMoe implementation of parallel experts
is used. This essentially merges w1 and w3 into a single tensor with 2x
@@ -5761,12 +5773,21 @@ class GraniteMoeModel(GraniteModel):
if name.endswith("block_sparse_moe.input_linear.weight"):
ffn_dim = self.hparams["intermediate_size"]
assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * intermediate_size"
gate, up = data_torch[..., :ffn_dim, :], data_torch[..., ffn_dim:, :]
gate, up = data_torch.split(ffn_dim, dim=-2)
return [
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_EXP, bid), gate),
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXP, bid), up),
]
if name.endswith("shared_mlp.input_linear.weight"):
ffn_dim = self.hparams["shared_intermediate_size"]
assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * shared_intermediate_size"
gate, up = data_torch.split(ffn_dim, dim=-2)
return [
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SHEXP, bid), gate),
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SHEXP, bid), up),
]
return super().modify_tensors(data_torch, name, bid)
+2
View File
@@ -731,6 +731,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
| GGML_SYCL_DEVICE_ARCH | Optional (except for AMD) | Set the SYCL device architecture, optional except for AMD. Setting the device architecture can improve the performance. See the table [--offload-arch](https://github.com/intel/llvm/blob/sycl/sycl/doc/design/OffloadDesign.md#--offload-arch) for a list of valid architectures. |
| GGML_SYCL_F16 | OFF *(default)* \|ON *(optional)* | Enable FP16 build with SYCL code path. |
| GGML_SYCL_GRAPH | ON *(default)* \|OFF *(Optional)* | Enable build with [SYCL Graph extension](https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc). |
| GGML_SYCL_DNN | ON *(default)* \|OFF *(Optional)* | Enable build with oneDNN. |
| CMAKE_C_COMPILER | `icx` *(Linux)*, `icx/cl` *(Windows)* | Set `icx` compiler for SYCL code path. |
| CMAKE_CXX_COMPILER | `icpx` *(Linux)*, `icx` *(Windows)* | Set `icpx/icx` compiler for SYCL code path. |
@@ -741,6 +742,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
| GGML_SYCL_DEBUG | 0 (default) or 1 | Enable log function by macro: GGML_SYCL_DEBUG |
| GGML_SYCL_DISABLE_OPT | 0 (default) or 1 | Disable optimize features based on Intel GPU type, to compare the performance increase |
| GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because graph performance isn't yet better than non-graph performance. |
| GGML_SYCL_DISABLE_DNN | 0 (default) or 1 | Disable running computations through oneDNN and always use oneMKL. |
| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |
+1 -1
View File
@@ -31,7 +31,7 @@ llama-server -hf ggml-org/gemma-3-4b-it-GGUF --no-mmproj-offload
## Pre-quantized models
These are ready-to-use models, most of them come with `Q4_K_M` quantization by default.
These are ready-to-use models, most of them come with `Q4_K_M` quantization by default. They can be found at the Hugging Face page of the ggml-org: https://huggingface.co/ggml-org
Replaces the `(tool_name)` with the name of binary you want to use. For example, `llama-mtmd-cli` or `llama-server`
+1
View File
@@ -32,6 +32,7 @@ else()
add_subdirectory(speculative)
add_subdirectory(speculative-simple)
add_subdirectory(gen-docs)
add_subdirectory(training)
if (NOT GGML_BACKEND_DL)
add_subdirectory(convert-llama2c-to-ggml)
# these examples use the backends directly and cannot be built with dynamic loading
+5
View File
@@ -0,0 +1,5 @@
set(TARGET llama-finetune)
add_executable(${TARGET} finetune.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
+17
View File
@@ -0,0 +1,17 @@
# llama.cpp/examples/training
This directory contains examples related to language model training using llama.cpp/GGML.
So far finetuning is technically functional (for FP32 models and limited hardware setups) but the code is very much WIP.
Finetuning of Stories 260K and LLaMA 3.2 1b seems to work with 24 GB of memory.
**For CPU training, compile llama.cpp without any additional backends such as CUDA.**
**For CUDA training, use the maximum number of GPU layers.**
Proof of concept:
``` sh
export model_name=llama_3.2-1b && export quantization=f32
./build/bin/finetune --file wikitext-2-raw/wiki.test.raw -ngl 999 --model models/${model_name}-${quantization}.gguf -c 512 -b 512 -ub 512
./build/bin/perplexity --file wikitext-2-raw/wiki.test.raw -ngl 999 --model finetuned-model.gguf
```
The perplexity value of the finetuned model should be lower after training on the test set for 2 epochs.
+96
View File
@@ -0,0 +1,96 @@
#include "arg.h"
#include "common.h"
#include "log.h"
#include "llama.h"
#include <cmath>
#include <cstdio>
#include <cstring>
#include <ctime>
#include <vector>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
int main(int argc, char ** argv) {
common_params params;
params.escape = false;
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) {
return 1;
}
if (params.use_mmap) {
LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n", __func__);
params.use_mmap = false;
}
if (params.cache_type_k != GGML_TYPE_F32) {
LOG_INF("%s: force changing k cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
params.cache_type_k = GGML_TYPE_F32;
}
if (params.cache_type_v != GGML_TYPE_F32) {
LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
params.cache_type_v = GGML_TYPE_F32;
}
common_init();
llama_backend_init();
llama_numa_init(params.numa);
// load the model and apply lora adapter, if any
common_init_result llama_init = common_init_from_params(params);
llama_model_ptr & model = llama_init.model;
llama_context_ptr & ctx = llama_init.context;
if (model == NULL) {
LOG_ERR("%s: unable to load model\n", __func__);
return 1;
}
// print system information
{
LOG_INF("\n");
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
}
constexpr float val_split = 0.05f;
std::vector<llama_token> tokens = common_tokenize(ctx.get(), params.prompt, true);
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2);
struct ggml_opt_optimizer_params optimizer_params = ggml_opt_get_default_optimizer_params(nullptr);
optimizer_params.adamw.alpha = 1e-7f; // learning rate
struct llama_opt_params lopt_params {
/*n_ctx_train =*/ 0,
/*param_filter =*/ llama_opt_param_filter_all,
/*param_filter_ud =*/ nullptr,
/*get_opt_pars =*/ ggml_opt_get_constant_optimizer_params,
/*get_opt_pars_ud =*/ &optimizer_params,
};
llama_opt_init(ctx.get(), model.get(), lopt_params);
const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - val_split);
ggml_opt_result_t result_train = ggml_opt_result_init();
ggml_opt_result_t result_eval = ggml_opt_result_init();
for (int epoch = 0; epoch < 2; ++epoch) {
llama_opt_epoch(ctx.get(), dataset, result_train, result_eval, idata_split,
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
fprintf(stderr, "\n");
ggml_opt_result_reset(result_train);
ggml_opt_result_reset(result_eval);
}
ggml_opt_result_free(result_train);
ggml_opt_result_free(result_eval);
llama_model_save_to_file(model.get(), "finetuned-model.gguf");
llama_backend_free();
return 0;
}
+1
View File
@@ -193,6 +193,7 @@ option(GGML_RPC "ggml: use RPC"
option(GGML_SYCL "ggml: use SYCL" OFF)
option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF)
option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON)
option(GGML_SYCL_DNN "ggml: enable oneDNN in the SYCL backend" ON)
set (GGML_SYCL_TARGET "INTEL" CACHE STRING
"ggml: sycl target device")
set (GGML_SYCL_DEVICE_ARCH "" CACHE STRING
+2 -2
View File
@@ -248,7 +248,7 @@ extern "C" {
// preferrably to run on the same backend as the buffer
ggml_backend_buffer_set_usage(buf_weights, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, GGML_DEFAULT_GRAPH_SIZE, false);
sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, GGML_DEFAULT_GRAPH_SIZE, false, true);
// initialize buffers from a max size graph (optional)
reserve_graph = build_graph(sched, max_batch_size);
@@ -289,7 +289,7 @@ extern "C" {
typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
// Initialize a backend scheduler, backends with low index are given priority over backends with high index
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel);
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel, bool op_offload);
GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
// Initialize backend buffers from a measure graph
+47 -28
View File
@@ -37,13 +37,16 @@ extern "C" {
// ====== Dataset ======
GGML_API ggml_opt_dataset_t ggml_opt_dataset_init(
int64_t ne_datapoint, // number of elements per datapoint
int64_t ne_label, // number of elements per label
int64_t ndata, // total number of datapoints/labels
int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
enum ggml_type type_data, // the type for the internal data tensor
enum ggml_type type_label, // the type for the internal labels tensor
int64_t ne_datapoint, // number of elements per datapoint
int64_t ne_label, // number of elements per label
int64_t ndata, // total number of datapoints/labels
int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
GGML_API void ggml_opt_dataset_free(ggml_opt_dataset_t dataset);
// get underlying tensors that store the data
GGML_API int64_t ggml_opt_dataset_ndata (ggml_opt_dataset_t dataset);
GGML_API struct ggml_tensor * ggml_opt_dataset_data (ggml_opt_dataset_t dataset); // shape = [ne_datapoint, ndata]
GGML_API struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset); // shape = [nd_label, ndata]
@@ -56,13 +59,19 @@ extern "C" {
struct ggml_tensor * data_batch, // shape = [ne_datapoint, ndata_batch]
struct ggml_tensor * labels_batch, // shape = [ne_label, ndata_batch]
int64_t ibatch);
GGML_API void ggml_opt_dataset_get_batch_host(
ggml_opt_dataset_t dataset,
void * data_batch,
size_t nb_data_batch,
void * labels_batch,
int64_t ibatch);
// ====== Model / Context ======
enum ggml_opt_build_type {
GGML_OPT_BUILD_TYPE_FORWARD,
GGML_OPT_BUILD_TYPE_GRAD,
GGML_OPT_BUILD_TYPE_OPT,
GGML_OPT_BUILD_TYPE_FORWARD = 10,
GGML_OPT_BUILD_TYPE_GRAD = 20,
GGML_OPT_BUILD_TYPE_OPT = 30,
};
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
@@ -81,20 +90,22 @@ extern "C" {
// userdata can be used to pass arbitrary data
typedef struct ggml_opt_optimizer_params (*ggml_opt_get_optimizer_params)(void * userdata);
// returns the default optimizer params (constant)
// returns the default optimizer params (constant, hard-coded values)
// userdata is not used
GGML_API struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata);
// casts userdata to ggml_opt_optimizer_params and returns it
GGML_API struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata);
// parameters for initializing a new optimization context
struct ggml_opt_params {
ggml_backend_sched_t backend_sched; // defines which backends are used to construct the compute graphs
struct ggml_context * ctx_compute; // created in user code, holds non-static tensors
// the forward graph is defined by inputs and outputs
// those tensors and all tensors inbetween are not intended to be reusable between multiple optimization contexts
struct ggml_tensor * inputs;
struct ggml_tensor * outputs;
// by default the forward graph needs to be reconstructed for each eval
// if ctx_compute, inputs, and outputs are set the graphs are instead allocated statically
struct ggml_context * ctx_compute;
struct ggml_tensor * inputs;
struct ggml_tensor * outputs;
enum ggml_opt_loss_type loss_type;
enum ggml_opt_build_type build_type;
@@ -107,12 +118,9 @@ extern "C" {
// get parameters for an optimization context with defaults set where possible
// parameters for which no sensible defaults exist are supplied as arguments to this function
GGML_API ggml_opt_params ggml_opt_default_params(
ggml_backend_sched_t backend_sched,
struct ggml_context * ctx_compute,
struct ggml_tensor * inputs,
struct ggml_tensor * outputs,
enum ggml_opt_loss_type loss_type);
GGML_API struct ggml_opt_params ggml_opt_default_params(
ggml_backend_sched_t backend_sched,
enum ggml_opt_loss_type loss_type);
GGML_API ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params);
GGML_API void ggml_opt_free(ggml_opt_context_t opt_ctx);
@@ -121,6 +129,7 @@ extern "C" {
GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer);
// get underlying tensors that store data
// if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc
GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor
GGML_API struct ggml_tensor * ggml_opt_outputs( ggml_opt_context_t opt_ctx); // forward graph output tensor
GGML_API struct ggml_tensor * ggml_opt_labels( ggml_opt_context_t opt_ctx); // labels to compare outputs against
@@ -128,11 +137,12 @@ extern "C" {
GGML_API struct ggml_tensor * ggml_opt_pred( ggml_opt_context_t opt_ctx); // predictions made by outputs
GGML_API struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx); // number of matching predictions between outputs and labels
// get the gradient accumulator for a node from the forward graph
GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node);
// ====== Optimization Result ======
GGML_API ggml_opt_result_t ggml_opt_result_init();
GGML_API ggml_opt_result_t ggml_opt_result_init(void);
GGML_API void ggml_opt_result_free(ggml_opt_result_t result);
GGML_API void ggml_opt_result_reset(ggml_opt_result_t result);
@@ -144,11 +154,20 @@ extern "C" {
// ====== Computation ======
// do forward pass, increment result if not NULL
GGML_API void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
// if not using static graphs, this function must be called prior to ggml_opt_alloc
GGML_API void ggml_opt_prepare_alloc(
ggml_opt_context_t opt_ctx,
struct ggml_context * ctx_compute,
struct ggml_cgraph * gf,
struct ggml_tensor * inputs,
struct ggml_tensor * outputs);
// do forward pass, increment result if not NULL, do backward pass
GGML_API void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
// allocate the next graph for evaluation, either forward or forward + backward
// must be called exactly once prior to calling ggml_opt_eval
GGML_API void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward);
// do forward pass, increment result if not NULL, do backward pass if allocated
GGML_API void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
// ############################################################################
// ## The high-level functions start here. They do not depend on any private ##
@@ -200,9 +219,9 @@ extern "C" {
// fit model defined by inputs and outputs to dataset
GGML_API void ggml_opt_fit(
ggml_backend_sched_t backend_sched, // backend scheduler for constructing the compute graphs
ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs
ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch]
ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
struct ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs
struct ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch]
struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
enum ggml_opt_loss_type loss_type, // loss to minimize
ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
+6 -7
View File
@@ -768,7 +768,7 @@ extern "C" {
// Tensor flags
GGML_API void ggml_set_input(struct ggml_tensor * tensor);
GGML_API void ggml_set_output(struct ggml_tensor * tensor);
GGML_API void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor);
GGML_API void ggml_set_param(struct ggml_tensor * tensor);
GGML_API void ggml_set_loss(struct ggml_tensor * tensor);
//
@@ -938,7 +938,7 @@ extern "C" {
GGML_API struct ggml_tensor * ggml_repeat_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
struct ggml_tensor * b); // sum up values that are adjacent in dims > 0 instead of repeated with same stride
// concat a and b along dim
// used in stable-diffusion
@@ -2049,15 +2049,14 @@ extern "C" {
GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
GGML_API void ggml_build_backward_expand(
struct ggml_context * ctx_static, // context for static gradients (loss + gradient accumulation)
struct ggml_context * ctx_compute, // context for gradient computation
struct ggml_cgraph * cgraph,
bool accumulate); // whether or not gradients should be accumulated, requires static allocation of tensors in ctx_static
struct ggml_context * ctx, // context for gradient computation
struct ggml_cgraph * cgraph,
struct ggml_tensor ** grad_accs);
// graph allocation in a context
GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
GGML_API struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads);
GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads);
GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst);
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1
GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph);
+7 -3
View File
@@ -674,6 +674,8 @@ struct ggml_backend_sched {
char * context_buffer;
size_t context_buffer_size;
bool op_offload;
int debug;
};
@@ -766,7 +768,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
if (tensor->op != GGML_OP_ROPE && src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src, tensor);
// check if a backend with higher prio wants to offload the op
if (src_backend_id == sched->n_backends - 1 && ggml_backend_buffer_is_host(src->buffer)) {
if (sched->op_offload && src_backend_id == sched->n_backends - 1 && ggml_backend_buffer_is_host(src->buffer)) {
for (int b = 0; b < src_backend_id; b++) {
if (ggml_backend_supports_op(sched->backends[b], tensor) && ggml_backend_offload_op(sched->backends[b], tensor)) {
SET_CAUSE(tensor, "1.off");
@@ -1109,7 +1111,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
const int node_backend_id = tensor_backend_id(node);
assert(node_backend_id != -1); // all nodes should be assigned by now
assert(node_backend_id != -1); // all nodes should be assigned by now, this can happen if there is no CPU fallback
// check if we should start a new split based on the sources of the current node
bool need_new_split = false;
@@ -1452,7 +1454,8 @@ ggml_backend_sched_t ggml_backend_sched_new(
ggml_backend_buffer_type_t * bufts,
int n_backends,
size_t graph_size,
bool parallel) {
bool parallel,
bool op_offload) {
GGML_ASSERT(n_backends > 0);
GGML_ASSERT(n_backends <= GGML_SCHED_MAX_BACKENDS);
GGML_ASSERT(ggml_backend_dev_type(ggml_backend_get_device(backends[n_backends - 1])) == GGML_BACKEND_DEVICE_TYPE_CPU);
@@ -1497,6 +1500,7 @@ ggml_backend_sched_t ggml_backend_sched_new(
}
sched->galloc = ggml_gallocr_new_n(sched->bufts, n_backends);
sched->op_offload = op_offload;
ggml_backend_sched_reset(sched);
+20 -13
View File
@@ -385,9 +385,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
# Fetch KleidiAI sources:
include(FetchContent)
set(KLEIDIAI_COMMIT_TAG "v1.5.0")
set(KLEIDIAI_COMMIT_TAG "v1.6.0")
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
set(KLEIDIAI_ARCHIVE_MD5 "ea22e1aefb800e9bc8c74d91633cc58e")
set(KLEIDIAI_ARCHIVE_MD5 "75b4ad68f25ab673dcc01065e5a0b05f")
if (POLICY CMP0135)
cmake_policy(SET CMP0135 NEW)
@@ -428,6 +428,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
${KLEIDIAI_SRC}/kai/ukernels/
${KLEIDIAI_SRC}/kai/ukernels/matmul/
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/)
set(ARCH_FLAGS_TEMP "${ARCH_FLAGS}")
@@ -438,17 +439,19 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
string(FIND "${ARCH_FLAGS_TEMP}" "+i8mm" I8MM_ENABLED)
string(FIND "${ARCH_FLAGS_TEMP}" "+sme" SME_ENABLED)
set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS})
set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS_TEMP})
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c)
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c)
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c)
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c)
list(APPEND GGML_KLEIDIAI_SOURCES
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c)
if (NOT DOTPROD_ENABLED MATCHES -1)
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c)
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c)
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c)
list(APPEND GGML_KLEIDIAI_SOURCES
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c)
endif()
if (NOT I8MM_ENABLED MATCHES -1)
@@ -456,9 +459,13 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
endif()
if (NOT SME_ENABLED MATCHES -1)
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c)
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c)
set(PRIVATE_ARCH_FLAGS "${PRIVATE_ARCH_FLAGS}+sve+sve2")
list(APPEND GGML_KLEIDIAI_SOURCES
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c)
set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2")
endif()
set_source_files_properties(${GGML_KLEIDIAI_SOURCES} PROPERTIES COMPILE_OPTIONS "${PRIVATE_ARCH_FLAGS}")
+195
View File
@@ -8519,7 +8519,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(n % QK_K == 0);
#ifdef __ARM_FEATURE_MATMUL_INT8
assert((nrc == 2) || (nrc == 1));
#else
assert(nrc == 1);
#endif
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
@@ -8530,6 +8534,197 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
const int nb = n / QK_K;
#if defined(__ARM_FEATURE_MATMUL_INT8)
if (nrc == 2) {
const block_q6_K * GGML_RESTRICT x0 = x;
const block_q6_K * GGML_RESTRICT x1 = (const block_q6_K *) ((const uint8_t *)vx + bx);
const block_q8_K * GGML_RESTRICT y0 = y;
const block_q8_K * GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by);
float32x4_t vfsum = vdupq_n_f32(0.0f);
for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) {
const uint8_t * GGML_RESTRICT ql0 = x0->ql;
const uint8_t * GGML_RESTRICT ql1 = x1->ql;
const uint8_t * GGML_RESTRICT qh0 = x0->qh;
const uint8_t * GGML_RESTRICT qh1 = x1->qh;
const int8_t * GGML_RESTRICT qy0 = y0->qs;
const int8_t * GGML_RESTRICT qy1 = y1->qs;
const uint8x16_t mone = vdupq_n_u8(0x30);
const uint8x16_t m4b = vdupq_n_u8(0x0f);
int32x4_t visum = vdupq_n_s32(0);
// process 8 blocks per iteration, totally 16 blocks
for (int j = 0; j < 2; ++j, qh0 += 32, ql0 += 64, qh1 += 32, ql1 += 64) {
int8x16_t vx0[8], vx1[8];
// de-quantize vx0[8]
{
const uint8x16x2_t qh_bits = vld1q_u8_x2(qh0);
const uint8x16x4_t ql_bits = vld1q_u8_x4(ql0);
uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
vx0[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
vx0[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
vx0[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
vx0[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
q6h_0 = vandq_u8(mone, qh_bits.val[0]);
q6h_1 = vandq_u8(mone, qh_bits.val[1]);
q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
vx0[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
vx0[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
vx0[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
vx0[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
}
// de-quantize vx1[8]
{
const uint8x16x2_t qh_bits = vld1q_u8_x2(qh1);
const uint8x16x4_t ql_bits = vld1q_u8_x4(ql1);
uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
vx1[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
vx1[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
vx1[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
vx1[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
q6h_0 = vandq_u8(mone, qh_bits.val[0]);
q6h_1 = vandq_u8(mone, qh_bits.val[1]);
q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
vx1[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
vx1[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
vx1[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
vx1[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
}
// process 16 elements (one block with same scale) per iteration
// - vx = concat(ql, qh) - 32
// - r1,r2,r3,r4 = smmla(vx, vy)
for (int k = 0; k < 8; ++k) {
const int blk = j * 8 + k;
const int8x16_t vy0 = vld1q_s8(qy0);
const int8x16_t vy1 = vld1q_s8(qy1);
qy0 += 16;
qy1 += 16;
const int32x4_t block_scale = {
x0->scales[blk],
x0->scales[blk],
x1->scales[blk],
x1->scales[blk],
};
// calculate four results at once with outer product
const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
int32x4_t vr = vdupq_n_s32(0);
vr = vmmlaq_s32(vr, vx_l, vy_l);
vr = vmmlaq_s32(vr, vx_h, vy_h);
// apply block scale, will NOT overflow
// block_scale * sum_256(int6*int8) <= 2^(8+8+6+8) = 30 bits
visum = vmlaq_s32(visum, vr, block_scale);
}
}
// adjust bias, apply superblock scale
{
int32_t bias[4];
#ifdef __ARM_FEATURE_SVE
const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
const svbool_t pg8_8 = svptrue_pat_b8(SV_VL8);
const svint16_t y0_q8sums_0 = svld1_s16(pg16_8, y0->bsums);
const svint16_t y0_q8sums_1 = svld1_s16(pg16_8, y0->bsums + 8);
const svint16_t y1_q8sums_0 = svld1_s16(pg16_8, y1->bsums);
const svint16_t y1_q8sums_1 = svld1_s16(pg16_8, y1->bsums + 8);
const svint16_t x0_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x0->scales));
const svint16_t x0_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x0->scales + 8));
const svint16_t x1_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x1->scales));
const svint16_t x1_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x1->scales + 8));
const svint64_t zero = svdup_n_s64(0);
bias[0] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x0_q6scales_0),
svdot_s64(zero, y0_q8sums_1, x0_q6scales_1)));
bias[1] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x0_q6scales_0),
svdot_s64(zero, y1_q8sums_1, x0_q6scales_1)));
bias[2] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x1_q6scales_0),
svdot_s64(zero, y0_q8sums_1, x1_q6scales_1)));
bias[3] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x1_q6scales_0),
svdot_s64(zero, y1_q8sums_1, x1_q6scales_1)));
#else
// NEON doesn't support int16 dot product, fallback to separated mul and add
const int16x8x2_t q8sums0 = vld1q_s16_x2(y0->bsums);
const int16x8x2_t q8sums1 = vld1q_s16_x2(y1->bsums);
int8x16_t scales_s8 = vld1q_s8(x0->scales);
const int16x8x2_t q6scales0 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
scales_s8 = vld1q_s8(x1->scales);
const int16x8x2_t q6scales1 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
int32x4_t prod;
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales0.val[0])),
vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales0.val[0]))),
vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales0.val[1])),
vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales0.val[1]))));
bias[0] = vaddvq_s32(prod);
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales0.val[0])),
vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales0.val[0]))),
vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales0.val[1])),
vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales0.val[1]))));
bias[1] = vaddvq_s32(prod);
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales1.val[0])),
vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales1.val[0]))),
vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales1.val[1])),
vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales1.val[1]))));
bias[2] = vaddvq_s32(prod);
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales1.val[0])),
vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales1.val[0]))),
vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales1.val[1])),
vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales1.val[1]))));
bias[3] = vaddvq_s32(prod);
#endif
const int32x4_t vibias = vmulq_n_s32(vld1q_s32(bias), 32);
const float32x4_t superblock_scale = {
GGML_FP16_TO_FP32(x0->d) * y0->d,
GGML_FP16_TO_FP32(x0->d) * y1->d,
GGML_FP16_TO_FP32(x1->d) * y0->d,
GGML_FP16_TO_FP32(x1->d) * y1->d,
};
visum = vsubq_s32(visum, vibias);
vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale);
}
}
// vfsum = ABCD -> ACBD
// AC -> s, BD -> (s+bs)
vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2));
vst1_f32(s, vget_low_f32 (vfsum));
vst1_f32(s + bs, vget_high_f32(vfsum));
return;
}
#endif
#ifdef __ARM_FEATURE_SVE
const int vector_length = ggml_cpu_get_sve_cnt()*8;
float sum = 0;
+4
View File
@@ -282,7 +282,11 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q6_K,
.vec_dot = ggml_vec_dot_q6_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
#if defined (__ARM_FEATURE_MATMUL_INT8)
.nrows = 2,
#else
.nrows = 1,
#endif
},
[GGML_TYPE_IQ2_XXS] = {
.from_float = NULL,
+88 -5
View File
@@ -4,16 +4,22 @@
// KleidiAI micro-kernels
#include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h"
#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
#include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
#include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h"
#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
#include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h"
#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
#include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h"
#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
#include "kai_common.h"
#include "kernels.h"
@@ -61,6 +67,53 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
},
/* .required_cpu = */ CPU_FEATURE_SME,
/* .lhs_type = */ GGML_TYPE_F32,
/* .rhs_type = */ GGML_TYPE_Q4_0,
/* .op_type = */ GGML_TYPE_F32,
},
{
/* SME GEMM */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
},
/* SME GEMV */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
},
/* .lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme,
/* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
},
/* .rhs_info = */ {
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
/* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
},
/* .required_cpu = */ CPU_FEATURE_SME,
/* .lhs_type = */ GGML_TYPE_F32,
/* .rhs_type = */ GGML_TYPE_F16,
/* .op_type = */ GGML_TYPE_F32,
},
#endif
#if defined(__APPLE__)
@@ -105,6 +158,9 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
},
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
/* .lhs_type = */ GGML_TYPE_F32,
/* .rhs_type = */ GGML_TYPE_Q4_0,
/* .op_type = */ GGML_TYPE_F32,
},
#endif
#if defined(__ARM_FEATURE_MATMUL_INT8)
@@ -148,6 +204,9 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
},
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
/* .lhs_type = */ GGML_TYPE_F32,
/* .rhs_type = */ GGML_TYPE_Q4_0,
/* .op_type = */ GGML_TYPE_F32,
},
#endif
#else
@@ -192,6 +251,9 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
},
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
/* .lhs_type = */ GGML_TYPE_F32,
/* .rhs_type = */ GGML_TYPE_Q4_0,
/* .op_type = */ GGML_TYPE_F32,
},
#endif
#if defined(__ARM_FEATURE_DOTPROD)
@@ -235,12 +297,33 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
},
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
/* .lhs_type = */ GGML_TYPE_F32,
/* .rhs_type = */ GGML_TYPE_Q4_0,
/* .op_type = */ GGML_TYPE_F32,
},
#endif
#endif
};
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature features) {
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) {
ggml_kleidiai_kernels * kernel = nullptr;
if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) {
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu &&
gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type &&
gemm_gemv_kernels[i].rhs_type == tensor->src[0]->type &&
gemm_gemv_kernels[i].op_type == tensor->type) {
kernel = &gemm_gemv_kernels[i];
break;
}
}
}
return kernel;
}
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) {
ggml_kleidiai_kernels * kernels = nullptr;
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
+47 -12
View File
@@ -4,6 +4,10 @@
#pragma once
#include <functional>
#include <variant>
#include "ggml.h"
enum cpu_feature {
CPU_FEATURE_NONE = 0,
CPU_FEATURE_DOTPROD = 1,
@@ -26,26 +30,53 @@ struct kernel_info {
size_t (*get_nr)(void);
size_t (*get_kr)(void);
size_t (*get_sr)(void);
size_t (*get_lhs_offset)(size_t m_idx, size_t k, size_t bl);
size_t (*get_rhs_packed_offset)(size_t n_idx, size_t k, size_t bl);
std::variant<
std::function<size_t(size_t n_idx, size_t k, size_t bl)>,
std::function<size_t(size_t m_idx, size_t k)>
> get_lhs_offset;
std::variant<
std::function<size_t(size_t n_idx, size_t k, size_t bl)>,
std::function<size_t(size_t n_idx, size_t k)>
> get_rhs_packed_offset;
size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride);
size_t (*get_dst_size)(size_t m, size_t n);
void (*run_kernel)(size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed,
float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max);
std::variant<
std::function<void(size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed,
float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max)>,
std::function<void(size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row,
size_t dst_stride_col, float clamp_min, float clamp_max)>
> run_kernel;
};
struct lhs_packing_info {
size_t (*get_offset)(size_t m_idx, size_t lhs_stride);
size_t (*get_packed_offset)(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
size_t (*packed_size)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
void (*pack_func)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
size_t lhs_stride, void* lhs_packed);
std::variant<
std::function<size_t(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr)>,
std::function<size_t(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr)>
> get_packed_offset;
std::variant<
std::function<size_t(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr)>,
std::function<size_t(size_t m, size_t k, size_t mr, size_t kr, size_t sr)>
> packed_size;
std::variant<
std::function<void(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
size_t lhs_stride, void* lhs_packed)>,
std::function<void(size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride,
void* lhs_packed)>
> pack_func;
};
struct rhs_packing_info {
size_t (*packed_size)(size_t n, size_t k, size_t nr, size_t kr, size_t bl);
void (*pack_func)(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params);
std::variant<
std::function<size_t(size_t n, size_t k, size_t nr, size_t kr, size_t bl)>,
std::function<size_t(size_t n, size_t k)>
> packed_size;
std::variant<
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params)>,
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params)>
> pack_func;
};
struct ggml_kleidiai_kernels {
@@ -55,6 +86,10 @@ struct ggml_kleidiai_kernels {
rhs_packing_info rhs_info;
cpu_feature required_cpu;
ggml_type lhs_type;
ggml_type rhs_type;
ggml_type op_type;
};
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features);
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor);
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features);
+266 -71
View File
@@ -3,7 +3,9 @@
//
#include <arm_neon.h>
#include <assert.h>
#include <atomic>
#include <cfloat>
#include <stdexcept>
#include <stdint.h>
#include <string.h>
#if defined(__linux__)
@@ -34,8 +36,9 @@
#include "ggml-common.h"
struct ggml_kleidiai_context {
cpu_feature features;
ggml_kleidiai_kernels * kernels;
} static ctx = { NULL };
} static ctx = { CPU_FEATURE_NONE, NULL };
static void init_kleidiai_context(void) {
@@ -47,18 +50,18 @@ static void init_kleidiai_context(void) {
const char *env_var = getenv("GGML_KLEIDIAI_SME");
int sme_enabled = 0;
cpu_feature features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
(ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) |
(ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
ctx.features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
(ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) |
(ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
if (env_var) {
sme_enabled = atoi(env_var);
}
if (sme_enabled != 0) {
features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
}
ctx.kernels = ggml_kleidiai_select_kernels(features);
ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features);
}
ggml_critical_section_end();
}
@@ -68,95 +71,275 @@ static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
return tensor->ne[dim];
}
template<typename Ret, typename Variant, typename... Args>
static Ret variant_call(const Variant & var, Args&&... args) {
return std::visit([&](auto&& func) -> Ret {
if constexpr (std::is_invocable_r_v<Ret, decltype(func), Args...>) {
return func(std::forward<Args>(args)...);
} else {
throw std::runtime_error("Invalid function type in variant_call");
}
}, var);
}
namespace ggml::cpu::kleidiai {
static size_t round_down(size_t x, size_t y) {
return y == 0 ? x : x - (x % y);
}
static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint16_t * src, size_t rhs_stride) {
size_t src_stride = rhs_stride / sizeof(uint16_t);
size_t dst_stride = n;
for (size_t k_idx = 0; k_idx < k; ++k_idx) {
for (size_t n_idx = 0; n_idx < n; ++n_idx) {
uint16_t v = *(src + k_idx + n_idx * src_stride);
*(dst + n_idx + k_idx * dst_stride) = kai_cast_f32_f16(v);
}
}
}
class tensor_traits : public ggml::cpu::tensor_traits {
bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
GGML_ASSERT(ctx.kernels);
kernel_info * kernel = op->src[1]->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm;
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
GGML_ASSERT(kernels);
kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
size_t k = op->src[0]->ne[0];
size_t n = op->src[0]->ne[1];
size_t m = op->src[1]->ne[1];
size_t mr = kernel->get_mr();
size_t kr = kernel->get_kr();
size_t sr = kernel->get_sr();
size = ctx.kernels->lhs_info.packed_size(m, k, QK4_0, mr, kr, sr);
if (kernels->rhs_type == GGML_TYPE_Q4_0) {
size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, QK4_0, mr, kr, sr);
} else if (kernels->rhs_type == GGML_TYPE_F16) {
size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, mr, kr, sr) +
variant_call<size_t>(kernels->rhs_info.packed_size, n, k) +
k * n * sizeof(float) + n * sizeof(float);
} else {
GGML_ASSERT(false);
}
return true;
}
bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
if (dst->op == GGML_OP_MUL_MAT) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
return compute_forward_q4_0(params, dst);
} else if (dst->src[0]->type == GGML_TYPE_F16) {
return compute_forward_kv_cache(params, dst);
}
}
return false;
}
GGML_TENSOR_BINARY_OP_LOCALS
bool compute_forward_kv_cache(ggml_compute_params * params, struct ggml_tensor * dst) {
static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT;
GGML_ASSERT(ctx.kernels);
kernel_info * kernel = src1->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm;
lhs_packing_info * lhs_info = &ctx.kernels->lhs_info;
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_ASSERT(kernel);
GGML_TENSOR_BINARY_OP_LOCALS
const int ith = params->ith;
const int nth = params->nth;
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
GGML_ASSERT(kernels);
const size_t k = ne00;
const size_t m = ne11;
const size_t n = ne01;
kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
GGML_ASSERT(kernel);
const size_t n_step = kernel->get_n_step();
const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
const size_t n_start = ith * num_n_per_thread;
const int nth = params->nth;
const int ith = params->ith;
size_t n_to_process = num_n_per_thread;
if ((n_start + n_to_process) > n) {
n_to_process = n - n_start;
const int64_t lhs_batch_size0 = ne12;
const int64_t rhs_batch_size0 = ne02;
const int64_t batch_size = rhs_batch_size0;
const int64_t r = lhs_batch_size0 / rhs_batch_size0;
const int64_t m = ne11 * r;
const int64_t n = ne01;
const int64_t k = ne00;
const size_t lhs_stride = src1->nb[1];
const size_t rhs_stride = src0->nb[1];
const size_t dst_stride = dst->nb[1];
const int64_t mr = static_cast<int64_t>(kernel->get_mr());
const int64_t nr = static_cast<int64_t>(kernel->get_nr());
const int64_t kr = static_cast<int64_t>(kernel->get_kr());
const int64_t sr = static_cast<int64_t>(kernel->get_sr());
const size_t lhs_packed_size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, mr, kr, sr);
const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, n, k);
const size_t kxn_size = k * n * sizeof(float);
const size_t bias_size = n * sizeof(float);
const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size;
GGML_ASSERT(wsize_required <= params->wsize);
uint8_t * lhs_packed = static_cast<uint8_t *>(params->wdata);
uint8_t * rhs_packed = lhs_packed + lhs_packed_size;
uint8_t * rhs_kxn = rhs_packed + rhs_packed_size;
uint8_t * bias = rhs_kxn + kxn_size;
for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
const uint8_t * lhs_batch = static_cast<const uint8_t *>(src1->data) + batch_idx * m * lhs_stride;
const uint8_t * rhs_batch = static_cast<const uint8_t *>(src0->data) + batch_idx * n * rhs_stride;
uint8_t * dst_batch = static_cast<uint8_t *>(dst->data) + batch_idx * m * dst_stride;
// LHS packing
{
const int64_t m_roundup_mr = kai_roundup(m, mr);
const int64_t num_threads = KAI_MIN(m_roundup_mr / mr, nth);
if (ith < num_threads) {
const int64_t num_m_per_thread0 = round_down(m_roundup_mr / num_threads, mr);
const int64_t num_m_per_threadN_1 = m - (num_threads - 1) * num_m_per_thread0;
const int64_t m_start = ith * num_m_per_thread0;
const int64_t num_m_per_thread = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
const size_t lhs_offset = variant_call<size_t>(kernels->gemm.get_lhs_offset, m_start, lhs_stride);
const size_t lhs_packed_offset = variant_call<size_t>(kernels->lhs_info.get_packed_offset, m_start, k, mr, kr, sr);
const void * src_ptr = static_cast<const uint8_t *>(lhs_batch) + lhs_offset;
void * dst_ptr = static_cast<uint8_t *>(lhs_packed) + lhs_packed_offset;
variant_call<void>(kernels->lhs_info.pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
}
}
const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
uint8_t * lhs_packed = (uint8_t*)params->wdata;
const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
// RHS packing
if (first_to_arrive.test_and_set(std::memory_order_acquire) == false) {
// First thread to reach this point handles RHS packing
memset(bias, 0, n * sizeof(float));
transpose_f32kxn_f16nxk(n, k, reinterpret_cast<float *>(rhs_kxn),
reinterpret_cast<const uint16_t *>(rhs_batch), rhs_stride);
size_t mr = kernel->get_mr();
size_t kr = kernel->get_kr();
size_t sr = kernel->get_sr();
// Calculate number of columns to be processed per thread
const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
const size_t m_start = ith * num_m_per_thread;
size_t m_to_process = num_m_per_thread;
if ((m_start + m_to_process) > m) {
m_to_process = m - m_start;
}
if(m_start < m) {
// Transform LHS
const size_t src_stride = src1->nb[1];
const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
const size_t lhs_packed_offset = lhs_info->get_packed_offset(m_start, k, QK4_0, mr, kr, sr);
void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
lhs_info->pack_func(m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
variant_call<void>(kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, n * sizeof(float),
rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr);
}
ggml_barrier(params->threadpool);
// Perform the operation
const size_t dst_stride = dst->nb[1];
const size_t lhs_packed_offset = lhs_info->get_packed_offset(0, k, QK4_0, mr, kr, sr);
const size_t rhs_packed_offset = kernel->get_rhs_packed_offset(n_start, k, QK4_0);
const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
first_to_arrive.clear(std::memory_order_release);
kernel->run_kernel(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr,
dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
return true;
// Perform the matmul
{
const int64_t m_to_process = m;
const int64_t m_start = 0;
const int64_t n_step = static_cast<int64_t>(kernel->get_n_step());
const int64_t num_threads = KAI_MIN(n / n_step, nth);
if (ith < num_threads) {
const int64_t num_n_per_thread0 = round_down(n / num_threads, n_step);
const int64_t num_n_per_threadN_1 = n - (num_threads - 1) * num_n_per_thread0;
const int64_t n_start = ith * num_n_per_thread0;
const int64_t n_to_process = (ith == num_threads - 1) ? num_n_per_threadN_1 : num_n_per_thread0;
const size_t lhs_packed_offset = variant_call<size_t>(kernel->get_lhs_offset, m_start, k);
const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k);
const size_t dst_offset = kernel->get_dst_offset(m_start, n_start, dst_stride);
const void * lhs_ptr = lhs_packed + lhs_packed_offset;
const void * rhs_ptr = rhs_packed + rhs_packed_offset;
float * dst_ptr = reinterpret_cast<float *>(dst_batch + dst_offset);
variant_call<void>(kernel->run_kernel, m_to_process, n_to_process, k, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
}
}
if (batch_idx != batch_size - 1) {
// This barrier is necessary when the batch size is larger than 1. While processing a batch,
// the work data buffer (params->wdata) is used as temporary storage which means that only
// a single batch can be processed at any given time. No barrier is needed for the last
// batch since GGML inserts a barrier between the execution of every operator.
ggml_barrier(params->threadpool);
}
}
return false;
return true;
}
bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_TENSOR_BINARY_OP_LOCALS
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
GGML_ASSERT(kernels);
kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
lhs_packing_info * lhs_info = &kernels->lhs_info;
GGML_ASSERT(kernel);
const int ith = params->ith;
const int nth = params->nth;
const size_t k = ne00;
const size_t m = ne11;
const size_t n = ne01;
size_t mr = kernel->get_mr();
size_t kr = kernel->get_kr();
size_t sr = kernel->get_sr();
const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
uint8_t * lhs_packed = (uint8_t*)params->wdata;
const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
const size_t n_step = kernel->get_n_step();
const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
const size_t n_start = ith * num_n_per_thread;
size_t n_to_process = num_n_per_thread;
if ((n_start + n_to_process) > n) {
n_to_process = n - n_start;
}
// Calculate number of columns to be processed per thread
const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
const size_t m_start = ith * num_m_per_thread;
size_t m_to_process = num_m_per_thread;
if ((m_start + m_to_process) > m) {
m_to_process = m - m_start;
}
if (m_start < m) {
// Transform LHS
const size_t src_stride = src1->nb[1];
const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, m_start, k, QK4_0, mr, kr, sr);
void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
variant_call<void>(lhs_info->pack_func, m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
}
ggml_barrier(params->threadpool);
// Perform the operation
const size_t dst_stride = dst->nb[1];
const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, 0, k, QK4_0, mr, kr, sr);
const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k, QK4_0);
const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
variant_call<void>(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
sizeof(float), -FLT_MAX, FLT_MAX);
return true;
}
public:
@@ -169,13 +352,13 @@ public:
size_t sr = ctx.kernels->gemm.get_sr();
#ifndef NDEBUG
const size_t repacked_size = ctx.kernels->rhs_info.packed_size(n, k, nr, kr, QK4_0);
const size_t repacked_size = variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!");
#endif
struct kai_rhs_pack_qs4cxs1s0_param params;
params.lhs_zero_point = 1;
params.rhs_zero_point = 8;
ctx.kernels->rhs_info.pack_func(1, n, k, nr, kr, sr, QK4_0, (const uint8_t *)data, NULL, tensor->data, 0, &params);
variant_call<void>(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, &params);
return 0;
@@ -189,7 +372,7 @@ static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struc
}
} // namespace ggml::cpu::kleidiai
GGML_API enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
GGML_UNUSED(buffer);
@@ -238,12 +421,11 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b
namespace ggml::cpu::kleidiai {
class extra_buffer_type : ggml::cpu::extra_buffer_type {
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
if ( op->op == GGML_OP_MUL_MAT &&
op->src[0]->type == GGML_TYPE_Q4_0 &&
op->src[0]->buffer &&
(ggml_n_dims(op->src[0]) == 2) &&
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels
) {
if (op->op == GGML_OP_MUL_MAT &&
op->src[0]->type == GGML_TYPE_Q4_0 &&
op->src[0]->buffer &&
(ggml_n_dims(op->src[0]) == 2) &&
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
return false;
}
@@ -260,6 +442,19 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
}
else if (ggml_kleidiai_select_kernels(ctx.features, op) &&
op->src[0]->op == GGML_OP_VIEW &&
(op->src[1]->op == GGML_OP_PERMUTE || op->src[1]->op == GGML_OP_SOFT_MAX) &&
op->src[1]->ne[1] > 1) {
if ((op->src[0]->nb[0] != 2) ||
(op->src[1]->nb[0] != 4) ||
(op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
(op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
return nullptr;
}
return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL);
}
}
return nullptr;
}
+40 -26
View File
@@ -1,47 +1,61 @@
#include "acc.cuh"
static __global__ void acc_f32(const float * x, const float * y, float * dst, const int ne,
const int ne10, const int ne11, const int ne12,
const int nb1, const int nb2, int offset) {
const int i = blockDim.x * blockIdx.x + threadIdx.x;
static __global__ void acc_f32(const float * x, const float * y, float * dst, const int64_t ne,
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
const int64_t s11, const int64_t s12, const int64_t s13, const int64_t offset) {
const int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
if (i >= ne) {
return;
}
int src1_idx = i - offset;
int oz = src1_idx / nb2;
int oy = (src1_idx - (oz * nb2)) / nb1;
int ox = src1_idx % nb1;
if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) {
dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
} else {
dst[i] = x[i];
int64_t src1_idx = i - offset;
int64_t tmp = src1_idx;
const int64_t i13 = tmp / s13;
tmp -= i13 * s13;
const int64_t i12 = tmp / s12;
tmp -= i12 * s12;
const int64_t i11 = tmp / s11;
tmp -= i11 * s11;
const int64_t i10 = tmp;
float val = x[i];
if (src1_idx >= 0 && i10 < ne10 && i11 < ne11 && i12 < ne12 && i13 < ne13) {
val += y[((i13*ne12 + i12) * ne11 + i11) * ne10 + i10];
}
dst[i] = val;
}
static void acc_f32_cuda(const float * x, const float * y, float * dst, const int n_elements,
const int ne10, const int ne11, const int ne12,
const int nb1, const int nb2, const int offset, cudaStream_t stream) {
int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE;
acc_f32<<<num_blocks, CUDA_ACC_BLOCK_SIZE, 0, stream>>>(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset);
static void acc_f32_cuda(const float * x, const float * y, float * dst, const int64_t n_elements,
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
const int64_t s1, const int64_t s2, const int64_t s3, const int64_t offset, cudaStream_t stream) {
const int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE;
acc_f32<<<num_blocks, CUDA_ACC_BLOCK_SIZE, 0, stream>>>(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset);
}
void ggml_cuda_op_acc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
float * dst_d = (float *)dst->data;
const float * src0_d = (const float *) src0->data;
const float * src1_d = (const float *) src1->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
// int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
int offset = dst->op_params[3] / 4; // offset in bytes
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(dst->nb[0] == ggml_element_size(dst));
GGML_ASSERT(ggml_is_contiguously_allocated(dst));
acc_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, stream);
const int64_t s1 = dst->op_params[0] / sizeof(float);
const int64_t s2 = dst->op_params[1] / sizeof(float);
const int64_t s3 = dst->op_params[2] / sizeof(float);
const int64_t offset = dst->op_params[3] / sizeof(float);
acc_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], s1, s2, s3, offset, stream);
}
+13 -5
View File
@@ -678,10 +678,14 @@ void launch_fattn(
) {
constexpr int ncols = ncols1 * ncols2;
const bool is_mla = DV == 512; // TODO better parameterization
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
GGML_ASSERT(V || is_mla);
const ggml_tensor * mask = dst->src[3];
ggml_tensor * KQV = dst;
@@ -689,6 +693,10 @@ void launch_fattn(
GGML_ASSERT(Q->type == GGML_TYPE_F32);
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
GGML_ASSERT( Q->nb[0] == ggml_element_size(Q));
GGML_ASSERT( K->nb[0] == ggml_element_size(K));
GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V));
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
@@ -713,10 +721,10 @@ void launch_fattn(
size_t nb12 = K->nb[2];
size_t nb13 = K->nb[3];
const char * V_data = (const char *) V->data;
size_t nb21 = V->nb[1];
size_t nb22 = V->nb[2];
size_t nb23 = V->nb[3];
const char * V_data = V ? (const char *) V->data : nullptr;
size_t nb21 = V ? V->nb[1] : nb11;
size_t nb22 = V ? V->nb[2] : nb12;
size_t nb23 = V ? V->nb[3] : nb13;
if (need_f16_K && K->type != GGML_TYPE_F16) {
GGML_ASSERT(ggml_is_contiguously_allocated(K));
@@ -733,7 +741,7 @@ void launch_fattn(
nb13 = nb13*bs*sizeof(half)/ts;
}
if (need_f16_V && V->type != GGML_TYPE_F16) {
if (V && need_f16_V && V->type != GGML_TYPE_F16) {
GGML_ASSERT(ggml_is_contiguously_allocated(V));
V_f16.alloc(ggml_nelements(V));
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
+266 -64
View File
@@ -33,9 +33,30 @@ struct fattn_mma_f16_config< 64, 64> {
static constexpr int nwarps_max = 4;
static constexpr bool Q_in_reg = true;
static constexpr int nstages_target = 2;
static constexpr int nbatch_K2 = 32;
static constexpr int nbatch_V2 = 32;
static constexpr int nbatch_combine = 32;
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
return 32;
}
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
return 32;
}
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
return 32;
}
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
return 32;
}
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
return 32;
}
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
return 32;
}
};
template <>
@@ -44,9 +65,30 @@ struct fattn_mma_f16_config< 80, 80> {
static constexpr int nwarps_max = 4;
static constexpr bool Q_in_reg = true;
static constexpr int nstages_target = 2;
static constexpr int nbatch_K2 = 40;
static constexpr int nbatch_V2 = 40;
static constexpr int nbatch_combine = 40;
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
return 40;
}
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
return 40;
}
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
return 40;
}
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
return 40;
}
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
return 40;
}
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
return 40;
}
};
template <>
@@ -55,9 +97,30 @@ struct fattn_mma_f16_config< 96, 96> {
static constexpr int nwarps_max = 4;
static constexpr bool Q_in_reg = true;
static constexpr int nstages_target = 2;
static constexpr int nbatch_K2 = 48;
static constexpr int nbatch_V2 = 48;
static constexpr int nbatch_combine = 48;
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
return 48;
}
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
return 48;
}
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
return 48;
}
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
return 48;
}
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
return 48;
}
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
return 48;
}
};
template <>
@@ -66,9 +129,30 @@ struct fattn_mma_f16_config<112, 112> {
static constexpr int nwarps_max = 4;
static constexpr bool Q_in_reg = true;
static constexpr int nstages_target = 2;
static constexpr int nbatch_K2 = 56;
static constexpr int nbatch_V2 = 56;
static constexpr int nbatch_combine = 56;
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
return 56;
}
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
return 56;
}
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
return 56;
}
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
return 56;
}
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
return 56;
}
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
return 56;
}
};
template <>
@@ -77,9 +161,30 @@ struct fattn_mma_f16_config<128, 128> {
static constexpr int nwarps_max = 4;
static constexpr bool Q_in_reg = true;
static constexpr int nstages_target = 2;
static constexpr int nbatch_K2 = 64;
static constexpr int nbatch_V2 = 64;
static constexpr int nbatch_combine = 64;
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
return 64;
}
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
return 64;
}
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
return 64;
}
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
return 64;
}
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
return 64;
}
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
return 64;
}
};
template <>
@@ -88,9 +193,38 @@ struct fattn_mma_f16_config<256, 256> {
static constexpr int nwarps_max = 4;
static constexpr bool Q_in_reg = true;
static constexpr int nstages_target = 2;
static constexpr int nbatch_K2 = 128;
static constexpr int nbatch_V2 = 128;
static constexpr int nbatch_combine = 128;
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
return 128;
}
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
return 128;
}
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
return 128;
}
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
return 128;
}
static int get_nbatch_combine_host(const int cc, const int ncols) {
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
return ncols <= 16 ? 128 : 64;
}
return 64;
}
static constexpr __device__ int get_nbatch_combine_device(int ncols) {
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
return ncols <= 16 ? 128 : 64;
#else
GGML_UNUSED(ncols);
return 128;
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
}
};
template <>
@@ -99,9 +233,44 @@ struct fattn_mma_f16_config<576, 512> {
static constexpr int nwarps_max = 8;
static constexpr bool Q_in_reg = false;
static constexpr int nstages_target = 1;
static constexpr int nbatch_K2 = 160;
static constexpr int nbatch_V2 = 128;
static constexpr int nbatch_combine = 128;
static int get_nbatch_K2_host(const int cc, const int ncols) {
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
return ncols <= 16 ? 96 : 160;
}
return ncols <= 16 ? 288 : 160;
}
static constexpr __device__ int get_nbatch_K2_device(int ncols) {
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
return ncols <= 16 ? 96 : 160;
#else
return ncols <= 16 ? 288 : 160;
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
}
static int get_nbatch_V2_host(const int cc, const int ncols) {
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
return ncols <= 16 ? 64 : 128;
}
return ncols <= 16 ? 256 : 128;
}
static constexpr __device__ int get_nbatch_V2_device(int ncols) {
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
return ncols <= 16 ? 64 : 128;
#else
return ncols <= 16 ? 256 : 128;
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
}
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
return 128;
}
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
return 128;
}
};
// ------------------------------------------------------------------------------------------------------------------
@@ -120,7 +289,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
auto load = [&] __device__ (const int n) {
auto load = [&] __device__ (auto n) {
const int stride_k = WARP_SIZE >> n;
const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
@@ -223,7 +392,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
}
}
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup, bool last_iter>
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter>
static __device__ __forceinline__ void flash_attn_ext_f16_iter(
const float2 * const __restrict__ Q_f2,
const half2 * const __restrict__ K_h2,
@@ -261,10 +430,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
constexpr int cols_per_warp = ntiles * tile_B::I;
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
constexpr int ncols = ncols1 * ncols2;
constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
constexpr int stride_tile_Q = DKQ/2 + 4;
constexpr int stride_tile_K = c::nbatch_K2 + 4;
constexpr int stride_tile_V = c::nbatch_V2 + 4;
constexpr int stride_tile_Q = DKQ/2 + 4;
constexpr int stride_tile_K = nbatch_K2 + 4;
static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
const int k_VKQ_0 = kb0 * c::nbatch_fa;
tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles];
@@ -275,12 +449,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C;
if constexpr (nstages > 1) {
static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
static_assert(!mla, "multi-stage loading not implemented for MLA");
static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
constexpr bool use_cp_async = true;
cp_async_wait_all();
__syncthreads();
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
(V_h2 + k_VKQ_0*stride_V, tile_V, c::nbatch_V2, stride_V);
(V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
} else {
constexpr bool use_cp_async = nstages == 1;
if (ncols2 > 1 || mask_h2) {
@@ -289,8 +464,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
}
#pragma unroll
for (int k0_start = 0; k0_start < DKQ/2; k0_start += c::nbatch_K2) {
const int k0_stop = k0_start + c::nbatch_K2 < DKQ/2 ? k0_start + c::nbatch_K2 : DKQ/2;
for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) {
const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
const int k0_diff = k0_stop - k0_start;
if (nstages <= 1) {
@@ -537,16 +712,21 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
(mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
}
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
(K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, c::nbatch_K2, stride_K);
(K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
}
}
#pragma unroll
for (int i0_start = 0; i0_start < DV; i0_start += 2*c::nbatch_V2) {
const int i0_stop = i0_start + 2*c::nbatch_V2 < DV ? i0_start + 2*c::nbatch_V2 : DV;
const int i0_diff = i0_stop - i0_start;
if (nstages <= 1) {
// For MLA K and V have the same data.
// Therefore, iterate over V in reverse and re-use the data if possible.
static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
#pragma unroll
for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) {
const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0;
const int i0_diff = i0_stop - i0_start;
if (nstages <= 1 && i0_start < reusable_cutoff) {
constexpr bool use_cp_async = nstages == 1;
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
(V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
@@ -555,6 +735,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
}
__syncthreads();
}
const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2;
// Calculate VKQ tile:
#pragma unroll
@@ -565,7 +746,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
tile_A A;
load_ldmatrix_trans(A, tile_V + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
if (ntiles == 1) {
mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
} else {
@@ -596,7 +777,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
#endif // NEW_MMA_AVAILABLE
}
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const float2 * const __restrict__ Q_f2,
const half2 * const __restrict__ K_h2,
@@ -632,13 +813,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
constexpr int cols_per_warp = ntiles * tile_B::I;
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
constexpr int stride_tile_Q = DKQ/2 + 4;
constexpr int stride_tile_K = c::nbatch_K2 + 4;
constexpr int stride_tile_V = c::nbatch_V2 + 4;
constexpr int stride_tile_Q = DKQ/2 + 4;
constexpr int stride_tile_K = nbatch_K2 + 4;
static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
extern __shared__ half2 tile_Q[];
@@ -726,26 +910,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
// Preload mask and K data for first iteration when using cp_async with multiple stages:
if constexpr (nstages > 1) {
static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
constexpr bool use_cp_async = true;
if (ncols2 > 1 || mask_h2) {
flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>
(mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
}
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
(K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, c::nbatch_K2, stride_K);
(K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
}
// Iterate over ne11 == previous tokens:
for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
constexpr bool last_iter = false;
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
}
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
constexpr bool last_iter = true;
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
}
@@ -774,7 +958,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
// It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
// So also write VKQ accumulators to shared memory in column-major format if np == 1.
constexpr int nbatch_combine = c::Q_in_reg ? DV/2 : DV/4;
constexpr int nbatch_combine = c::get_nbatch_combine_device(ncols);
constexpr int tile_stride = nbatch_combine + 4;
static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine");
@@ -895,6 +1079,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
}
} else if (np > 1) {
// Warps with threadIdx.y % np == 0 execute a __syncthreads() in the if branch.
// Therefore, all other warps also need to execute a __syncthreads().
// Otherwise the points at which warps synchronize with each other would become misaligned.
__syncthreads();
}
#pragma unroll
@@ -1007,7 +1196,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
#endif // NEW_MMA_AVAILABLE
}
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap>
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla>
__launch_bounds__(nwarps*WARP_SIZE, 1)
static __global__ void flash_attn_ext_f16(
const char * __restrict__ Q,
@@ -1052,6 +1241,14 @@ static __global__ void flash_attn_ext_f16(
NO_DEVICE_CODE;
return;
}
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
if (ncols1*ncols2 > 32) {
NO_DEVICE_CODE;
return;
}
#endif __CUDA_ARCH__ == GGML_CUDA_CC_TURING
static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV");
typedef fattn_mma_f16_config<DKQ, DV> c;
@@ -1062,9 +1259,10 @@ static __global__ void flash_attn_ext_f16(
const int stride_Q1 = nb01 / sizeof(float2);
const int stride_Q2 = nb02 / sizeof(float2);
const int stride_K = nb11 / sizeof(half2);
const int stride_V = nb21 / sizeof(half2);
const int stride_mask = nb31 / sizeof(half2);
const int stride_V = mla ? stride_K : nb21 / sizeof(half2);
const int iter_k = ne11 / FATTN_KQ_STRIDE;
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
@@ -1087,10 +1285,11 @@ static __global__ void flash_attn_ext_f16(
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
const int kb0_start_kernel = kb0_start * kb_niter;
@@ -1099,12 +1298,12 @@ static __global__ void flash_attn_ext_f16(
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
if (kb0_start == 0) {
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
} else {
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
}
@@ -1125,10 +1324,11 @@ static __global__ void flash_attn_ext_f16(
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); // K and V have same shape
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
const int kb0_start_kernel = kb0_start * kb_niter;
@@ -1136,7 +1336,7 @@ static __global__ void flash_attn_ext_f16(
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
constexpr bool needs_fixup = false;
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
#else
@@ -1162,10 +1362,6 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
typedef fattn_mma_f16_config<DKQ, DV> c;
constexpr int nbatch_K2 = c::nbatch_K2 < 1 ? DKQ/2 : c::nbatch_K2;
constexpr int nbatch_V2 = c::nbatch_V2 < 1 ? DV /2 : c::nbatch_V2;
constexpr int nbatch_combine = c::nbatch_combine < 1 ? DV /2 : c::nbatch_combine;
const int nstages = cp_async_available(cc) ? c::nstages_target : 0;
constexpr int ncols = ncols1 * ncols2;
@@ -1175,15 +1371,21 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I;
constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max;
constexpr bool mla = DKQ == 576;
const int nbatch_K2 = c::get_nbatch_K2_host (cc, ncols);
const int nbatch_V2 = c::get_nbatch_K2_host (cc, ncols);
const int nbatch_combine = c::get_nbatch_combine_host(cc, ncols);
static_assert(DKQ % tile_B::J == 0, "bad DKQ");
static_assert(DV % tile_A::J == 0, "bad DV");
static_assert(ncols % cols_per_warp == 0, "bad ncols");
const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(c::nbatch_K2 + 4, c::nbatch_V2 + 4) * sizeof(half2);
const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (c::nbatch_K2 + 4 + c::nbatch_V2 + 4) * sizeof(half2);
const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2);
const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2);
const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2);
const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2);
const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2);
const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2);
const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2);
const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage;
@@ -1197,7 +1399,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
fattn_kernel_t fattn_kernel;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>;
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
@@ -1208,7 +1410,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
} else {
constexpr bool use_logit_softcap = true;
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>;
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
+2 -1
View File
@@ -10,6 +10,7 @@
template <int DKQ, int DV, int ncols2>
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const ggml_tensor * Q = dst->src[0];
if constexpr (ncols2 <= 8) {
@@ -24,7 +25,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
return;
}
if (Q->ne[1] <= 32/ncols2) {
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || Q->ne[1] <= 32/ncols2) {
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);
return;
}
+9 -3
View File
@@ -1909,13 +1909,19 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
// If src0 is a temporary compute buffer it may have some padding that needs to be cleared for mul_mat_vec_q or mul_mat_q.
// But if src0 is also a view of another tensor then this cannot be done safely because it may overwrite valid tensor data.
// Therefore, in such cases use cuBLAS.
const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE
&& ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src;
bool use_mul_mat_vec = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
&& src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
bool use_mul_mat_q = ggml_is_quantized(src0->type)
bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
bool any_gpus_with_slow_fp16 = false;
@@ -3216,7 +3222,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
#endif // FLASH_ATTN_AVAILABLE
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
if (!new_mma_available(cc) || cc < GGML_CUDA_CC_AMPERE) {
if (!new_mma_available(cc)) {
return false;
}
const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
+4 -2
View File
@@ -91,11 +91,11 @@ void ggml_cuda_mul_mat_q(
// If src0 is a temporary compute buffer, clear any potential padding.
if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
GGML_ASSERT(ggml_is_contiguously_allocated(src0));
GGML_ASSERT(!src0->view_src);
const size_t size_data = ggml_nbytes(src0);
const size_t size_alloc = ggml_backend_buffer_get_alloc_size(src0->buffer, src0);
if (size_alloc > size_data) {
GGML_ASSERT(ggml_is_contiguously_allocated(src0));
GGML_ASSERT(!src0->view_src);
CUDA_CHECK(cudaMemsetAsync((char *) src0->data + size_data, 0, size_alloc - size_data, stream));
}
}
@@ -122,6 +122,7 @@ void ggml_cuda_mul_mat_q(
const int64_t s13 = src1->nb[3] / ts_src1;
quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type,
ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
CUDA_CHECK(cudaGetLastError());
}
const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
@@ -205,6 +206,7 @@ void ggml_cuda_mul_mat_q(
const int64_t s13 = src1->nb[2] / ts_src1;
quantize_mmq_q8_1_cuda(src1_d, ids_src1_dev, src1_q8_1.get(), src0->type,
ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
CUDA_CHECK(cudaGetLastError());
}
const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
+2 -2
View File
@@ -515,11 +515,11 @@ void ggml_cuda_mul_mat_vec_q(
// If src0 is a temporary compute buffer, clear any potential padding.
if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
GGML_ASSERT(ggml_is_contiguously_allocated(src0));
GGML_ASSERT(!src0->view_src);
const size_t size_data = ggml_nbytes(src0);
const size_t size_alloc = ggml_backend_buffer_get_alloc_size(src0->buffer, src0);
if (size_alloc > size_data) {
GGML_ASSERT(ggml_is_contiguously_allocated(src0));
GGML_ASSERT(!src0->view_src);
CUDA_CHECK(cudaMemsetAsync((char *) src0->data + size_data, 0, size_alloc - size_data, stream));
}
}
+7 -6
View File
@@ -56,13 +56,13 @@ static __global__ void quantize_mmq_q8_1(
constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
const int64_t i0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
const int64_t i0 = ((int64_t)blockDim.x*blockIdx.y + threadIdx.x)*4;
if (i0 >= ne0) {
return;
}
const int64_t i1 = blockIdx.y;
const int64_t i1 = blockIdx.x;
const int64_t i2 = blockIdx.z % ne2;
const int64_t i3 = blockIdx.z / ne2;
@@ -75,8 +75,8 @@ static __global__ void quantize_mmq_q8_1(
block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.y; // block index in channel
const int64_t ib0 = blockIdx.z*((int64_t)gridDim.x*gridDim.y*blockDim.x/QK8_1); // first block of channel
const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.x; // block index in channel
const int64_t iqs = i0 % (4*QK8_1); // quant index in block
// Load 4 floats per thread and calculate max. abs. value between them:
@@ -166,8 +166,9 @@ void quantize_mmq_q8_1_cuda(
GGML_ASSERT(ne00 % 4 == 0);
GGML_ASSERT(ne0 % (4*QK8_1) == 0);
const int64_t block_num_x = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
// ne1 tends to assume the highest values, therefore use it as the "x" dimension of the CUDA grid:
const int64_t block_num_y = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
const dim3 num_blocks(ne1, block_num_y, ne2*ne3);
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
switch (mmq_get_q8_1_ds_layout(type_src0)) {
case MMQ_Q8_1_DS_LAYOUT_D4:
+1 -1
View File
@@ -31,7 +31,7 @@ void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguously_allocated(src0));
const float * src0_d = (const float *) src0->data;
float * dst_d = (float *) dst->data;
+4
View File
@@ -207,6 +207,10 @@ typedef struct {
float attn_factor;
float beta_fast;
float beta_slow;
int32_t sect_0;
int32_t sect_1;
int32_t sect_2;
int32_t sect_3;
} ggml_metal_kargs_rope;
typedef struct {
+43 -17
View File
@@ -332,6 +332,10 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16,
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32,
GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16,
GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32,
GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16,
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
@@ -1275,6 +1279,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, mul_mm_id_iq4_xs_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, rope_multi_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16, rope_multi_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32, rope_vision_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16, rope_vision_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
@@ -1637,16 +1645,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
case GGML_OP_ROPE:
{
const int mode = ((const int32_t *) op->op_params)[2];
if (mode & GGML_ROPE_TYPE_MROPE) {
return false;
}
if (mode & GGML_ROPE_TYPE_VISION) {
return false;
}
return true;
}
return true;
case GGML_OP_IM2COL:
return op->src[0]->type == GGML_TYPE_F16;
case GGML_OP_POOL_1D:
@@ -3826,6 +3825,7 @@ static bool ggml_metal_encode_node(
} break;
case GGML_OP_ROPE:
{
// make sure we have one or more position id(ne10) per token(ne02)
GGML_ASSERT(ne10 % ne02 == 0);
GGML_ASSERT(ne10 >= ne02);
@@ -3852,20 +3852,42 @@ static bool ggml_metal_encode_node(
memcpy(&beta_fast, (const int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (const int32_t *) dst->op_params + 10, sizeof(float));
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
// mrope
const int sect_0 = ((const int32_t *) dst->op_params)[11];
const int sect_1 = ((const int32_t *) dst->op_params)[12];
const int sect_2 = ((const int32_t *) dst->op_params)[13];
const int sect_3 = ((const int32_t *) dst->op_params)[14];
id<MTLComputePipelineState> pipeline = nil;
if (!is_neox) {
if (is_neox) {
switch (src0->type) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
default: GGML_ABORT("fatal error");
};
} else if (is_mrope && !is_vision) {
GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token
switch (src0->type) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16].pipeline; break;
default: GGML_ABORT("fatal error");
};
} else if (is_vision) {
GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token
switch (src0->type) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16].pipeline; break;
default: GGML_ABORT("fatal error");
};
} else {
switch (src0->type) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
default: GGML_ABORT("fatal error");
};
}
@@ -3896,6 +3918,10 @@ static bool ggml_metal_encode_node(
/*.attn_factor =*/ attn_factor,
/*.beta_fast =*/ beta_fast,
/*.beta_slow =*/ beta_slow,
/* sect_0 =*/ sect_0,
/* sect_1 =*/ sect_1,
/* sect_2 =*/ sect_2,
/* sect_3 =*/ sect_3,
};
[encoder setComputePipelineState:pipeline];
@@ -4332,7 +4358,7 @@ static bool ggml_metal_encode_node(
// TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
// for now avoiding mainly to keep the number of templates/kernels a bit lower
// these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 96 && ne00 != 192 && ne00 != 576)) {
if (ne01 >= 20 || (ne00%128 != 0 && ne00 != 96 && ne00 != 192 && ne00 != 576)) {
switch (src1->type) {
case GGML_TYPE_F16:
{
+151
View File
@@ -2713,8 +2713,148 @@ kernel void kernel_rope_neox(
}
}
template<typename T>
kernel void kernel_rope_multi(
constant ggml_metal_kargs_rope & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 tptg [[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) {
const int i3 = tgpig[2];
const int i2 = tgpig[1];
const int i1 = tgpig[0];
float corr_dims[2];
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
device const int32_t * pos = (device const int32_t *) src1;
const float inv_ndims = -1.f/args.n_dims;
float cos_theta;
float sin_theta;
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
if (i0 < args.n_dims) {
const int ic = i0/2;
// mrope theta calculations
// note: the rest is the same as kernel_rope_neox
const int sect_dims = args.sect_0 + args.sect_1 + args.sect_2 + args.sect_3;
const int sec_w01 = args.sect_0 + args.sect_1; // end of section 1
const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
const int sector = ic % sect_dims;
float theta_base;
if (sector < args.sect_0) {
theta_base = (float) pos[i2];
} else if (sector < sec_w01) {
theta_base = (float) pos[i2 + args.ne02];
} else if (sector < sec_w012) {
theta_base = (float) pos[i2 + args.ne02 * 2];
} else {
theta_base = (float) pos[i2 + args.ne02 * 3];
}
// end of mrope
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
const float x0 = src[0];
const float x1 = src[args.n_dims/2];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
} else {
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
template<typename T>
kernel void kernel_rope_vision(
constant ggml_metal_kargs_rope & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 tptg [[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) {
const int i3 = tgpig[2];
const int i2 = tgpig[1];
const int i1 = tgpig[0];
float corr_dims[2];
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
device const int32_t * pos = (device const int32_t *) src1;
const float inv_ndims = -1.f/args.n_dims;
float cos_theta;
float sin_theta;
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
if (i0 < 2*args.n_dims) { // different from kernel_rope_multi
const int ic = i0/2;
// mrope theta calculations (only support 2 dimensions)
const int sect_dims = args.sect_0 + args.sect_1;
const int sector = ic % sect_dims;
float p;
float theta_base;
if (sector < args.sect_1) {
p = (float) sector;
theta_base = (float) pos[i2];
} else {
p = (float) sector - args.sect_0;
theta_base = (float) pos[i2 + args.ne02];
}
const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
// end of mrope
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
const float x0 = src[0];
const float x1 = src[args.n_dims]; // different from kernel_rope_multi
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; // different from kernel_rope_multi
} else {
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
typedef decltype(kernel_rope_multi<float>) kernel_rope_multi_t;
typedef decltype(kernel_rope_vision<float>) kernel_rope_vision_t;
template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
@@ -2722,6 +2862,12 @@ template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_
template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi<float>;
template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi<half>;
template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
typedef void (im2col_t)(
device const float * x,
device char * dst,
@@ -3741,6 +3887,11 @@ kernel void kernel_flash_attn_ext_vec(
sm[tiisg] = pm[ic + tiisg];
}
// skip -INF blocks
if (simd_max(sm[tiisg]) == -INFINITY) {
continue;
}
// Q*K^T
{
// each simdgroup processes 1 query and NE (NW/NL) head elements
-2
View File
@@ -4855,8 +4855,6 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
if (!any_on_device) {
return false;
}
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
func = ggml_cl_add;
break;
case GGML_OP_MUL:
+374 -196
View File
@@ -28,16 +28,19 @@ struct ggml_opt_dataset {
};
struct ggml_opt_context {
ggml_backend_sched_t backend_sched = nullptr;
ggml_cgraph * allocated_graph = nullptr;
ggml_cgraph * allocated_graph_copy = nullptr;
struct ggml_context * ctx_static = nullptr;
struct ggml_context * ctx_static_cpu = nullptr;
struct ggml_context * ctx_compute = nullptr;
struct ggml_context * ctx_copy = nullptr;
ggml_backend_buffer_t buf_static = nullptr;
ggml_backend_buffer_t buf_static_cpu = nullptr;
std::mt19937 rng;
ggml_backend_sched_t backend_sched = nullptr;
ggml_cgraph * allocated_graph = nullptr;
ggml_cgraph * allocated_graph_copy = nullptr;
struct ggml_context * ctx_static = nullptr;
struct ggml_context * ctx_cpu = nullptr;
struct ggml_context * ctx_compute = nullptr;
struct ggml_context * ctx_copy = nullptr;
ggml_backend_buffer_t buf_static = nullptr;
ggml_backend_buffer_t buf_cpu = nullptr;
std::mt19937 rng;
enum ggml_opt_loss_type loss_type;
enum ggml_opt_build_type build_type;
enum ggml_opt_build_type build_type_alloc;
struct ggml_tensor * inputs = nullptr;
struct ggml_tensor * outputs = nullptr;
@@ -50,6 +53,11 @@ struct ggml_opt_context {
struct ggml_cgraph * gf = nullptr;
struct ggml_cgraph * gb_grad = nullptr;
struct ggml_cgraph * gb_opt = nullptr;
bool static_graphs = false;
bool eval_ready = false;
std::vector<struct ggml_tensor *> grad_accs;
std::vector<struct ggml_tensor *> grad_m;
std::vector<struct ggml_tensor *> grad_v;
int64_t iter = 1;
int32_t opt_period = 1;
@@ -73,7 +81,13 @@ struct ggml_opt_result {
// ====== Dataset ======
ggml_opt_dataset_t ggml_opt_dataset_init(int64_t ne_datapoint, int64_t ne_label, int64_t ndata, int64_t ndata_shard) {
ggml_opt_dataset_t ggml_opt_dataset_init(
enum ggml_type type_data,
enum ggml_type type_label,
int64_t ne_datapoint,
int64_t ne_label,
int64_t ndata,
int64_t ndata_shard) {
GGML_ASSERT(ne_datapoint > 0);
GGML_ASSERT(ne_label >= 0);
GGML_ASSERT(ndata > 0);
@@ -92,11 +106,11 @@ ggml_opt_dataset_t ggml_opt_dataset_init(int64_t ne_datapoint, int64_t ne_label,
result->ctx = ggml_init(params);
}
result->data = ggml_new_tensor_2d(result->ctx, GGML_TYPE_F32, ne_datapoint, ndata);
result->data = ggml_new_tensor_2d(result->ctx, type_data, ne_datapoint, ndata);
result->nbs_data = ggml_nbytes(result->data) * ndata_shard/ndata;
if (ne_label > 0) {
result->labels = ggml_new_tensor_2d(result->ctx, GGML_TYPE_F32, ne_label, ndata);
result->labels = ggml_new_tensor_2d(result->ctx, type_label, ne_label, ndata);
result->nbs_labels = ggml_nbytes(result->labels) * ndata_shard/ndata;
} else {
result->labels = nullptr;
@@ -119,6 +133,10 @@ void ggml_opt_dataset_free(ggml_opt_dataset_t dataset) {
delete dataset;
}
int64_t ggml_opt_dataset_ndata(ggml_opt_dataset_t dataset) {
return dataset->ndata;
}
struct ggml_tensor * ggml_opt_dataset_data(ggml_opt_dataset_t dataset) {
return dataset->data;
}
@@ -144,6 +162,8 @@ void ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor *
GGML_ASSERT( data_batch && ggml_is_contiguous(data_batch));
GGML_ASSERT(!labels_batch || ggml_is_contiguous(labels_batch));
GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr));
GGML_ASSERT( data_batch->type == dataset->data->type);
GGML_ASSERT(!labels_batch || labels_batch->type == dataset->labels->type);
const size_t nb_data_batch = ggml_nbytes(data_batch);
GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0);
@@ -171,6 +191,31 @@ void ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor *
}
}
void ggml_opt_dataset_get_batch_host(ggml_opt_dataset_t dataset, void * data_batch, size_t nb_data_batch, void * labels_batch, int64_t ibatch) {
GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr));
GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0);
const int64_t shards_per_batch = nb_data_batch / dataset->nbs_data;
GGML_ASSERT((ibatch + 1)*shards_per_batch <= int64_t(dataset->permutation.size()));
for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) {
const int64_t ishard = dataset->permutation[ibatch*shards_per_batch + ishard_batch];
const char * ptr_data = (const char *) dataset->data->data + ishard *dataset->nbs_data;
char * ptr_data_batch = (char *) data_batch + ishard_batch*dataset->nbs_data;
memcpy(ptr_data_batch, ptr_data, dataset->nbs_data);
if (!labels_batch) {
continue;
}
const char * ptr_labels = (const char *) dataset->labels->data + ishard *dataset->nbs_labels;
char * ptr_labels_batch = (char *) labels_batch + ishard_batch*dataset->nbs_labels;
memcpy(ptr_labels_batch, ptr_labels, dataset->nbs_labels);
}
}
// ====== Model / Context ======
struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata) {
@@ -187,17 +232,18 @@ struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * us
return result;
}
struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata) {
return *((struct ggml_opt_optimizer_params *) userdata);
}
struct ggml_opt_params ggml_opt_default_params(
ggml_backend_sched_t backend_sched,
struct ggml_context * ctx_compute,
struct ggml_tensor * inputs,
struct ggml_tensor * outputs,
enum ggml_opt_loss_type loss_type) {
return {
/*backend_sched =*/ backend_sched,
/*ctx_compute =*/ ctx_compute,
/*inputs =*/ inputs,
/*logits =*/ outputs,
/*ctx_compute =*/ nullptr,
/*inputs =*/ nullptr,
/*logits =*/ nullptr,
/*loss_type =*/ loss_type,
/*build_type =*/ GGML_OPT_BUILD_TYPE_OPT,
/*opt_period =*/ 1,
@@ -266,195 +312,246 @@ static ggml_cgraph * dup_graph(ggml_context * ctx, ggml_cgraph * src) {
return dst;
}
static void ggml_opt_alloc_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph) {
GGML_ASSERT(graph);
if (opt_ctx->allocated_graph == graph) {
return;
}
static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
GGML_ASSERT(opt_ctx->ctx_compute && "no compute context set, either use static graphs or set one with ggml_opt_prepare_alloc");
GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && "when using static graphs the inputs must be allocated statically");
ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph
const bool accumulate = opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD &&
!(opt_ctx->static_graphs && opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1);
{
ggml_init_params params = {
/*.mem_size =*/ ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE,
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
};
ggml_free(opt_ctx->ctx_copy);
opt_ctx->ctx_copy = ggml_init(params);
}
opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph);
ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
opt_ctx->allocated_graph = graph;
}
ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
ggml_opt_context_t result = new struct ggml_opt_context;
result->backend_sched = params.backend_sched;
result->ctx_compute = params.ctx_compute;
result->inputs = params.inputs;
result->outputs = params.outputs;
result->opt_period = params.opt_period;
result->get_opt_pars = params.get_opt_pars;
result->get_opt_pars_ud = params.get_opt_pars_ud;
GGML_ASSERT(result->inputs->data && "the inputs must be allocated statically");
GGML_ASSERT(result->opt_period >= 1);
const bool accumulate = params.build_type == GGML_OPT_BUILD_TYPE_GRAD ||
(params.build_type == GGML_OPT_BUILD_TYPE_OPT && result->opt_period > 1);
ggml_set_input(result->inputs);
ggml_set_output(result->outputs);
result->gf = ggml_new_graph_custom(result->ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass.
ggml_build_forward_expand(result->gf, result->outputs);
ggml_set_input(opt_ctx->inputs);
ggml_set_output(opt_ctx->outputs);
int n_param = 0;
for (int i = 0; i < result->gf->n_nodes; ++i) {
if (result->gf->nodes[i]->flags & GGML_TENSOR_FLAG_PARAM) {
for (int i = 0; i < opt_ctx->gf->n_nodes; ++i) {
const struct ggml_tensor * node = opt_ctx->gf->nodes[i];
if (node->flags & GGML_TENSOR_FLAG_PARAM) {
n_param++;
}
GGML_ASSERT(!(node->flags & GGML_TENSOR_FLAG_LOSS) && "support for extra loss terms not implemented");
}
{
if (!opt_ctx->ctx_static) {
// The static context is used for:
// - gradients (1 tensor per param if using gradient accumulation)
// - gradients (1 per loss, 1 tensor per param if using gradient accumulation)
// - optimizer momenta (2 tensors per param)
// - labels
// - loss + its gradient (up to 5 tensors)
// - pred
// - ncorrect (2 tensors).
const size_t tensors_per_param = (accumulate ? 1 : 0) + (params.build_type == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0);
const size_t size_meta = (tensors_per_param*n_param + 9) * ggml_tensor_overhead();
// - labels (if using static graphs)
// - loss (if using static graphs, up to 5 tensors)
// - pred (if using static graphs)
// - ncorrect (if using static graphs, 2 tensors).
constexpr size_t n_loss = 1;
const size_t tensors_per_param = (accumulate ? 1 : 0) +
(opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0);
const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0;
const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * ggml_tensor_overhead();
struct ggml_init_params params = {
/*.mem_size =*/ size_meta,
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
};
result->ctx_static = ggml_init(params);
opt_ctx->ctx_static = ggml_init(params);
}
GGML_ASSERT(opt_ctx->build_type <= opt_ctx->build_type_alloc);
{
// The static cpu context is used for:
// - optimizer parameters (1 for the entire context)
// The cpu context is allocated statically if using static graphs, dynamically otherwise.
// It is used for:
// - optimizer parameters (1 shared for all optimizer invocations)
const size_t size_meta = 1 * ggml_tensor_overhead();
struct ggml_init_params params = {
/*.mem_size =*/ size_meta,
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
};
result->ctx_static_cpu = ggml_init(params);
ggml_free(opt_ctx->ctx_cpu);
opt_ctx->ctx_cpu = ggml_init(params);
ggml_backend_buffer_free(opt_ctx->buf_cpu);
opt_ctx->buf_cpu = nullptr;
}
struct ggml_context * ctx_results = opt_ctx->static_graphs ? opt_ctx->ctx_static : opt_ctx->ctx_compute;
switch (params.loss_type) {
switch (opt_ctx->loss_type) {
case GGML_OPT_LOSS_TYPE_MEAN: {
result->loss = ggml_sum(result->ctx_static, result->outputs);
ggml_set_name(result->loss, "loss_sum");
const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs));
result->loss = ggml_scale(result->ctx_static, result->loss, scale);
ggml_set_name(result->loss, "loss_mean");
result->loss_per_datapoint = true;
opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->outputs);
ggml_set_name(opt_ctx->loss, "loss_sum");
const float scale = 1.0f / (opt_ctx->opt_period * ggml_nelements(opt_ctx->outputs));
opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, scale);
ggml_set_name(opt_ctx->loss, "loss_mean");
opt_ctx->loss_per_datapoint = true;
break;
}
case GGML_OPT_LOSS_TYPE_SUM: {
result->loss = ggml_sum(result->ctx_static, result->outputs);
ggml_set_name(result->loss, "loss_sum");
result->loss_per_datapoint = false;
opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->outputs);
ggml_set_name(opt_ctx->loss, "loss_sum");
opt_ctx->loss_per_datapoint = false;
break;
}
case GGML_OPT_LOSS_TYPE_CROSS_ENTROPY: {
result->labels = ggml_dup_tensor(result->ctx_static, result->outputs);
ggml_set_input(result->labels);
ggml_set_name(result->labels, "labels");
result->loss = ggml_cross_entropy_loss(result->ctx_static, result->outputs, result->labels);
ggml_set_name(result->loss, "loss_cross_entropy");
if (result->opt_period > 1) {
result->loss = ggml_scale(result->ctx_static, result->loss, 1.0f / result->opt_period);
ggml_set_name(result->loss, "loss_cross_entropy_scaled");
opt_ctx->labels = ggml_dup_tensor(ctx_results, opt_ctx->outputs);
ggml_set_input(opt_ctx->labels);
ggml_set_name(opt_ctx->labels, "labels");
opt_ctx->loss = ggml_cross_entropy_loss(ctx_results, opt_ctx->outputs, opt_ctx->labels);
ggml_set_name(opt_ctx->loss, "loss_cross_entropy");
if (opt_ctx->opt_period > 1) {
opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, 1.0f / opt_ctx->opt_period);
ggml_set_name(opt_ctx->loss, "loss_cross_entropy_scaled");
}
result->loss_per_datapoint = true;
opt_ctx->loss_per_datapoint = true;
break;
}
case GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR: {
result->labels = ggml_dup_tensor(result->ctx_static, result->outputs);
ggml_set_input(result->labels);
ggml_set_name(result->labels, "labels");
result->loss = ggml_sub(result->ctx_static, result->outputs, result->labels);
ggml_set_name(result->loss, "loss_error");
result->loss = ggml_sqr(result->ctx_static, result->loss);
ggml_set_name(result->loss, "loss_squared_error");
result->loss = ggml_sum(result->ctx_static, result->loss);
ggml_set_name(result->loss, "loss_sum_squared_error");
const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs));
result->loss = ggml_scale(result->ctx_static, result->loss, scale);
ggml_set_name(result->loss, "loss_mean_squared_error");
result->loss_per_datapoint = true;
opt_ctx->labels = ggml_dup_tensor(ctx_results, opt_ctx->outputs);
ggml_set_input(opt_ctx->labels);
ggml_set_name(opt_ctx->labels, "labels");
opt_ctx->loss = ggml_sub(ctx_results, opt_ctx->outputs, opt_ctx->labels);
ggml_set_name(opt_ctx->loss, "loss_error");
opt_ctx->loss = ggml_sqr(ctx_results, opt_ctx->loss);
ggml_set_name(opt_ctx->loss, "loss_squared_error");
opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->loss);
ggml_set_name(opt_ctx->loss, "loss_sum_squared_error");
const float scale = 1.0f / (opt_ctx->opt_period * ggml_nelements(opt_ctx->outputs));
opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, scale);
ggml_set_name(opt_ctx->loss, "loss_mean_squared_error");
opt_ctx->loss_per_datapoint = true;
break;
}
}
ggml_set_output(result->loss);
ggml_set_loss(result->loss);
ggml_build_forward_expand(result->gf, result->loss);
ggml_set_output(opt_ctx->loss);
ggml_set_loss(opt_ctx->loss);
ggml_build_forward_expand(opt_ctx->gf, opt_ctx->loss);
result->pred = ggml_argmax(result->ctx_static, result->outputs);
ggml_set_name(result->pred, "pred");
ggml_set_output(result->pred);
ggml_build_forward_expand(result->gf, result->pred);
if (opt_ctx->loss_type == GGML_OPT_LOSS_TYPE_CROSS_ENTROPY) {
opt_ctx->pred = ggml_argmax(ctx_results, opt_ctx->outputs);
ggml_set_name(opt_ctx->pred, "pred");
ggml_set_output(opt_ctx->pred);
ggml_build_forward_expand(opt_ctx->gf, opt_ctx->pred);
if (result->labels) {
result->ncorrect = ggml_count_equal(result->ctx_static, result->pred, ggml_argmax(result->ctx_static, result->labels));
ggml_set_name(result->ncorrect, "ncorrect");
ggml_set_output(result->ncorrect);
ggml_build_forward_expand(result->gf, result->ncorrect);
} else {
result->ncorrect = nullptr;
opt_ctx->ncorrect = ggml_count_equal(ctx_results, opt_ctx->pred, ggml_argmax(ctx_results, opt_ctx->labels));
ggml_set_name(opt_ctx->ncorrect, "ncorrect");
ggml_set_output(opt_ctx->ncorrect);
ggml_build_forward_expand(opt_ctx->gf, opt_ctx->ncorrect);
}
if (params.build_type == GGML_OPT_BUILD_TYPE_FORWARD) {
result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
return result;
if (opt_ctx->buf_static) {
if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_FORWARD) {
return;
}
} else if (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_FORWARD) {
opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(
opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
return;
}
// gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients.
result->gb_grad = ggml_graph_dup(result->ctx_compute, result->gf);
ggml_build_backward_expand(result->ctx_static, result->ctx_compute, result->gb_grad, accumulate);
if (opt_ctx->grad_accs.empty()) {
GGML_ASSERT(opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD);
if (params.build_type == GGML_OPT_BUILD_TYPE_GRAD) {
result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
ggml_graph_reset(result->gb_grad);
return result;
}
const int n_nodes = opt_ctx->gf->n_nodes;
opt_ctx->grad_accs.resize(n_nodes);
for (int i = 0; i < n_nodes; ++i) {
ggml_tensor * node = opt_ctx->gf->nodes[i];
if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) {
opt_ctx->grad_accs[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
} else {
opt_ctx->grad_accs[i] = nullptr;
}
}
GGML_ASSERT(params.build_type == GGML_OPT_BUILD_TYPE_OPT);
// gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
result->gb_opt = ggml_graph_dup(result->ctx_compute, result->gb_grad);
result->adamw_params = ggml_new_tensor_1d(result->ctx_static_cpu, GGML_TYPE_F32, 7);
ggml_set_input(result->adamw_params);
ggml_set_name(result->adamw_params, "adamw_params");
for (int i = result->gf->n_nodes-1; i >= 0; --i) {
struct ggml_tensor * node = result->gb_opt->nodes[i];
struct ggml_tensor * grad = ggml_graph_get_grad(result->gb_opt, node);
if (node->flags & GGML_TENSOR_FLAG_PARAM) {
struct ggml_tensor * m = ggml_dup_tensor(result->ctx_static, node);
struct ggml_tensor * v = ggml_dup_tensor(result->ctx_static, node);
struct ggml_tensor * opt_step = ggml_opt_step_adamw(result->ctx_compute, node, grad, m, v, result->adamw_params);
ggml_build_forward_expand(result->gb_opt, opt_step);
if (opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_OPT) {
opt_ctx->grad_m.resize(n_nodes);
opt_ctx->grad_v.resize(n_nodes);
for (int i = 0; i < n_nodes; ++i) {
ggml_tensor * node = opt_ctx->gf->nodes[i];
if (node->flags & GGML_TENSOR_FLAG_PARAM) {
opt_ctx->grad_m[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
opt_ctx->grad_v[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
} else {
opt_ctx->grad_m[i] = nullptr;
opt_ctx->grad_v[i] = nullptr;
}
}
}
}
result->buf_static = ggml_backend_alloc_ctx_tensors(
result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
// gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients.
opt_ctx->gb_grad = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gf, /*force_grads =*/ true);
ggml_build_backward_expand(opt_ctx->ctx_compute, opt_ctx->gb_grad, opt_ctx->grad_accs.data());
result->buf_static_cpu = ggml_backend_alloc_ctx_tensors_from_buft(result->ctx_static_cpu, ggml_backend_cpu_buffer_type());
if (opt_ctx->buf_static) {
if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_GRAD) {
return;
}
} else if (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_GRAD) {
opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
ggml_graph_reset(opt_ctx->gb_grad);
}
ggml_graph_reset(result->gb_opt);
GGML_ASSERT(opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT);
// gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
opt_ctx->gb_opt = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true);
opt_ctx->adamw_params = ggml_new_tensor_1d(opt_ctx->ctx_cpu, GGML_TYPE_F32, 7);
ggml_set_input(opt_ctx->adamw_params);
ggml_set_name(opt_ctx->adamw_params, "adamw_params");
for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) {
struct ggml_tensor * node = opt_ctx->gb_opt->nodes[i];
struct ggml_tensor * grad = ggml_graph_get_grad(opt_ctx->gb_opt, node);
if (grad && (node->flags & GGML_TENSOR_FLAG_PARAM)) {
struct ggml_tensor * m = opt_ctx->grad_m[i];
struct ggml_tensor * v = opt_ctx->grad_v[i];
struct ggml_tensor * opt_step = ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, opt_ctx->adamw_params);
ggml_set_name(m, (std::string("AdamW m for ") + std::string(node->name)).c_str());
ggml_set_name(v, (std::string("AdamW v for ") + std::string(node->name)).c_str());
ggml_set_name(opt_step, (std::string("AdamW step for ") + std::string(node->name)).c_str());
ggml_build_forward_expand(opt_ctx->gb_opt, opt_step);
}
}
if (!opt_ctx->buf_static) {
opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(
opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
ggml_graph_reset(opt_ctx->gb_opt);
}
opt_ctx->buf_cpu = ggml_backend_alloc_ctx_tensors_from_buft(opt_ctx->ctx_cpu, ggml_backend_cpu_buffer_type());
}
ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
ggml_opt_context_t result = new struct ggml_opt_context;
result->backend_sched = params.backend_sched;
result->ctx_compute = params.ctx_compute;
result->loss_type = params.loss_type;
result->build_type = params.build_type;
result->build_type_alloc = params.build_type;
result->inputs = params.inputs;
result->outputs = params.outputs;
result->opt_period = params.opt_period;
result->get_opt_pars = params.get_opt_pars;
result->get_opt_pars_ud = params.get_opt_pars_ud;
GGML_ASSERT(result->opt_period >= 1);
result->static_graphs = result->ctx_compute;
if (!result->static_graphs) {
GGML_ASSERT(!result->inputs);
GGML_ASSERT(!result->outputs);
return result;
}
GGML_ASSERT(result->inputs);
GGML_ASSERT(result->outputs);
result->gf = ggml_new_graph_custom(result->ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass.
ggml_build_forward_expand(result->gf, result->outputs);
ggml_opt_build(result);
return result;
}
@@ -464,9 +561,9 @@ void ggml_opt_free(ggml_opt_context_t opt_ctx) {
return;
}
ggml_backend_buffer_free(opt_ctx->buf_static);
ggml_backend_buffer_free(opt_ctx->buf_static_cpu);
ggml_backend_buffer_free(opt_ctx->buf_cpu);
ggml_free(opt_ctx->ctx_static);
ggml_free(opt_ctx->ctx_static_cpu);
ggml_free(opt_ctx->ctx_cpu);
delete opt_ctx;
}
@@ -582,8 +679,79 @@ void ggml_opt_result_accuracy(ggml_opt_result_t result, double * accuracy, doubl
// ====== Computation ======
static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph, ggml_opt_result * result) {
if (graph != opt_ctx->gf) {
void ggml_opt_prepare_alloc(
ggml_opt_context_t opt_ctx,
struct ggml_context * ctx_compute,
struct ggml_cgraph * gf,
struct ggml_tensor * inputs,
struct ggml_tensor * outputs) {
GGML_ASSERT(!opt_ctx->static_graphs);
opt_ctx->ctx_compute = ctx_compute;
opt_ctx->gf = gf;
opt_ctx->inputs = inputs;
opt_ctx->outputs = outputs;
}
void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward) {
GGML_ASSERT(!opt_ctx->eval_ready);
if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period > 1 && opt_ctx->opt_i == 0) {
ggml_graph_reset(opt_ctx->gb_grad);
}
if (backward) {
const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
opt_ctx->build_type = opt_i_next == 0 ? GGML_OPT_BUILD_TYPE_OPT : GGML_OPT_BUILD_TYPE_GRAD;
} else {
opt_ctx->build_type = GGML_OPT_BUILD_TYPE_FORWARD;
}
if (!opt_ctx->static_graphs) {
ggml_opt_build(opt_ctx);
}
struct ggml_cgraph * graph = nullptr;
switch (opt_ctx->build_type) {
case GGML_OPT_BUILD_TYPE_FORWARD: {
graph = opt_ctx->gf;
} break;
case GGML_OPT_BUILD_TYPE_GRAD: {
graph = opt_ctx->gb_grad;
} break;
case GGML_OPT_BUILD_TYPE_OPT: {
graph = opt_ctx->gb_opt;
} break;
}
GGML_ASSERT(graph);
if (opt_ctx->allocated_graph == graph) {
opt_ctx->eval_ready = true;
return;
}
ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph
if (opt_ctx->static_graphs) {
ggml_init_params params = {
/*.mem_size =*/ graph->size*ggml_tensor_overhead() + ggml_graph_overhead_custom(graph->size, graph->grads),
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
};
ggml_free(opt_ctx->ctx_copy);
opt_ctx->ctx_copy = ggml_init(params);
opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph);
} else {
opt_ctx->allocated_graph_copy = graph;
}
ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
opt_ctx->allocated_graph = graph;
opt_ctx->eval_ready = true;
}
void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) {
GGML_ASSERT(opt_ctx->eval_ready);
if (opt_ctx->allocated_graph == opt_ctx->gb_opt) {
struct ggml_opt_optimizer_params opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);
@@ -609,9 +777,19 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph,
adamw_par_data[6] = beta2h;
}
ggml_opt_alloc_graph(opt_ctx, graph);
ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
opt_ctx->iter += opt_ctx->allocated_graph == opt_ctx->gb_opt;
opt_ctx->opt_i = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
if (!opt_ctx->static_graphs) {
opt_ctx->gf = nullptr;
opt_ctx->gb_grad = nullptr;
opt_ctx->gb_opt = nullptr;
opt_ctx->allocated_graph = nullptr;
opt_ctx->allocated_graph_copy = nullptr;
}
opt_ctx->eval_ready = false;
if (!result) {
return;
@@ -635,12 +813,14 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph,
ggml_backend_tensor_get(opt_ctx->loss, &loss, 0, ggml_nbytes(opt_ctx->loss));
result->loss.push_back(loss);
GGML_ASSERT(opt_ctx->pred->type == GGML_TYPE_I32);
std::vector<int32_t> pred(ndata);
ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, ggml_nbytes(opt_ctx->pred));
result->pred.insert(result->pred.end(), pred.begin(), pred.end());
if (opt_ctx->pred) {
GGML_ASSERT(opt_ctx->pred->type == GGML_TYPE_I32);
std::vector<int32_t> pred(ndata);
ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, ggml_nbytes(opt_ctx->pred));
result->pred.insert(result->pred.end(), pred.begin(), pred.end());
}
if (!opt_ctx->labels || result->ncorrect < 0) {
if (!opt_ctx->ncorrect || result->ncorrect < 0) {
result->ncorrect = -1;
return;
}
@@ -652,26 +832,6 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph,
result->ncorrect += ncorrect;
}
void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) {
ggml_opt_eval_graph(opt_ctx, opt_ctx->gf, result);
}
void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) {
if (opt_ctx->opt_period == 1) {
ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result);
return;
}
const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
if (opt_i_next == 0) {
ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result);
ggml_opt_reset(opt_ctx, /*optimizer =*/ false);
} else {
ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_grad, result);
}
opt_ctx->opt_i = opt_i_next;
}
// ====== High-Level Functions ======
void ggml_opt_epoch(
@@ -700,16 +860,18 @@ void ggml_opt_epoch(
int64_t ibatch = 0;
int64_t t_loop_start = ggml_time_us();
for (; ibatch < ibatch_split; ++ibatch) {
ggml_opt_alloc(opt_ctx, /*backward =*/ true);
ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);
ggml_opt_forward_backward(opt_ctx, result_train);
ggml_opt_eval(opt_ctx, result_train);
if (callback_train) {
callback_train(true, opt_ctx, dataset, result_train, ibatch+1, ibatch_split, t_loop_start);
}
}
t_loop_start = ggml_time_us();
for (; ibatch < nbatches; ++ibatch) {
ggml_opt_alloc(opt_ctx, /*backward =*/ false);
ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);
ggml_opt_forward(opt_ctx, result_eval);
ggml_opt_eval(opt_ctx, result_eval);
if (callback_eval) {
callback_eval(false, opt_ctx, dataset, result_eval, ibatch+1-ibatch_split, nbatches-ibatch_split, t_loop_start);
}
@@ -726,13 +888,26 @@ void ggml_opt_epoch_callback_progress_bar(
int64_t t_start_us) {
fprintf(stderr, "%s[", train ? "train: " : "val: ");
constexpr int64_t bar_length = 25;
// The progress bar consists of partially filled blocks, unicode has 8 separate fill levels.
constexpr int64_t bar_length = 8;
const int64_t ibatch8 = 8 * ibatch;
for (int64_t j = 0; j < bar_length; ++j) {
const int64_t ibatch_j = ibatch_max * j/bar_length;
if (ibatch_j < ibatch) {
fprintf(stderr, "=");
} else if (ibatch_max * (j - 1)/bar_length < ibatch) {
fprintf(stderr, ">");
if (ibatch_max * (8*j + 8) / bar_length < ibatch8) {
fprintf(stderr, "\u2588"); // full block
} else if (ibatch_max * (8*j + 7) / bar_length < ibatch8) {
fprintf(stderr, "\u2589"); // 7/8 filled
} else if (ibatch_max * (8*j + 6) / bar_length < ibatch8) {
fprintf(stderr, "\u258A"); // 6/8 filled
} else if (ibatch_max * (8*j + 5) / bar_length < ibatch8) {
fprintf(stderr, "\u258B"); // 5/8 filled
} else if (ibatch_max * (8*j + 4) / bar_length < ibatch8) {
fprintf(stderr, "\u258C"); // 4/8 filled
} else if (ibatch_max * (8*j + 3) / bar_length < ibatch8) {
fprintf(stderr, "\u258D"); // 3/8 filled
} else if (ibatch_max * (8*j + 2) / bar_length < ibatch8) {
fprintf(stderr, "\u258E"); // 2/8 filled
} else if (ibatch_max * (8*j + 1) / bar_length < ibatch8) {
fprintf(stderr, "\u258F"); // 1/8 filled
} else {
fprintf(stderr, " ");
}
@@ -764,8 +939,8 @@ void ggml_opt_epoch_callback_progress_bar(
const int64_t t_eta_m = t_eta_s / 60;
t_eta_s -= t_eta_m * 60;
fprintf(stderr, "| data=%06" PRId64 "/%06" PRId64 ", loss=%.6lf+-%.6lf, accuracy=%.2lf+-%.2lf%%, "
"t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 ", ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 "]\r",
fprintf(stderr, "] data=%07" PRId64 "/%07" PRId64 " loss=%.5lf±%.5lf acc=%.2lf±%.2lf%% "
"t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 " ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 " \r",
idata, idata_max, loss, loss_unc, 100.0*accuracy, 100.0*accuracy_unc,
t_ibatch_h, t_ibatch_m, t_ibatch_s, t_eta_h, t_eta_m, t_eta_s);
if (ibatch == ibatch_max) {
@@ -806,7 +981,10 @@ void ggml_opt_fit(
int64_t epoch = 1;
ggml_opt_params params = ggml_opt_default_params(backend_sched, ctx_compute, inputs, outputs, loss_type);
ggml_opt_params params = ggml_opt_default_params(backend_sched, loss_type);
params.ctx_compute = ctx_compute;
params.inputs = inputs;
params.outputs = outputs;
params.opt_period = opt_period;
params.get_opt_pars = get_opt_pars;
params.get_opt_pars_ud = &epoch;
+29 -23
View File
@@ -49,35 +49,38 @@ endif()
target_compile_options(ggml-sycl PRIVATE "-Wno-narrowing")
# Link against oneDNN
find_package(DNNL)
set(GGML_SYCL_DNNL 0)
if(DNNL_FOUND)
if (DEFINED ENV{ONEAPI_ROOT} AND NOT DEFINED DNNL_GPU_VENDOR)
# Assuming oneDNN packaged with oneapi release is used which
# supports only intel target
set(DNNL_GPU_VENDOR "INTEL")
if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL")
message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target")
if(GGML_SYCL_DNN)
find_package(DNNL)
if(DNNL_FOUND)
if (NOT DEFINED DNNL_GPU_VENDOR)
# default to intel target
set(DNNL_GPU_VENDOR "INTEL")
if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL")
message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target")
endif()
endif()
endif()
# Verify oneDNN was compiled for the same target as llama
if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}")
target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
set(GGML_SYCL_DNNL 1)
get_target_property(CONFIGS DNNL::dnnl IMPORTED_CONFIGURATIONS)
foreach(CONFIG ${CONFIGS})
get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG})
message(STATUS "Found oneDNN: ${DNNL_LIB}")
endforeach()
# Verify oneDNN was compiled for the same target as llama
if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}")
target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
set(GGML_SYCL_DNNL 1)
get_target_property(CONFIGS DNNL::dnnl IMPORTED_CONFIGURATIONS)
foreach(CONFIG ${CONFIGS})
get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG})
message(STATUS "Found oneDNN: ${DNNL_LIB}")
endforeach()
else()
message(WARNING
"oneDNN must be compiled for the same target as llama.cpp.
llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}.
Disabling oneDNN support.")
endif()
else()
message(WARNING
"oneDNN must be compiled for the same target as llama.cpp.
llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}.
Disabling oneDNN support.")
message(STATUS "oneDNN not found, disabling oneDNN support")
endif()
else()
message(STATUS "oneDNN not found, disabling oneDNN support")
message(STATUS "oneDNN support disabled by the user")
endif()
target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL})
@@ -108,6 +111,9 @@ endif()
if (GGML_SYCL_TARGET STREQUAL "INTEL")
# Intel devices use Intel oneMKL directly instead of oneMath to avoid the limitation of linking Intel oneMKL statically
# See https://github.com/uxlfoundation/oneMath/issues/654
if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
set(SYCL_COMPILER ON)
endif()
find_package(MKL REQUIRED)
target_link_libraries(ggml-sycl PRIVATE MKL::MKL_SYCL::BLAS)
target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_USE_INTEL_ONEMKL)
+109 -220
View File
@@ -1,93 +1,74 @@
#include "binbcast.hpp"
#include <array>
#include <cstddef>
#include <cstdint>
#include <sycl/sycl.hpp>
#include "dpct/helper.hpp"
#include "ggml.h"
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
int ne0, int ne1, int ne2, int ne3,
int ne10, int ne11, int ne12, int ne13,
/*int s0, */ int s1, int s2, int s3,
/*int s00,*/ int s01, int s02, int s03,
/*int s10,*/ int s11, int s12, int s13,
const sycl::nd_item<3> &item_ct1) {
const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
item_ct1.get_local_id(1));
const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
item_ct1.get_local_id(0)) /
ne3;
const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
item_ct1.get_local_id(0)) %
ne3;
if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
return;
}
const int i11 = i1 % ne11;
const int i12 = i2 % ne12;
const int i13 = i3 % ne13;
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
const src0_t * src0_row = src0 + i_src0;
const src1_t * src1_row = src1 + i_src1;
dst_t * dst_row = dst + i_dst;
for (int i0 = i0s; i0 < ne0;
i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {
const int i10 = i0 % ne10;
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
static __dpct_inline__ void k_bin_bcast_contiguous(const src0_t * __restrict__ src0, const src1_t * __restrict__ src1,
dst_t * dst, std::size_t num_elements, const sycl::nd_item<1> & it) {
auto element_id = it.get_global_id(0);
auto global_range = it.get_global_range(0);
for (; element_id < num_elements; element_id += global_range) {
auto src0_float_val = sycl::vec(src0[element_id]).template convert<float, sycl::rounding_mode::rte>();
auto src1_float_val = sycl::vec(src1[element_id]).template convert<float, sycl::rounding_mode::rte>();
float dst_val = bin_op(src0_float_val[0], src1_float_val[0]);
auto val_to_store = sycl::vec(dst_val).template convert<dst_t, sycl::rounding_mode::rte>();
dst[element_id] = val_to_store;
}
}
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
int ne0, int ne1, int ne2, int ne3,
int ne10, int ne11, int ne12, int ne13,
/*int s0, */ int s1, int s2, int s3,
/*int s00,*/ int s01, int s02, int s03,
/*int s10,*/ int s11, int s12, int s13,
const sycl::nd_item<3> &item_ct1) {
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
static __dpct_inline__ void k_bin_bcast(const src0_t * __restrict__ src0, const src1_t * __restrict__ src1, dst_t * dst,
int ne0, int ne1, int ne2, int ne3, int ne10, int ne11, int ne12, int ne13,
int s0, int s1, int s2, int s3, int s00, int s01, int s02, int s03, int s10,
int s11, int s12, int s13, std::size_t num_dst_elements,
const sycl::nd_item<1> & item_ct1) {
auto calculate_logical_index =
[](const std::array<int, 4> & dims, std::size_t element_id) __attribute__((always_inline))->std::array<int, 4> {
std::array<int, 4> logical_index;
#pragma unroll(4)
for (int i = 3; i >= 0; i--) {
logical_index[i] = element_id % dims[i];
element_id /= dims[i];
}
return logical_index;
};
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
auto calculate_index = [](const std::array<int, 4> & dims, const std::array<int, 4> & strides,
const std::array<int, 4> & indices) __attribute__((always_inline))
->std::size_t {
std::size_t index = 0;
#pragma unroll(4)
for (int i = 0; i < 4; i++) {
auto index_i = indices[i];
if (indices[i] >= dims[i]) {
index_i = indices[i] % dims[i];
}
index += strides[i] * index_i;
}
return index;
};
const int i3 = i/(ne2*ne1*ne0);
const int i2 = (i/(ne1*ne0)) % ne2;
const int i1 = (i/ne0) % ne1;
const int i0 = i % ne0;
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
return;
auto element_id = item_ct1.get_global_id(0);
for (; element_id < num_dst_elements; element_id += item_ct1.get_global_range(0)) {
auto logical_index = calculate_logical_index({ ne3, ne2, ne1, ne0 }, element_id);
auto src_0_index = calculate_index({ ne3, ne2, ne1, ne0 }, { s03, s02, s01, s00 }, logical_index);
auto src_1_index = calculate_index({ ne13, ne12, ne11, ne10 }, { s13, s12, s11, s10 }, logical_index);
auto dst_index = calculate_index({ ne3, ne2, ne1, ne0 }, { s3, s2, s1, s0 }, logical_index);
auto src0_float_val = sycl::vec(src0[src_0_index]).template convert<float, sycl::rounding_mode::rte>();
auto src1_float_val = sycl::vec(src1[src_1_index]).template convert<float, sycl::rounding_mode::rte>();
float dst_val = bin_op(src0_float_val[0], src1_float_val[0]);
auto val_to_store = sycl::vec(dst_val).template convert<dst_t, sycl::rounding_mode::rte>();
dst[dst_index] = val_to_store;
}
const int i11 = i1 % ne11;
const int i12 = i2 % ne12;
const int i13 = i3 % ne13;
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
const src0_t * src0_row = src0 + i_src0;
const src1_t * src1_row = src1 + i_src1;
dst_t * dst_row = dst + i_dst;
const int i10 = i0 % ne10;
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
}
template<float (*bin_op)(const float, const float)>
struct bin_bcast_sycl {
template <float (*bin_op)(const float, const float)> struct bin_bcast_sycl {
template <typename src0_t, typename src1_t, typename dst_t>
void operator()(const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, const int64_t ne00,
const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t ne10, const int64_t ne11,
@@ -96,165 +77,73 @@ struct bin_bcast_sycl {
const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0,
const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguous,
const bool src1_is_contiguous, const bool dst_is_contiguous, queue_ptr stream) {
int nr0 = ne10 / ne0;
int nr1 = ne11/ne1;
int nr2 = ne12/ne2;
int nr3 = ne13/ne3;
int nr[4] = { nr0, nr1, nr2, nr3 };
// collapse dimensions until first broadcast dimension
int64_t cne[] = {ne0, ne1, ne2, ne3};
int64_t cne0[] = {ne00, ne01, ne02, ne03};
int64_t cne1[] = {ne10, ne11, ne12, ne13};
size_t cnb[] = {nb0, nb1, nb2, nb3};
size_t cnb0[] = {nb00, nb01, nb02, nb03};
size_t cnb1[] = {nb10, nb11, nb12, nb13};
auto collapse = [](int64_t cne[]) {
cne[0] *= cne[1];
cne[1] = cne[2];
cne[2] = cne[3];
cne[3] = 1;
};
auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
cnb[1] *= cne[1];
cnb[2] *= cne[2];
cnb[3] *= cne[3];
};
if (src0_is_contiguous && src1_is_contiguous && dst_is_contiguous) {
auto check_bcast_required = [](const std::array<int64_t, 4> & src_dims,
const std::array<int64_t, 4> & dst_dims) -> bool {
for (int i = 0; i < 4; i++) {
if (nr[i] != 1) {
break;
}
if (i > 0) {
collapse_nb(cnb, cne);
collapse_nb(cnb0, cne0);
collapse_nb(cnb1, cne1);
collapse(cne);
collapse(cne0);
collapse(cne1);
if (dst_dims[i] > src_dims[i]) {
return true;
}
}
}
{
int64_t ne0 = cne[0];
int64_t ne1 = cne[1];
int64_t ne2 = cne[2];
int64_t ne3 = cne[3];
return false;
};
int64_t ne10 = cne1[0];
int64_t ne11 = cne1[1];
int64_t ne12 = cne1[2];
int64_t ne13 = cne1[3];
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
size_t nb0 = cnb[0];
size_t nb1 = cnb[1];
size_t nb2 = cnb[2];
size_t nb3 = cnb[3];
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
size_t nb00 = cnb0[0];
size_t nb01 = cnb0[1];
size_t nb02 = cnb0[2];
size_t nb03 = cnb0[3];
GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
size_t nb10 = cnb1[0];
size_t nb11 = cnb1[1];
size_t nb12 = cnb1[2];
size_t nb13 = cnb1[3];
GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
size_t s0 = nb0 / sizeof(dst_t);
size_t s1 = nb1 / sizeof(dst_t);
size_t s2 = nb2 / sizeof(dst_t);
size_t s3 = nb3 / sizeof(dst_t);
// dst strides in number of elements
size_t s0 = nb0 / sizeof(dst_t);
size_t s1 = nb1 / sizeof(dst_t);
size_t s2 = nb2 / sizeof(dst_t);
size_t s3 = nb3 / sizeof(dst_t);
size_t s10 = nb10 / sizeof(src1_t);
size_t s11 = nb11 / sizeof(src1_t);
size_t s12 = nb12 / sizeof(src1_t);
size_t s13 = nb13 / sizeof(src1_t);
// src1 strides in number of elements
size_t s10 = nb10 / sizeof(src0_t);
size_t s11 = nb11 / sizeof(src1_t);
size_t s12 = nb12 / sizeof(src1_t);
size_t s13 = nb13 / sizeof(src1_t);
size_t s00 = nb00 / sizeof(src0_t);
size_t s01 = nb01 / sizeof(src0_t);
size_t s02 = nb02 / sizeof(src0_t);
size_t s03 = nb03 / sizeof(src0_t);
// src0 strides in number of elements
size_t s00 = nb00 / sizeof(src0_t);
size_t s01 = nb01 / sizeof(src0_t);
size_t s02 = nb02 / sizeof(src0_t);
size_t s03 = nb03 / sizeof(src0_t);
GGML_UNUSED(s00);
std::size_t num_dst_elements = static_cast<std::size_t>(ne0) * static_cast<std::size_t>(ne1) *
static_cast<std::size_t>(ne2) * static_cast<std::size_t>(ne3);
std::size_t local_range = 256;
std::size_t global_range = ceil_div(num_dst_elements, local_range) * local_range;
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
bool needs_broadcasting = check_bcast_required({ ne00, ne01, ne02, ne03 }, { ne0, ne1, ne2, ne3 }) ||
check_bcast_required({ ne10, ne11, ne12, ne13 }, { ne0, ne1, ne2, ne3 });
bool all_contiguous = src0_is_contiguous && src1_is_contiguous && dst_is_contiguous;
GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
GGML_ASSERT(s0 == 1);
GGML_ASSERT(s10 == 1);
const int block_size = 128;
int64_t hne0 = std::max(ne0/2LL, 1LL);
sycl::range<3> block_dims(1, 1, 1);
block_dims[2] = std::min<unsigned int>(hne0, block_size);
block_dims[1] = std::min<unsigned int>(
ne1, block_size / (unsigned int)block_dims[2]);
block_dims[0] = std::min(
std::min<unsigned int>(
ne2 * ne3, block_size / (unsigned int)block_dims[2] /
(unsigned int)block_dims[1]),
64U);
sycl::range<3> block_nums(
(ne2 * ne3 + block_dims[0] - 1) / block_dims[0],
(ne1 + block_dims[1] - 1) / block_dims[1],
(hne0 + block_dims[2] - 1) / block_dims[2]);
if (block_nums[0] > 65535) {
// this is the maximum number of blocks in z direction, fallback to 1D grid kernel
int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
sycl::range<3>(1, 1, block_size),
sycl::range<3>(1, 1, block_size)),
[=](sycl::nd_item<3> item_ct1) {
k_bin_bcast_unravel<bin_op>(
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02,
s03, s11, s12, s13, item_ct1);
});
}
} else {
/*
DPCT1049:16: The work-group size passed to the SYCL kernel may
exceed the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if
needed.
*/
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
ne2, ne3, ne10, ne11, ne12, ne13,
s1, s2, s3, s01, s02, s03, s11, s12, s13,
item_ct1);
});
}
if (! needs_broadcasting && all_contiguous) {
stream->submit([&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), [=](sycl::nd_item<1> it) {
k_bin_bcast_contiguous<bin_op>(src0_dd, src1_dd, dst_dd, num_dst_elements, it);
});
});
} else {
stream->submit([&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), [=](sycl::nd_item<1> it) {
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, ne10, ne11, ne12, ne13, s0, s1,
s2, s3, s00, s01, s02, s03, s10, s11, s12, s13, num_dst_elements, it);
});
});
}
}
};
+29 -2
View File
@@ -183,6 +183,24 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int64_t k,
}
}
template <typename dst_t>
static void dequantize_row_q4_K_sycl_reorder(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {
const int64_t nb = k / QK_K;
const size_t local_size = 32;
const size_t global_size = nb * local_size;
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
stream->submit([&](sycl::handler & cgh) {
sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
cgh.parallel_for(sycl::nd_range<1>(sycl::range<1>(global_size), sycl::range<1>(local_size)),
[=](sycl::nd_item<1> item_ct1) {
dequantize_block_q4_K_reorder(vx, y, get_pointer(scale_local_acc), item_ct1, nb);
});
});
}
template <typename dst_t>
static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) {
@@ -504,7 +522,11 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
case GGML_TYPE_Q3_K:
return dequantize_row_q3_K_sycl;
case GGML_TYPE_Q4_K:
return dequantize_row_q4_K_sycl;
if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
return dequantize_row_q4_K_sycl_reorder;
} else {
return dequantize_row_q4_K_sycl;
}
case GGML_TYPE_Q5_K:
return dequantize_row_q5_K_sycl;
case GGML_TYPE_Q6_K:
@@ -556,7 +578,12 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
case GGML_TYPE_Q3_K:
return dequantize_row_q3_K_sycl;
case GGML_TYPE_Q4_K:
return dequantize_row_q4_K_sycl;
if (dst->src[0]->extra &&
((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) {
return dequantize_row_q4_K_sycl_reorder;
} else {
return dequantize_row_q4_K_sycl;
}
case GGML_TYPE_Q5_K:
return dequantize_row_q5_K_sycl;
case GGML_TYPE_Q6_K:
+59 -21
View File
@@ -357,6 +357,28 @@ static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8
}
#endif
template <typename dst_t>
inline void dequantize_q4_K_common(dst_t * __restrict__ y, const uint8_t * __restrict__ qs_ptr, const float dall,
const float dmin, uint8_t * __restrict__ scales_local, int il, int ir) {
const int is = 2 * il;
constexpr int n = 4;
uint8_t sc, m;
get_scale_min_k4(is + 0, scales_local, sc, m);
const float d1 = dall * sc;
const float m1 = dmin * m;
get_scale_min_k4(is + 1, scales_local, sc, m);
const float d2 = dall * sc;
const float m2 = dmin * m;
sycl::vec<uint8_t, n> q_vec = vec_aligned_load<uint8_t, n>(qs_ptr + 32 * il + n * ir);
for (int l = 0; l < n; ++l) {
y[l + 0] = d1 * (q_vec[l] & 0xF) - m1;
y[l + 32] = d2 * (q_vec[l] >> 4) - m2;
}
}
template<typename dst_t>
static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) {
@@ -365,36 +387,22 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
const int64_t i = item_ct1.get_group(2);
#if QK_K == 256
// assume 32 threads
const int64_t tid = item_ct1.get_local_id(2);
const int64_t il = tid/8;
const int64_t ir = tid%8;
const int64_t is = 2*il;
const int64_t n = 4;
const int64_t il = tid / 8;
const int64_t ir = tid % 8;
dst_t * y = yy + i*QK_K + 64*il + n*ir;
dst_t * y = yy + i * QK_K + 64 * il + 4 * ir;
const sycl::half2 dm = x[i].dm;
const float dall = dm[0];
const float dmin = dm[1];
if (tid < 12)
if (tid < 12) {
scales_local[tid] = x[i].scales[tid];
item_ct1.barrier(sycl::access::fence_space::local_space);
uint8_t sc, m;
get_scale_min_k4(is + 0, scales_local, sc, m);
const float d1 = dall * sc;
const float m1 = dmin * m;
get_scale_min_k4(is + 1, scales_local, sc, m);
const float d2 = dall * sc;
const float m2 = dmin * m;
sycl::vec<uint8_t, n> q_vec = vec_aligned_load<uint8_t, n>(x[i].qs + 32*il + n*ir);
for (int l = 0; l < n; ++l) {
y[l + 0] = d1 * (q_vec[l] & 0xF) - m1;
y[l +32] = d2 * (q_vec[l] >> 4) - m2;
}
item_ct1.barrier(sycl::access::fence_space::local_space);
dequantize_q4_K_common(y, x[i].qs, dall, dmin, scales_local, il, ir);
#else
const int64_t tid = item_ct1.get_local_id(2);
const uint8_t * q = x[i].qs;
@@ -406,6 +414,36 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
#endif
}
template <typename dst_t>
static void dequantize_block_q4_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, uint8_t * scales_local,
const sycl::nd_item<1> & item_ct1, int64_t nb) {
const int64_t i = item_ct1.get_group(0); // block index
const int64_t tid = item_ct1.get_local_id(0); // thread index within block
const int64_t il = tid / 8;
const int64_t ir = tid % 8;
dst_t * y = yy + i * QK_K + 64 * il + 4 * ir;
const uint8_t * base = static_cast<const uint8_t *>(vx);
const size_t qs_offset = i * (QK_K / 2);
const size_t scales_offset = nb * (QK_K / 2) + i * K_SCALE_SIZE;
const size_t dm_offset = nb * (QK_K / 2) + nb * K_SCALE_SIZE + i * sizeof(ggml_half2);
const uint8_t * qs_ptr = base + qs_offset;
const uint8_t * scales_ptr = base + scales_offset;
ggml_half2 dm_values = *reinterpret_cast<const ggml_half2 *>(base + dm_offset);
const float dall = dm_values.x();
const float dmin = dm_values.y();
if (tid < 12) {
scales_local[tid] = scales_ptr[tid];
}
item_ct1.barrier(sycl::access::fence_space::local_space);
dequantize_q4_K_common(y, qs_ptr, dall, dmin, scales_local, il, ir);
}
template<typename dst_t>
static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
const sycl::nd_item<3> &item_ct1) {
+7 -1
View File
@@ -1129,7 +1129,13 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
break;
case GGML_TYPE_Q4_K:
dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
// reorder is currently not supported for dmmv
GGML_ABORT("Unimplemented dequantize case case for q4_k reorder");
} else {
dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
}
break;
case GGML_TYPE_Q5_K:
dequantize_mul_mat_vec_q5_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
-23
View File
@@ -655,7 +655,6 @@ inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
}
default:
GGML_ABORT("GGML tensor type not supported!\n");
break;
}
}
@@ -688,7 +687,6 @@ inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
}
default:
GGML_ABORT("GGML tensor type not supported!\n");
break;
}
}
@@ -722,7 +720,6 @@ inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
}
default:
GGML_ABORT("GGML tensor type not supported!\n");
break;
}
}
@@ -754,7 +751,6 @@ inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst
}
default:
GGML_ABORT("GGML tensor type not supported!\n");
break;
}
}
@@ -786,7 +782,6 @@ inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst
}
default:
GGML_ABORT("GGML tensor type not supported!\n");
break;
}
}
@@ -818,7 +813,6 @@ inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor
}
default:
GGML_ABORT("GGML tensor type not supported!\n");
break;
}
}
@@ -850,7 +844,6 @@ inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst
}
default:
GGML_ABORT("GGML tensor type not supported!\n");
break;
}
}
@@ -883,7 +876,6 @@ inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst
}
default:
GGML_ABORT("GGML tensor type not supported!\n");
break;
}
}
@@ -917,7 +909,6 @@ inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tenso
}
default:
GGML_ABORT("GGML tensor type not supported!\n");
break;
}
}
@@ -949,7 +940,6 @@ inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor
}
default:
GGML_ABORT("GGML tensor type not supported!\n");
break;
}
}
@@ -981,7 +971,6 @@ inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
}
default:
GGML_ABORT("GGML tensor type not supported!\n");
break;
}
}
@@ -1013,7 +1002,6 @@ inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
}
default:
GGML_ABORT("GGML tensor type not supported!\n");
break;
}
}
@@ -1045,7 +1033,6 @@ inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor *
}
default:
GGML_ABORT("GGML tensor type not supported!\n");
break;
}
}
@@ -1078,7 +1065,6 @@ inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst
}
default:
GGML_ABORT("GGML tensor type not supported!\n");
break;
}
}
@@ -1110,7 +1096,6 @@ inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
}
default:
GGML_ABORT("GGML tensor type not supported!\n");
break;
}
}
@@ -1142,7 +1127,6 @@ inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
}
default:
GGML_ABORT("GGML tensor type not supported!\n");
break;
}
}
@@ -1174,7 +1158,6 @@ inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst
}
default:
GGML_ABORT("GGML tensor type not supported!\n");
break;
}
}
@@ -1206,7 +1189,6 @@ inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
}
default:
GGML_ABORT("GGML tensor type not supported!\n");
break;
}
}
@@ -1241,7 +1223,6 @@ inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor
}
default:
GGML_ABORT("GGML tensor type not supported!\n");
break;
}
}
@@ -1273,7 +1254,6 @@ inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
}
default:
GGML_ABORT("GGML tensor type not supported!\n");
break;
}
}
@@ -1315,7 +1295,6 @@ inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor *
}
default:
GGML_ABORT("GGML tensor type not supported!\n");
break;
}
}
@@ -1350,7 +1329,6 @@ inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
}
default:
GGML_ABORT("GGML tensor type not supported!\n");
break;
}
}
@@ -1388,7 +1366,6 @@ inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * ds
}
default:
GGML_ABORT("GGML tensor type not supported!\n");
break;
}
}
+37 -8
View File
@@ -32,16 +32,36 @@ public:
else static_assert(0);
}
static inline void row_gemm(ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k,
const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
// matrix A has m rows, k columns
// matrix B has k rows, n columns
// nra - number of elements to skip when moving into next row in A
// nrb - number of elements to skip when moving into next row in B
// nca - number of elements to skip when moving into next column in A
// ncb - number of elements to skip when moving into next column in B
// stride_a - number of elements to skip when moving to next A matrix
// stride_b - number of elements to skip when moving to next B matrix
// batches_a - number of A matrices
// batches_b - number of B matrices
static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
const void * a, dt at, dnnl_dim_t nra, dnnl_dim_t nca, dnnl_dim_t stride_a,
const void * b, dt bt, dnnl_dim_t nrb, dnnl_dim_t ncb, dnnl_dim_t stride_b,
void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches_a, dnnl_dim_t batches_b) {
auto stream = ctx.stream_dnnl(q);
auto eng = ctx.engine_dnnl(q);
dnnl::memory::dims a_dims = { m, k };
dnnl::memory::dims b_dims = { k, n };
dnnl::memory::dims c_dims = { m, n };
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
// { # strides, # rows, # columns }
dnnl::memory::dims a_dims = { batches_a, m, k };
dnnl::memory::dims b_dims = { batches_b, k, n };
dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n };
// { # elements to skip to next stride, # elements to skip to next row, # elements to skip to next column }
dnnl::memory::dims a_strides = { stride_a, nra, nca };
dnnl::memory::dims b_strides = { stride_b, nrb, ncb };
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides);
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides);
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::abc);
dnnl::primitive_attr primitive_attr;
primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
@@ -63,6 +83,15 @@ public:
matmul_prim.execute(stream, matmul_args);
}
// matrices A and B are column major, both having k rows
// matrix A has m column, matrix B has n columns
// output: column major matrix C = A transposed * B
static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
gemm(ctx, m, n, k, a, at, k, 1, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1);
}
};
#endif
+190 -80
View File
@@ -49,6 +49,7 @@ static bool g_sycl_loaded = false;
int g_ggml_sycl_debug = 0;
int g_ggml_sycl_disable_optimize = 0;
int g_ggml_sycl_disable_graph = 0;
int g_ggml_sycl_disable_dnn = 0;
int g_ggml_sycl_prioritize_dmmv = 0;
static ggml_sycl_device_info ggml_sycl_init() {
@@ -196,12 +197,22 @@ static void ggml_check_sycl() try {
g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 1);
g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0);
g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
GGML_LOG_INFO("Running with Environment Variables:\n");
GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
#ifdef GGML_SYCL_GRAPH
GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph);
#else
GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: graph disabled by compile flag\n");
#endif
#if GGML_SYCL_DNNL
GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn);
#else
GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n");
#endif
GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv);
GGML_LOG_INFO("Build with Macros:\n");
#if defined(GGML_SYCL_FORCE_MMQ)
@@ -341,7 +352,7 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
assert(tensor->view_src->buffer->buft == buffer->buft);
return GGML_STATUS_SUCCESS;
}
if (tensor->type == GGML_TYPE_Q4_0 && !g_ggml_sycl_disable_optimize) {
if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K) && !g_ggml_sycl_disable_optimize) {
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
tensor->extra = extra;
ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
@@ -1985,19 +1996,18 @@ inline void ggml_sycl_op_mul_mat_sycl(
const int64_t ne00 = src0->ne[0];
const int64_t ne10 = src1->ne[0];
GGML_ASSERT(ne00 == ne10);
const int64_t row_diff = row_high - row_low;
int id;
SYCL_CHECK(
CHECK_TRY_ERROR(id = get_current_device_id()));
#if !GGML_SYCL_DNNL
const int64_t ne0 = dst->ne[0];
const int64_t ne0 = dst->ne[0]; // used by MKL only
// the main device has a larger memory buffer to hold the results from all GPUs
// ldc == nrows of the matrix that cuBLAS writes into
int ldc = id == ctx.device ? ne0 : row_diff;
#endif
int ldc = id == ctx.device ? ne0 : row_diff; // used by MKL only
#ifdef GGML_SYCL_F16
bool use_fp16 = true; // TODO(Yu) SYCL capability check
@@ -2033,25 +2043,29 @@ inline void ggml_sycl_op_mul_mat_sycl(
: src1_as_f16.get();
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
#if !GGML_SYCL_DNNL
const sycl::half alpha_f16 = 1.0f;
const sycl::half beta_f16 = 0.0f;
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
*stream, oneapi::math::transpose::trans,
oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10,
&alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
dst_f16.get(), dpct::library_data_t::real_half, ldc,
dpct::library_data_t::real_half)));
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
#else
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ptr,
DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
#if GGML_SYCL_DNNL
if (!g_ggml_sycl_disable_dnn) {
DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr,
DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
}
else
#endif
{
const sycl::half alpha_f16 = 1.0f;
const sycl::half beta_f16 = 0.0f;
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
*stream, oneapi::math::transpose::trans,
oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10,
&alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
dst_f16.get(), dpct::library_data_t::real_half, ldc,
dpct::library_data_t::real_half)));
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
}
}
else {
// GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
@@ -2072,18 +2086,22 @@ inline void ggml_sycl_op_mul_mat_sycl(
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
#if !GGML_SYCL_DNNL
const float alpha = 1.0f;
const float beta = 0.0f;
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm(
get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff,
src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
#else
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i,
DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
#if GGML_SYCL_DNNL
if (!g_ggml_sycl_disable_dnn) {
DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ddf1_i,
DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
}
else
#endif
{
const float alpha = 1.0f;
const float beta = 0.0f;
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm(
get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff,
src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
}
}
GGML_UNUSED(dst);
GGML_UNUSED(src1_ddq_i);
@@ -2697,7 +2715,7 @@ catch (sycl::exception const &exc) {
std::exit(1);
}
static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, char * dst,
static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, void * dst,
const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23,
size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3,
int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) {
@@ -2713,7 +2731,7 @@ static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::h
const uint8_t * src0_bytes = reinterpret_cast<const uint8_t *>(src0_as_f16);
const uint8_t * src1_bytes = reinterpret_cast<const uint8_t *>(src1_as_f16);
uint8_t * dst_bytes = reinterpret_cast<uint8_t *>(dst);
uint8_t * dst_bytes = static_cast<uint8_t *>(dst);
ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03;
ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13;
@@ -2726,6 +2744,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
GGML_ASSERT(!ggml_is_transposed(src1));
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_TENSOR_BINARY_OP_LOCALS
@@ -2766,7 +2785,6 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
}
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
char * dst_t = reinterpret_cast<char *>(dst_ddf);
dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float;
dpct::library_data_t mkl_data_type = dpct::library_data_t::real_float;
@@ -2783,42 +2801,83 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0);
GGML_ASSERT(ne01 == static_cast<int64_t>(nb1/nb0));
GGML_ASSERT(ne10 == ne00);
// broadcast factors
const int64_t r2 = ne12 / ne02;
const int64_t r3 = ne13 / ne03;
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_t,
mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
} else {
const int ne23 = ne12 * ne13;
#if GGML_SYCL_DNNL
if (!g_ggml_sycl_disable_dnn) {
auto dnn_gemm = [&ctx, queue, ne11, ne01, ne10, nb00, nb01, nb02, s11, s12]
(const sycl::half* src1, const sycl::half* src0, float* dst, const dnnl_dim_t batches_a, const dnnl_dim_t batches_b) {
ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2 * ne23);
ggml_sycl_pool_alloc<void *> ptrs_dst(ctx.pool(), 1 * ne23);
ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
DnnlGemmWrapper::gemm(ctx, ne11,ne01, ne10,
src1, DnnlGemmWrapper::to_dt<sycl::half>(), s11, 1, s12,
src0, DnnlGemmWrapper::to_dt<sycl::half>(), 1, nb01/nb00, nb02/nb00,
dst, DnnlGemmWrapper::to_dt<float>(), queue, batches_a, batches_b);
};
sycl::range<3> block_dims(1, ne12, ne13);
queue->submit([&](sycl::handler & cgh) {
const void ** ptrs_src_get = ptrs_src.get();
void ** ptrs_dst_get = ptrs_dst.get();
size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
k_compute_batched_ptrs(src0_f16, src1_f16, dst_t, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
if (r2 == 1 && r3 == 1) {
if (ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
dnn_gemm(src1_f16, src0_f16, dst_ddf, ne12*ne13, ne02 * ne03);
}
else {
for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
const sycl::half* src0_f16_shifted = src0_f16 + ((ie03*nb03)/sizeof(sycl::half)); // nb is in bytes
const sycl::half* src1_f16_shifted = src1_f16 + ie03*s13;
float* dst_shifted = dst_ddf + ((ie03*nb3)/sizeof(float));
dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, ne12, ne02);
}
}
} else {
// iterate over batches from smaller set of matrices (matrix 0)
for (int64_t ie02 = 0; ie02 < ne02; ++ie02) {
for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
const sycl::half* src0_f16_shifted = src0_f16 + ((ie02*nb02 + ie03*nb03)/sizeof(sycl::half));
const sycl::half* src1_f16_shifted = src1_f16 + ie02*s12*r2 + ie03*s13*r3;
float* dst_shifted = dst_ddf + ((ie02*nb2*r2 + ie03*nb3*r3)/sizeof(float));
dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, r2*r3, 1);
}
}
}
}
else
#endif
{
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_ddf,
mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
} else {
const int ne23 = ne12 * ne13;
ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2 * ne23);
ggml_sycl_pool_alloc<void *> ptrs_dst(ctx.pool(), 1 * ne23);
ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
sycl::range<3> block_dims(1, ne12, ne13);
queue->submit([&](sycl::handler & cgh) {
const void ** ptrs_src_get = ptrs_src.get();
void ** ptrs_dst_get = ptrs_dst.get();
size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
});
});
});
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
*queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
(const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
(const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
(void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
*queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
(const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
(const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
(void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
}
}
} catch (const sycl::exception & exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
@@ -2841,6 +2900,8 @@ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0:
return true;
case GGML_TYPE_Q4_K:
return !g_ggml_sycl_prioritize_dmmv;
default:
return false;
}
@@ -2858,6 +2919,7 @@ inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) {
inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_K:
return true;
default:
return false;
@@ -2883,16 +2945,16 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
}
}
static void reorder_qw(char *data_device, const int ncols, const int nrows,
size_t size, size_t offset, dpct::queue_ptr stream) {
auto tmp_buf = sycl::malloc_shared<char>(size, *stream);
static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset,
dpct::queue_ptr stream) {
auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
SYCL_CHECK(
CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size)
.wait()));
GGML_ASSERT((size % sizeof(block_q4_0) == 0));
GGML_ASSERT((offset % sizeof(block_q4_0) == 0));
int offset_blks = offset / sizeof(block_q4_0);
auto qs_ptr = (uint8_t*)data_device + offset_blks * QK4_0 / 2;
auto qs_ptr = data_device + offset_blks * QK4_0 / 2;
auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks;
stream->parallel_for(
@@ -2906,18 +2968,59 @@ static void reorder_qw(char *data_device, const int ncols, const int nrows,
*(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j];
}
*(d_ptr + ib) = x[ib].d;
});
}).wait_and_throw();
sycl::free(tmp_buf, *stream);
}
static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
GGML_ASSERT(size % sizeof(block_q4_K) == 0);
GGML_ASSERT(offset % sizeof(block_q4_K) == 0);
const int nblocks = size / sizeof(block_q4_K);
auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait()));
auto * qs_ptr = data_device;
auto * scales_ptr = qs_ptr + QK_K / 2 * nblocks;
auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks);
stream->parallel_for(nblocks, [=](auto i) {
const block_q4_K * x = (const block_q4_K *) tmp_buf;
const int ib = i;
for (int j = 0; j < QK_K / 2; ++j) {
qs_ptr[ib * (QK_K / 2) + j] = x[ib].qs[j];
}
for (int j = 0; j < K_SCALE_SIZE; ++j) {
scales_ptr[ib * K_SCALE_SIZE + j] = x[ib].scales[j];
}
dm_ptr[ib] = x[ib].dm;
}).wait_and_throw();
sycl::free(tmp_buf, *stream);
}
static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
char*data_device = (char*)src0->data;
uint8_t * data_device = (uint8_t *) src0->data;
size_t ncols = src0->ne[0];
size_t nrows = src0->ne[1];
size_t size = ggml_nbytes(src0);
reorder_qw(data_device, ncols, nrows, size, 0, stream);
switch (src0->type) {
case GGML_TYPE_Q4_0:
reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream);
break;
case GGML_TYPE_Q4_K:
reorder_qw_q4_k(data_device, size, 0, stream);
break;
default:
GGML_ABORT("reorder_qw() called with unsupported type");
break;
}
}
static bool should_reorder_tensor(ggml_backend_sycl_context& ctx, const ggml_tensor * dst) {
@@ -2960,8 +3063,18 @@ static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor *
extra->optimized_feature.reorder = true; // Used to decode/dequan in next steps and avoid re-reordering
}
static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
static bool can_use_dequantize_mul_mat_vec(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
return ggml_sycl_supports_dmmv(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
}
static bool can_use_mul_mat_vec_q(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
return ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
}
static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
int64_t min_compute_capability = INT_MAX;
@@ -2984,13 +3097,9 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
}
// check data types and tensor shapes for custom matrix multiplication kernels:
bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
&& src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
bool use_dequantize_mul_mat_vec = can_use_dequantize_mul_mat_vec(src0, src1, dst);
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
bool use_mul_mat_vec_q = can_use_mul_mat_vec_q(src0, src1, dst);
bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
@@ -3713,7 +3822,8 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_
return GGML_STATUS_SUCCESS;
}
sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()));
sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()), {sycl_ex::property::graph::assume_buffer_outlives_graph{}});
model_sycl_graph.begin_recording(*(sycl_ctx->stream()));
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
model_sycl_graph.end_recording();
+29 -2
View File
@@ -24,6 +24,7 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
const int blocks_per_row = ncols / block_traits::qk;
constexpr int blocks_per_subgroup = ceil_div(block_traits::vdr_mmvq * WARP_SIZE, block_traits::qi);
constexpr int block_elements_per_subgroup = block_traits::qi / block_traits::vdr_mmvq;
const int nblocks = nrows * (ncols / block_traits::qk);
static_assert(blocks_per_subgroup > 0);
static_assert(block_elements_per_subgroup > 0);
@@ -45,7 +46,7 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
// x block quant index when casting the quants to int
const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup);
partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, &y[iby], iqs);
partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, &y[iby], iqs, nblocks);
}
}
@@ -739,6 +740,27 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
}
}
static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
const int nrows, dpct::queue_ptr stream) {
GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
constexpr size_t num_subgroups = 16;
GGML_ASSERT(block_num_y % num_subgroups == 0);
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
stream->submit([&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>>(vx, vy, dst, ncols,
nrows, nd_item);
});
});
}
static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
float *dst, const int ncols,
const int nrows,
@@ -1035,7 +1057,12 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
break;
case GGML_TYPE_Q4_K:
mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
reorder_mul_mat_vec_q4_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
} else {
mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
}
break;
case GGML_TYPE_Q5_K:
mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+22
View File
@@ -56,6 +56,28 @@ template <> struct block_q_t<GGML_TYPE_Q4_0> {
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
};
template <> struct block_q_t<GGML_TYPE_Q4_K> {
struct traits {
static constexpr uint32_t qk = QK_K;
static constexpr uint32_t qi = QI4_K;
static constexpr uint32_t qr = QR4_K;
static constexpr uint32_t vdr_mmvq = 2;
};
static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); }
static constexpr int get_d_offset(int nrows, int ncols, const int block_index) {
auto nblocks = (nrows * (ncols / traits::qk));
return (nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2));
}
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
constexpr size_t get_total_qs_bytes(int nblocks) { return nblocks * QK_K / 2; }
constexpr size_t get_dm_offset(int nblocks) { return get_total_qs_bytes(nblocks) + nblocks * K_SCALE_SIZE; }
};
} // namespace ggml_sycl_reordered
#endif // GGML_SYCL_QUANTS_HPP
+69 -43
View File
@@ -285,7 +285,7 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0> {
}
__dpct_inline__ float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
const block_q8_1 * __restrict__ bq8_1, const int & iqs, int /* nblocks */) {
const uint8_t * bq4_0 = static_cast<const uint8_t *>(vbq) + ibx_offset;
const ggml_half d = *(reinterpret_cast<const ggml_half *>(static_cast<const uint8_t *>(vbq) + d_offset));
int v[q4_0_traits::vdr_mmvq];
@@ -303,6 +303,67 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0> {
};
};
static inline float vec_dot_q4_K_q8_1_common(const int * __restrict__ q4, const uint16_t * __restrict__ scales,
const ggml_half2 & dm, const block_q8_1 * __restrict__ bq8_1,
const int & iqs) {
int v[2];
int u[2 * QR4_K];
float d8[QR4_K];
v[0] = q4[0];
v[1] = q4[4];
uint16_t aux[2];
const int j = (QR4_K * ((iqs / 2) / (QI8_1 / 2))) / 2;
if (j < 2) {
aux[0] = scales[j + 0] & 0x3f3f;
aux[1] = scales[j + 2] & 0x3f3f;
} else {
aux[0] = ((scales[j + 2] >> 0) & 0x0f0f) | ((scales[j - 2] & 0xc0c0) >> 2);
aux[1] = ((scales[j + 2] >> 4) & 0x0f0f) | ((scales[j - 0] & 0xc0c0) >> 2);
}
const uint8_t * sc = (const uint8_t *) aux;
const uint8_t * m = sc + 2;
const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2));
for (int i = 0; i < QR4_K; ++i) {
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
d8[i] = bq8i->ds[0];
const int * q8 = (const int *) bq8i->qs + ((iqs / 2) % 4);
u[2 * i + 0] = q8[0];
u[2 * i + 1] = q8[4];
}
return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, dm, d8);
}
template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K> {
static constexpr ggml_type gtype = GGML_TYPE_Q4_K;
using q4_k_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q4_K>;
using q4_k_traits = typename q4_k_block::traits;
float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
const block_q8_1 * __restrict__ bq8_1, const int & iqs, int nblocks) {
const int ib = ibx_offset / (QK_K / 2);
const uint8_t * base = static_cast<const uint8_t *>(vbq);
const uint8_t * qs = base + ibx_offset;
const int total_qs_bytes = nblocks * (QK_K / 2);
const uint8_t * scs = base + total_qs_bytes + ib * K_SCALE_SIZE;
const ggml_half2 * dms = reinterpret_cast<const ggml_half2 *>(base + d_offset);
const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2));
const int * q4 = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4));
const uint16_t * scales = (const uint16_t *) scs;
return vec_dot_q4_K_q8_1_common(q4, scales, *dms, bq8_1, iqs);
}
};
#define VDR_Q4_0_Q8_1_MMVQ 2
#define VDR_Q4_0_Q8_1_MMQ 4
@@ -649,52 +710,17 @@ vec_dot_q3_K_q8_1(const void *__restrict__ vbq,
return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8);
}
static __dpct_inline__ float
vec_dot_q4_K_q8_1(const void *__restrict__ vbq,
const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
static __dpct_inline__ float vec_dot_q4_K_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1,
const int & iqs) {
#ifndef GGML_QKK_64
const block_q4_K * bq4_K = (const block_q4_K *) vbq;
int v[2];
int u[2*QR4_K];
float d8[QR4_K];
const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2));
const int * q4 = (const int *) (bq4_K->qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4));
const uint16_t * scales = (const uint16_t *) bq4_K->scales;
// iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6
const int bq8_offset = QR4_K * ((iqs/2) / (QI8_1/2));
// iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12
// iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44
// iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76
// iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108
const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));
v[0] = q4[0];
v[1] = q4[4];
const uint16_t * scales = (const uint16_t *)bq4_K->scales;
uint16_t aux[2];
const int j = bq8_offset/2;
if (j < 2) {
aux[0] = scales[j+0] & 0x3f3f;
aux[1] = scales[j+2] & 0x3f3f;
} else {
aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
}
const uint8_t * sc = (const uint8_t *)aux;
const uint8_t * m = sc + 2;
for (int i = 0; i < QR4_K; ++i) {
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
d8[i] = bq8i->ds[0];
const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
u[2*i+0] = q8[0];
u[2*i+1] = q8[4];
}
return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8);
return vec_dot_q4_K_q8_1_common(q4, scales, bq4_K->dm, bq8_1, iqs);
#else
+79 -90
View File
@@ -15,6 +15,32 @@ function(detect_host_compiler)
set(HOST_CXX_COMPILER "${HOST_CXX_COMPILER}" PARENT_SCOPE)
endfunction()
# Function to test shader extension support
# Parameters:
# EXTENSION_NAME - Name of the extension to test (e.g., "GL_EXT_integer_dot_product")
# TEST_SHADER_FILE - Path to the test shader file
# RESULT_VARIABLE - Name of the variable to set (ON/OFF) based on test result
function(test_shader_extension_support EXTENSION_NAME TEST_SHADER_FILE RESULT_VARIABLE)
execute_process(
COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${TEST_SHADER_FILE}"
OUTPUT_VARIABLE glslc_output
ERROR_VARIABLE glslc_error
)
if (${glslc_error} MATCHES ".*extension not supported: ${EXTENSION_NAME}.*")
message(STATUS "${EXTENSION_NAME} not supported by glslc")
set(${RESULT_VARIABLE} OFF PARENT_SCOPE)
else()
message(STATUS "${EXTENSION_NAME} supported by glslc")
set(${RESULT_VARIABLE} ON PARENT_SCOPE)
add_compile_definitions(${RESULT_VARIABLE})
# Ensure the extension support is forwarded to vulkan-shaders-gen
list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -D${RESULT_VARIABLE}=ON)
set(VULKAN_SHADER_GEN_CMAKE_ARGS "${VULKAN_SHADER_GEN_CMAKE_ARGS}" PARENT_SCOPE)
endif()
endfunction()
if (Vulkan_FOUND)
message(STATUS "Vulkan found")
@@ -23,69 +49,35 @@ if (Vulkan_FOUND)
../../include/ggml-vulkan.h
)
# Compile a test shader to determine whether GL_KHR_cooperative_matrix is supported.
# If it's not, there will be an error to stderr.
# If it's supported, set a define to indicate that we should compile those shaders
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp"
OUTPUT_VARIABLE glslc_output
ERROR_VARIABLE glslc_error)
set(VULKAN_SHADER_GEN_CMAKE_ARGS
-DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}
-DCMAKE_RUNTIME_OUTPUT_DIRECTORY=${CMAKE_RUNTIME_OUTPUT_DIRECTORY}
)
if (${glslc_error} MATCHES ".*extension not supported: GL_KHR_cooperative_matrix.*")
message(STATUS "GL_KHR_cooperative_matrix not supported by glslc")
set(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT OFF)
else()
message(STATUS "GL_KHR_cooperative_matrix supported by glslc")
set(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT ON)
add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
endif()
# Test all shader extensions
test_shader_extension_support(
"GL_KHR_cooperative_matrix"
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp"
"GGML_VULKAN_COOPMAT_GLSLC_SUPPORT"
)
# Compile a test shader to determine whether GL_NV_cooperative_matrix2 is supported.
# If it's not, there will be an error to stderr.
# If it's supported, set a define to indicate that we should compile those shaders
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp"
OUTPUT_VARIABLE glslc_output
ERROR_VARIABLE glslc_error)
test_shader_extension_support(
"GL_NV_cooperative_matrix2"
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp"
"GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT"
)
if (${glslc_error} MATCHES ".*extension not supported: GL_NV_cooperative_matrix2.*")
message(STATUS "GL_NV_cooperative_matrix2 not supported by glslc")
set(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT OFF)
else()
message(STATUS "GL_NV_cooperative_matrix2 supported by glslc")
set(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT ON)
add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
endif()
test_shader_extension_support(
"GL_EXT_integer_dot_product"
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_integer_dot_support.comp"
"GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT"
)
# Compile a test shader to determine whether GL_EXT_integer_dot_product is supported.
# If it's not, there will be an error to stderr.
# If it's supported, set a define to indicate that we should compile those shaders
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_integer_dot_support.comp"
OUTPUT_VARIABLE glslc_output
ERROR_VARIABLE glslc_error)
if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_integer_dot_product.*")
message(STATUS "GL_EXT_integer_dot_product not supported by glslc")
set(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT OFF)
else()
message(STATUS "GL_EXT_integer_dot_product supported by glslc")
set(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT ON)
add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
endif()
# Compile a test shader to determine whether GL_EXT_bfloat16 is supported.
# If it's not, there will be an error to stderr.
# If it's supported, set a define to indicate that we should compile those shaders
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_bfloat16_support.comp"
OUTPUT_VARIABLE glslc_output
ERROR_VARIABLE glslc_error)
if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_bfloat16.*")
message(STATUS "GL_EXT_bfloat16 not supported by glslc")
set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT OFF)
else()
message(STATUS "GL_EXT_bfloat16 supported by glslc")
set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT ON)
add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
endif()
test_shader_extension_support(
"GL_EXT_bfloat16"
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_bfloat16_support.comp"
"GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT"
)
target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan)
target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
@@ -124,16 +116,8 @@ if (Vulkan_FOUND)
add_compile_definitions(GGML_VULKAN_RUN_TESTS)
endif()
if (NOT CMAKE_CROSSCOMPILING)
add_subdirectory(vulkan-shaders)
if (MSVC)
foreach(CONFIG ${CMAKE_CONFIGURATION_TYPES})
string(TOUPPER ${CONFIG} CONFIG)
set_target_properties(vulkan-shaders-gen PROPERTIES
RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
endforeach()
endif()
else()
# Set up toolchain for host compilation whether cross-compiling or not
if (CMAKE_CROSSCOMPILING)
if (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN)
set(HOST_CMAKE_TOOLCHAIN_FILE ${GGML_VULKAN_SHADERS_GEN_TOOLCHAIN})
else()
@@ -146,25 +130,31 @@ if (Vulkan_FOUND)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/host-toolchain.cmake.in ${CMAKE_BINARY_DIR}/host-toolchain.cmake @ONLY)
set(HOST_CMAKE_TOOLCHAIN_FILE ${CMAKE_BINARY_DIR}/host-toolchain.cmake)
endif()
message(STATUS "vulkan-shaders-gen toolchain file: ${HOST_CMAKE_TOOLCHAIN_FILE}")
include(ExternalProject)
# Native build through ExternalProject_Add
ExternalProject_Add(
vulkan-shaders-gen
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders
CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE}
-DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}
-DGGML_VULKAN_COOPMAT_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT_GLSLC_SUPPORT}
-DGGML_VULKAN_COOPMAT2_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT}
-DGGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT=${GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT}
-DGGML_VULKAN_BFLOAT16_GLSLC_SUPPORT=${GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT}
BUILD_COMMAND ${CMAKE_COMMAND} --build .
INSTALL_COMMAND ${CMAKE_COMMAND} --install .
INSTALL_DIR ${CMAKE_BINARY_DIR}
)
ExternalProject_Add_StepTargets(vulkan-shaders-gen build install)
else()
# For non-cross-compiling, use empty toolchain (use host compiler)
set(HOST_CMAKE_TOOLCHAIN_FILE "")
endif()
# Always use ExternalProject_Add approach
include(ExternalProject)
# Add toolchain file if cross-compiling
if (CMAKE_CROSSCOMPILING)
list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE})
message(STATUS "vulkan-shaders-gen toolchain file: ${HOST_CMAKE_TOOLCHAIN_FILE}")
endif()
# Native build through ExternalProject_Add
ExternalProject_Add(
vulkan-shaders-gen
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders
CMAKE_ARGS ${VULKAN_SHADER_GEN_CMAKE_ARGS}
BUILD_COMMAND ${CMAKE_COMMAND} --build .
INSTALL_COMMAND ${CMAKE_COMMAND} --install .
INSTALL_DIR ${CMAKE_BINARY_DIR}
)
ExternalProject_Add_StepTargets(vulkan-shaders-gen build install)
set (_ggml_vk_host_suffix $<IF:$<STREQUAL:${CMAKE_HOST_SYSTEM_NAME},Windows>,.exe,>)
set (_ggml_vk_genshaders_cmd ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/vulkan-shaders-gen${_ggml_vk_host_suffix})
set (_ggml_vk_header ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp)
@@ -175,9 +165,8 @@ if (Vulkan_FOUND)
file(GLOB _ggml_vk_shader_deps "${_ggml_vk_input_dir}/*.comp")
set (_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen)
if (CMAKE_CROSSCOMPILING)
set(_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen-build vulkan-shaders-gen-install)
endif()
# Add build and install dependencies for all builds
set(_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen-build vulkan-shaders-gen-install)
add_custom_command(
OUTPUT ${_ggml_vk_header}
+182 -51
View File
@@ -288,6 +288,9 @@ struct vk_device_struct {
bool coopmat_acc_f32_support {};
bool coopmat_acc_f16_support {};
bool coopmat_bf16_support {};
bool coopmat_support_16x16x16_f16acc {};
bool coopmat_support_16x16x16_f32acc {};
bool coopmat1_fa_support {};
uint32_t coopmat_m;
uint32_t coopmat_n;
uint32_t coopmat_k;
@@ -410,6 +413,13 @@ struct vk_device_struct {
vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D64_cm1[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D80_cm1[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D96_cm1[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D112_cm1[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D128_cm1[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm1[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
@@ -1588,19 +1598,36 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
);
}
enum FaCodePath {
FA_SCALAR,
FA_COOPMAT1,
FA_COOPMAT2,
};
// number of rows/cols for flash attention shader
static constexpr uint32_t flash_attention_num_small_rows = 32;
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
static constexpr uint32_t scalar_flash_attention_num_large_rows = 8;
static uint32_t get_fa_num_small_rows(bool scalar) {
return scalar ? scalar_flash_attention_num_small_rows : flash_attention_num_small_rows;
// The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
// 128 threads split into four subgroups, each subgroup does 1/4
// of the Bc dimension.
static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16;
static constexpr uint32_t scalar_flash_attention_Bc = 64;
static constexpr uint32_t scalar_flash_attention_workgroup_size = 128;
static uint32_t get_fa_num_small_rows(FaCodePath path) {
if (path == FA_COOPMAT2) {
return flash_attention_num_small_rows;
} else {
return scalar_flash_attention_num_small_rows;
}
}
static std::array<uint32_t, 2> fa_rows_cols(bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
GGML_UNUSED(clamp);
if (scalar) {
if (path == FA_SCALAR) {
if (small_rows) {
return {scalar_flash_attention_num_small_rows, 64};
} else {
@@ -1608,9 +1635,17 @@ static std::array<uint32_t, 2> fa_rows_cols(bool scalar, uint32_t D, uint32_t cl
}
}
if (path == FA_COOPMAT1) {
if (small_rows) {
return {scalar_flash_attention_num_small_rows, scalar_flash_attention_Bc};
} else {
return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc};
}
}
// small rows, large cols
if (small_rows) {
return {get_fa_num_small_rows(scalar), 32};
return {get_fa_num_small_rows(FA_COOPMAT2), 32};
}
// small cols to reduce register count
@@ -1907,17 +1942,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
};
auto const &fa_wg_denoms = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
return {fa_rows_cols(scalar, D, clamp, type, small_rows)[0], 1, 1};
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
return {fa_rows_cols(path, D, clamp, type, small_rows)[0], 1, 1};
};
auto const &fa_spec_constants = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
// For large number of rows, 128 invocations seems to work best.
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
// can't use 256 for D==80.
// For scalar, use 128 (arbitrary)
uint32_t wg_size = scalar ? 128 : ((small_rows && (D % 32) == 0) ? 256 : 128);
auto rows_cols = fa_rows_cols(scalar, D, clamp, type, small_rows);
uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
? scalar_flash_attention_workgroup_size
: ((small_rows && (D % 32) == 0) ? 256 : 128);
auto rows_cols = fa_rows_cols(path, D, clamp, type, small_rows);
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
@@ -1929,36 +1966,43 @@ static void ggml_vk_load_shaders(vk_device& device) {
return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split};
};
#define CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, D) \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true); \
#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, D) \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
#define CREATE_FA(TYPE, NAMELC, SCALAR, SUFFIX) \
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 64) \
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 80) \
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 96) \
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 112) \
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 128) \
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 256)
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256)
CREATE_FA(GGML_TYPE_F16, f16, true, )
CREATE_FA(GGML_TYPE_Q4_0, q4_0, true, )
CREATE_FA(GGML_TYPE_Q8_0, q8_0, true, )
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (device->coopmat1_fa_support) {
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
}
#endif
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (device->coopmat2) {
CREATE_FA(GGML_TYPE_F16, f16, false, _cm2)
CREATE_FA(GGML_TYPE_Q4_0, q4_0, false, _cm2)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, false, _cm2)
CREATE_FA(GGML_TYPE_Q5_0, q5_0, false, _cm2)
CREATE_FA(GGML_TYPE_Q5_1, q5_1, false, _cm2)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, false, _cm2)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, false, _cm2)
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2)
}
#endif
#undef CREATE_FA2
@@ -2041,17 +2085,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
// Create 6 variants, {s,m,l}x{unaligned,aligned}
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
// Create 2 variants, {f16,f32} accumulator
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
@@ -3009,6 +3053,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
#if defined(VK_KHR_cooperative_matrix)
device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
// coopmat1 fa shader currently assumes 32 invocations per subgroup
device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support &&
device->subgroup_size_control && device->subgroup_min_size <= 32 &&
device->subgroup_max_size >= 32;
#endif
if (coopmat2_support) {
@@ -3143,6 +3192,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
// Only enable if shape is identical
device->coopmat_acc_f32_support = true;
}
if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) {
device->coopmat_support_16x16x16_f32acc = true;
}
} else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 &&
(vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) {
// coopmat sizes not set yet
@@ -3155,6 +3207,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
// Only enable if shape is identical
device->coopmat_acc_f16_support = true;
}
if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) {
device->coopmat_support_16x16x16_f16acc = true;
}
}
} else if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eSint8 &&
(vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eSint8 &&
@@ -5688,6 +5743,36 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
}
}
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t D, bool f32acc) {
// Needs to be kept up to date on shader changes
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
const uint32_t Br = scalar_flash_attention_num_large_rows;
const uint32_t Bc = scalar_flash_attention_Bc;
const uint32_t acctype = f32acc ? 4 : 2;
const uint32_t f16vec4 = 8;
const uint32_t tmpsh = wg_size * sizeof(float);
const uint32_t tmpshv4 = wg_size * 4 * acctype;
const uint32_t Qf = Br * (D / 4 + 2) * f16vec4;
const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
const uint32_t sfsh = Bc * sfshstride * acctype;
const uint32_t kshstride = D / 4 + 2;
const uint32_t ksh = Bc * kshstride * f16vec4;
const uint32_t slope = Br * sizeof(float);
const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(D=" << D << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
return supported;
}
static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) {
VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3];
std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3];
@@ -5738,7 +5823,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
assert(q->type == GGML_TYPE_F32);
assert(k->type == v->type);
bool scalar = !ctx->device->coopmat2;
FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 :
ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
if (path == FA_COOPMAT1) {
const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
(dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, D, dst->op_params[3] == GGML_PREC_F32);
if (!coopmat_shape_supported || !coopmat_shmem_supported) {
path = FA_SCALAR;
}
}
uint32_t gqa_ratio = 1;
uint32_t qk_ratio = neq2 / nek2;
@@ -5746,9 +5843,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
uint32_t workgroups_y = (uint32_t)neq2;
uint32_t workgroups_z = (uint32_t)neq3;
// For scalar FA, we can use the "large" size to accommodate qga.
// For coopmat FA, we always use the small size (which is still pretty large for gqa).
const uint32_t max_gqa = scalar ? scalar_flash_attention_num_large_rows : get_fa_num_small_rows(false);
// For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
// For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
uint32_t max_gqa;
switch (path) {
case FA_SCALAR:
case FA_COOPMAT1:
// We may switch from coopmat1 to scalar, so use the scalar limit for both
max_gqa = scalar_flash_attention_num_large_rows;
break;
case FA_COOPMAT2:
max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
break;
default:
GGML_ASSERT(0);
}
if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
@@ -5761,11 +5870,16 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
}
vk_pipeline *pipelines;
// XXX TODO other backends may be changing accumulator precision to default to f32 soon
bool f32acc = scalar || dst->op_params[3] == GGML_PREC_F32;
bool small_rows = N <= get_fa_num_small_rows(scalar);
bool small_rows = N <= get_fa_num_small_rows(path);
if (scalar) {
if (small_rows && path == FA_COOPMAT1) {
path = FA_SCALAR;
}
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
switch (path) {
case FA_SCALAR:
switch (D) {
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
@@ -5777,7 +5891,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
GGML_ASSERT(!"unsupported D value");
return;
}
} else {
break;
case FA_COOPMAT1:
switch (D) {
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm1[k->type][f32acc][small_rows][0]; break;
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm1[k->type][f32acc][small_rows][0]; break;
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm1[k->type][f32acc][small_rows][0]; break;
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm1[k->type][f32acc][small_rows][0]; break;
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm1[k->type][f32acc][small_rows][0]; break;
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm1[k->type][f32acc][small_rows][0]; break;
default:
GGML_ASSERT(!"unsupported D value");
return;
}
break;
case FA_COOPMAT2:
switch (D) {
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break;
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break;
@@ -5789,6 +5917,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
GGML_ASSERT(!"unsupported D value");
return;
}
break;
default:
GGML_ASSERT(0);
}
assert(pipelines);
@@ -5,18 +5,35 @@ find_package (Threads REQUIRED)
if (GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
message(STATUS "Enabling coopmat glslc support")
endif()
if (GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
message(STATUS "Enabling coopmat2 glslc support")
endif()
if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
message(STATUS "Enabling dot glslc support")
endif()
if (GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
message(STATUS "Enabling bfloat16 glslc support")
endif()
set(TARGET vulkan-shaders-gen)
add_executable(${TARGET} vulkan-shaders-gen.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_compile_features(${TARGET} PRIVATE cxx_std_17)
target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads)
# Configure output directories for MSVC builds
if(MSVC)
# Get the main project's runtime output directory if possible
if(DEFINED CMAKE_RUNTIME_OUTPUT_DIRECTORY)
foreach(CONFIG ${CMAKE_CONFIGURATION_TYPES})
string(TOUPPER ${CONFIG} CONFIG)
set_target_properties(${TARGET} PROPERTIES
RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
endforeach()
endif()
endif()
@@ -12,6 +12,7 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
layout (constant_id = 1) const uint32_t Br = 1;
layout (constant_id = 2) const uint32_t Bc = 32;
layout (constant_id = 3) const uint32_t D = 32;
@@ -19,7 +20,7 @@ layout (constant_id = 3) const uint32_t D = 32;
layout (constant_id = 5) const uint32_t D_split = 16;
const uint32_t D_per_thread = D / D_split;
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split;
const uint32_t cols_per_iter = WorkGroupSize / D_split;
const uint32_t cols_per_thread = Bc / cols_per_iter;
layout (push_constant) uniform parameter {
@@ -134,8 +135,8 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
}
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
shared vec4 tmpshv4[gl_WorkGroupSize.x];
shared FLOAT_TYPE tmpsh[WorkGroupSize];
shared vec4 tmpshv4[WorkGroupSize];
shared float masksh[Bc][Br];
shared vec4 Qf[Br][D / 4];
@@ -0,0 +1,506 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_KHR_memory_scope_semantics : enable
#extension GL_KHR_cooperative_matrix : enable
#include "types.comp"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (constant_id = 1) const uint32_t Br = 1;
layout (constant_id = 2) const uint32_t Bc = 32;
layout (constant_id = 3) const uint32_t D = 32;
layout (constant_id = 5) const uint32_t D_split = 16;
const uint32_t D_per_thread = D / D_split;
const uint32_t row_split = 4;
const uint32_t rows_per_thread = Br / row_split;
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
const uint32_t cols_per_thread = Bc / cols_per_iter;
layout (push_constant) uniform parameter {
uint32_t N;
uint32_t KV;
uint32_t ne1;
uint32_t ne2;
uint32_t ne3;
uint32_t neq2;
uint32_t neq3;
uint32_t nek2;
uint32_t nek3;
uint32_t nev2;
uint32_t nev3;
uint32_t nem1;
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t nb21;
uint32_t nb22;
uint32_t nb23;
uint32_t nb31;
float scale;
float max_bias;
float logit_softcap;
uint32_t mask;
uint32_t n_head_log2;
float m0;
float m1;
uint32_t gqa_ratio;
uint32_t split_kv;
uint32_t k_num;
} p;
layout (binding = 0) readonly buffer Q {float data_q[];};
layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
layout (binding = 1) readonly buffer K {float16_t data_k[];};
layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
layout (binding = 2) readonly buffer V {float16_t data_v[];};
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
layout (binding = 3) readonly buffer M {float16_t data_m[];};
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
#if defined(A_TYPE_PACKED16)
#define BINDING_IDX_K 0
#define BINDING_IDX_V 1
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
#endif
#if defined(DATA_A_Q4_0)
#define BLOCK_BYTE_SIZE 18
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
uint shift = (iqs & 0x10) >> 2;
vui_lo >>= shift;
vui_hi >>= shift;
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
}
#endif
#if defined(DATA_A_Q8_0)
#define BLOCK_BYTE_SIZE 34
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
}
#endif
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
// Store the output when doing grouped query attention.
// Rows index by Q's dimension 2, and the first N rows are valid.
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
uint32_t offset = (iq2 + r) * D + c;
data_o[o_offset + offset] = D_TYPE(elem);
return elem;
}
// Store column zero. This is used to save per-row m and L values for split_k.
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
if (r < N && c == 0) {
uint32_t offset = iq2 + r;
data_o[o_offset + offset] = D_TYPE(elem);
}
return elem;
}
// Load the slope matrix, indexed by Q's dimension 2.
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
{
const uint32_t h = iq2 + (r % p.gqa_ratio);
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
}
// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
const uint32_t MatBr = 16;
const uint32_t MatBc = 16;
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
const uint32_t qstride = D / 4 + 2; // in units of f16vec4
shared f16vec4 Qf[Br * qstride];
// Avoid padding for D==256 to make it fit in 48KB shmem.
const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
shared ACC_TYPE sfsh[Bc * sfshstride];
const uint32_t kshstride = D / 4 + 2; // in units of f16vec4
shared f16vec4 ksh[Bc * kshstride];
shared float slope[Br];
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
const uint32_t tid = gl_LocalInvocationIndex;
const uint32_t N = p.N;
const uint32_t KV = p.KV;
const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split;
#define tile_row(r) (row_tid * rows_per_thread + (r))
uint32_t i = gl_WorkGroupID.x;
uint32_t split_k_index = 0;
if (p.k_num > 1) {
i = 0;
split_k_index = gl_WorkGroupID.x;
}
const uint32_t Tr = CEIL_DIV(N, Br);
const uint32_t start_j = split_k_index * p.split_kv / Bc;
const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
const uint32_t iq3 = gl_WorkGroupID.z;
// broadcast factors
const uint32_t rk2 = p.neq2/p.nek2;
const uint32_t rk3 = p.neq3/p.nek3;
const uint32_t rv2 = p.neq2/p.nev2;
const uint32_t rv3 = p.neq3/p.nev3;
// k indices
const uint32_t ik3 = iq3 / rk3;
const uint32_t ik2 = iq2 / rk2;
// v indices
const uint32_t iv3 = iq3 / rv3;
const uint32_t iv2 = iq2 / rv2;
// nb?1 are already divided by the type size and are in units of elements.
// When using grouped query attention, Q is indexed by iq2, so the stride
// should be nb02 (which is in bytes).
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
uint32_t k_stride = p.nb11;
uint32_t v_stride = p.nb21;
// When using grouped query attention, all rows use the same mask (stride 0).
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
// that prevents the compiler from folding the "&" through the select
// and breaking the alignment detection.
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (D / 4);
uint32_t r = (idx + tid) / (D / 4);
if (r < Br && d < D / 4 &&
i * Br + r < N) {
Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
}
}
barrier();
ACC_TYPEV4 Of[rows_per_thread][D_per_thread / 4];
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] = ACC_TYPEV4(0.0);
}
}
float Lf[rows_per_thread], Mf[rows_per_thread];
// Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Lf[r] = 0;
Mf[r] = NEG_FLT_MAX_OVER_2;
}
// ALiBi
if (p.max_bias > 0.0f) {
if (tid < Br) {
uint r = tid;
slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
}
barrier();
} else {
if (tid < Br) {
uint r = tid;
slope[r] = 1.0;
}
barrier();
}
#if BLOCK_SIZE > 1
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
#else
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * D / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (D / 4);
uint32_t c = (idx + tid) / (D / 4);
if (c < Bc && d < D / 4) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
f16vec4 K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
#else
f16vec4 K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
#endif
ksh[c * kshstride + d] = K_Tf;
}
}
barrier();
// K * Q^T -> S^T: Bc x D * D x Br -> Bc x Br
// Bc split across workgroup (four subgroups), loop over D in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
// This is written transposed in order to allow for N being 8 if implementations need it
coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
for (uint32_t d = 0; d < D / 16; ++d) {
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
SfMat = coopMatMulAdd(KMat, QMat, SfMat);
}
uint coord = gl_SubgroupID * MatBc * sfshstride;
coopMatStore(SfMat, sfsh, coord, sfshstride, gl_CooperativeMatrixLayoutRowMajor);
barrier();
if (p.logit_softcap != 0.0f) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) / Br;
uint32_t r = (idx + tid) % Br;
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
sfsh[c * sfshstride + r] = ACC_TYPE(p.logit_softcap * tanh(sfsh[c * sfshstride + r]));
}
}
barrier();
}
if (p.mask != 0) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]));
}
}
barrier();
}
float eMf[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
float rowmaxf = sfsh[tile_row(r) + (0 * cols_per_iter + col_tid) * sfshstride];
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride]));
}
float Moldf = Mf[r];
// M = max(rowmax, Mold)
// P = e^(S - M)
// eM = e^(Mold - M)
Mf[r] = max(rowmaxf, Moldf);
eMf[r] = exp(Moldf - Mf[r]);
}
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
}
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Lf[r] = eMf[r]*Lf[r];
}
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
float Pf[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);
Lf[r] += Pf[r];
}
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
#else
vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
#endif
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] += float16_t(Pf[r]) * ACC_TYPEV4(Vf);
}
}
}
barrier();
}
// reduce across threads
float rowmaxf[rows_per_thread], eMf[rows_per_thread], Moldf[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
FLOAT_TYPE M = Mf[r];
tmpsh[tid] = M;
// Compute max across the row
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
M = max(M, tmpsh[tid ^ s]);
barrier();
tmpsh[tid] = M;
barrier();
}
rowmaxf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup];
barrier();
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Moldf[r] = Mf[r];
// M = max(rowmax, Mold)
// eM = e^(Mold - M)
Mf[r] = max(rowmaxf[r], Moldf[r]);
eMf[r] = exp(Moldf[r] - Mf[r]);
Lf[r] = eMf[r]*Lf[r];
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
FLOAT_TYPE L = Lf[r];
tmpsh[tid] = L;
// Compute sum across the row
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
L += tmpsh[tid ^ s];
barrier();
tmpsh[tid] = L;
barrier();
}
Lf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup];
barrier();
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
tmpshv4[tid] = Of[r][d];
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
Of[r][d] += tmpshv4[tid ^ s];
barrier();
tmpshv4[tid] = Of[r][d];
barrier();
}
Of[r][d] = tmpshv4[d_tid + row_tid * threads_per_rowgroup];
barrier();
}
}
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
uint32_t o_offset = D * p.ne1 * split_k_index;
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
}
}
}
}
o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
}
}
return;
}
float Lfrcp[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Lfrcp[r] = 1.0 / Lf[r];
}
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] *= float16_t(Lfrcp[r]);
}
}
uint32_t o_offset = iq3*p.ne2*p.ne1;
if (p.gqa_ratio > 1) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
}
}
}
}
} else {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (i * Br + tile_row(r) < N) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
data_o[o_offset + iq2 * D + (i * Br + tile_row(r)) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
}
}
}
}
}
}
@@ -215,7 +215,7 @@ static std::mutex compile_count_mutex;
static std::condition_variable compile_count_cond;
void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_coopmat" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
std::string out_fname = join_paths(output_dir, name + ".spv");
std::string in_path = join_paths(input_dir, in_fname);
@@ -424,6 +424,7 @@ void process_shaders() {
// flash attention
for (const auto& f16acc : {false, true}) {
std::string acctype = f16acc ? "float16_t" : "float";
std::string acctypev4 = f16acc ? "f16vec4" : "vec4";
for (const auto& tname : type_names) {
if (tname == "f32") {
@@ -440,6 +441,16 @@ void process_shaders() {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
}
#endif
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"COOPMAT", "1"}}), true, true, false, f16acc);
} else if (tname == "q4_0" || tname == "q8_0") {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
}
#endif
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
+23 -18
View File
@@ -5499,7 +5499,7 @@ static void ggml_compute_backward(
// tensor = src0 * 1 + src1 * 0
if (src0_needs_grads) {
// dsrc0 = dtensor * 1
ggml_add_or_set(ctx, cgraph, isrc0, grad);
ggml_add_or_set(ctx, cgraph, isrc0, ggml_reshape(ctx, grad, src0));
}
if (src1_needs_grads) {
// dsrc1 = dtensor * 0 -> noop
@@ -5780,10 +5780,9 @@ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor *
}
void ggml_build_backward_expand(
struct ggml_context * ctx_static,
struct ggml_context * ctx_compute,
struct ggml_cgraph * cgraph,
bool accumulate) {
struct ggml_context * ctx,
struct ggml_cgraph * cgraph,
struct ggml_tensor ** grad_accs) {
GGML_ASSERT(cgraph->n_nodes > 0);
GGML_ASSERT(cgraph->grads);
GGML_ASSERT(cgraph->grad_accs);
@@ -5856,21 +5855,24 @@ void ggml_build_backward_expand(
GGML_ASSERT(!node->view_src || node->op == GGML_OP_CPY || node->op == GGML_OP_VIEW ||
node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE);
const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);
GGML_ASSERT(igrad != GGML_HASHSET_FULL);
GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, igrad));
if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) {
cgraph->grad_accs[igrad] = ggml_dup_tensor(ctx_static, node);
cgraph->grads[igrad] = cgraph->grad_accs[igrad];
ggml_format_name(cgraph->grad_accs[igrad], "grad acc for %s", node->name);
const size_t ihash = ggml_hash_find(&cgraph->visited_hash_set, node);
GGML_ASSERT(ihash != GGML_HASHSET_FULL);
GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, ihash));
if (grad_accs && grad_accs[i]) {
cgraph->grad_accs[ihash] = grad_accs[i];
cgraph->grads[ihash] = cgraph->grad_accs[ihash];
} else if (node->flags & GGML_TENSOR_FLAG_LOSS) {
// loss tensors always need a gradient accumulator
cgraph->grad_accs[ihash] = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
cgraph->grads[ihash] = cgraph->grad_accs[ihash];
}
grads_needed[igrad] = true;
grads_needed[ihash] = true;
}
for (int i = n_nodes_f - 1; i >= 0; --i) {
// inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation
// use allocator to automatically make inplace operations
ggml_compute_backward(ctx_compute, cgraph, i, grads_needed);
ggml_compute_backward(ctx, cgraph, i, grads_needed);
}
free(grads_needed);
@@ -6016,8 +6018,8 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
}
}
struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads != NULL);
struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads) {
struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads || force_grads);
ggml_graph_cpy(cgraph, result);
return result;
}
@@ -6036,6 +6038,9 @@ struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
}
void ggml_graph_reset(struct ggml_cgraph * cgraph) {
if (!cgraph) {
return;
}
GGML_ASSERT(cgraph->grads != NULL);
for (int i = 0; i < cgraph->n_nodes; i++) {
@@ -6345,8 +6350,8 @@ void ggml_set_output(struct ggml_tensor * tensor) {
tensor->flags |= GGML_TENSOR_FLAG_OUTPUT;
}
void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor) {
GGML_UNUSED(ctx); // TODO: remove this parameter
void ggml_set_param(struct ggml_tensor * tensor) {
GGML_ASSERT(tensor->op == GGML_OP_NONE);
tensor->flags |= GGML_TENSOR_FLAG_PARAM;
}
+33 -33
View File
@@ -299,10 +299,10 @@ bool gguf_read_emplace_helper(const struct gguf_reader & gr, std::vector<struct
return false;
}
} catch (std::length_error &) {
fprintf(stderr, "%s: encountered length_error while reading value for key '%s'\n", __func__, key.c_str());
GGML_LOG_ERROR("%s: encountered length_error while reading value for key '%s'\n", __func__, key.c_str());
return false;
} catch (std::bad_alloc &) {
fprintf(stderr, "%s: encountered bad_alloc error while reading value for key '%s'\n", __func__, key.c_str());
GGML_LOG_ERROR("%s: encountered bad_alloc error while reading value for key '%s'\n", __func__, key.c_str());
return false;
}
kv.emplace_back(key, value);
@@ -328,14 +328,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
ok = ok && gr.read(magic, 4);
if (!ok) {
fprintf(stderr, "%s: failed to read magic\n", __func__);
GGML_LOG_ERROR("%s: failed to read magic\n", __func__);
gguf_free(ctx);
return nullptr;
}
for (uint32_t i = 0; i < magic.size(); i++) {
if (magic[i] != GGUF_MAGIC[i]) {
fprintf(stderr, "%s: invalid magic characters: '%c%c%c%c', expected 'GGUF'\n", __func__, magic[0], magic[1], magic[2], magic[3]);
GGML_LOG_ERROR("%s: invalid magic characters: '%c%c%c%c', expected 'GGUF'\n", __func__, magic[0], magic[1], magic[2], magic[3]);
gguf_free(ctx);
return nullptr;
}
@@ -348,11 +348,11 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
if (ok && gr.read(ctx->version)) {
if (ctx->version == 1) {
fprintf(stderr, "%s: GGUFv1 is no longer supported, please use a more up-to-date version\n", __func__);
GGML_LOG_ERROR("%s: GGUFv1 is no longer supported, please use a more up-to-date version\n", __func__);
ok = false;
}
if (ctx->version > GGUF_VERSION) {
fprintf(stderr, "%s: this GGUF file is version %" PRIu32 " but this software only supports up to version %d\n",
GGML_LOG_ERROR("%s: this GGUF file is version %" PRIu32 " but this software only supports up to version %d\n",
__func__, ctx->version, GGUF_VERSION);
ok = false;
}
@@ -363,7 +363,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
if (ok && gr.read(n_tensors)) {
static_assert(sizeof(size_t) <= 8 && sizeof(gguf_tensor_info) >= 2, "int64_t insufficient for indexing");
if (n_tensors < 0 || n_tensors > int64_t(SIZE_MAX/sizeof(gguf_tensor_info))) {
fprintf(stderr, "%s: number of tensors is %" PRIi64 " but must be in [0, %zu]\n",
GGML_LOG_ERROR("%s: number of tensors is %" PRIi64 " but must be in [0, %zu]\n",
__func__, n_tensors, SIZE_MAX/sizeof(gguf_tensor_info));
ok = false;
}
@@ -374,7 +374,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
if (ok && gr.read(n_kv)) {
static_assert(sizeof(size_t) <= 8 && sizeof(gguf_tensor_info) >= 2, "int64_t insufficient for indexing");
if (n_kv < 0 || n_kv > int64_t(SIZE_MAX/sizeof(gguf_kv))) {
fprintf(stderr, "%s: number of key value pairs is %" PRIi64 " but must be in [0, %zu]\n",
GGML_LOG_ERROR("%s: number of key value pairs is %" PRIi64 " but must be in [0, %zu]\n",
__func__, n_kv, SIZE_MAX/sizeof(gguf_kv));
ok = false;
}
@@ -383,7 +383,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
}
if (!ok) {
fprintf(stderr, "%s: failed to read header\n", __func__);
GGML_LOG_ERROR("%s: failed to read header\n", __func__);
gguf_free(ctx);
return nullptr;
}
@@ -399,15 +399,15 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
try {
ok = ok && gr.read(key);
} catch (std::length_error &) {
fprintf(stderr, "%s: encountered length_error while reading key %" PRIi64 "\n", __func__, i);
GGML_LOG_ERROR("%s: encountered length_error while reading key %" PRIi64 "\n", __func__, i);
ok = false;
} catch (std::bad_alloc &) {
fprintf(stderr, "%s: encountered bad_alloc error while reading key %" PRIi64 "\n", __func__, i);
GGML_LOG_ERROR("%s: encountered bad_alloc error while reading key %" PRIi64 "\n", __func__, i);
ok = false;
}
for (size_t j = 0; ok && j < ctx->kv.size(); ++j) {
if (key == ctx->kv[j].key) {
fprintf(stderr, "%s: duplicate key '%s' for tensors %zu and %" PRIi64 " \n", __func__, key.c_str(), j, i);
GGML_LOG_ERROR("%s: duplicate key '%s' for tensors %zu and %" PRIi64 " \n", __func__, key.c_str(), j, i);
ok = false;
}
}
@@ -441,14 +441,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
case GGUF_TYPE_ARRAY:
default:
{
fprintf(stderr, "%s: key '%s' has invalid GGUF type %d\n", __func__, key.c_str(), type);
GGML_LOG_ERROR("%s: key '%s' has invalid GGUF type %d\n", __func__, key.c_str(), type);
ok = false;
} break;
}
}
if (!ok) {
fprintf(stderr, "%s: failed to read key-value pairs\n", __func__);
GGML_LOG_ERROR("%s: failed to read key-value pairs\n", __func__);
gguf_free(ctx);
return nullptr;
}
@@ -458,7 +458,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
ctx->alignment = alignment_idx == -1 ? GGUF_DEFAULT_ALIGNMENT : gguf_get_val_u32(ctx, alignment_idx);
if (ctx->alignment == 0 || (ctx->alignment & (ctx->alignment - 1)) != 0) {
fprintf(stderr, "%s: alignment %zu is not a power of 2\n", __func__, ctx->alignment);
GGML_LOG_ERROR("%s: alignment %zu is not a power of 2\n", __func__, ctx->alignment);
gguf_free(ctx);
return nullptr;
}
@@ -474,14 +474,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
try {
ok = ok && gr.read(name);
} catch (std::length_error &) {
fprintf(stderr, "%s: encountered length_error while reading tensor name %" PRIi64 "\n", __func__, i);
GGML_LOG_ERROR("%s: encountered length_error while reading tensor name %" PRIi64 "\n", __func__, i);
ok = false;
} catch (std::bad_alloc &) {
fprintf(stderr, "%s: encountered bad_alloc error while reading tensor name %" PRIi64 "\n", __func__, i);
GGML_LOG_ERROR("%s: encountered bad_alloc error while reading tensor name %" PRIi64 "\n", __func__, i);
ok = false;
}
if (name.length() >= GGML_MAX_NAME) {
fprintf(stderr, "%s: tensor name %" PRIi64 " is too long: %zu >= %d\n", __func__, i, name.length(), GGML_MAX_NAME);
GGML_LOG_ERROR("%s: tensor name %" PRIi64 " is too long: %zu >= %d\n", __func__, i, name.length(), GGML_MAX_NAME);
ok = false;
break;
}
@@ -490,7 +490,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
// make sure there are no duplicate tensor names
for (int64_t j = 0; ok && j < i; ++j) {
if (strcmp(info.t.name, ctx->info[j].t.name) == 0) {
fprintf(stderr, "%s: duplicate tensor name '%s' for tensors %" PRIi64 " and %" PRIi64 "\n", __func__, info.t.name, j, i);
GGML_LOG_ERROR("%s: duplicate tensor name '%s' for tensors %" PRIi64 " and %" PRIi64 "\n", __func__, info.t.name, j, i);
ok = false;
break;
}
@@ -505,7 +505,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
uint32_t n_dims = -1;
ok = ok && gr.read(n_dims);
if (n_dims > GGML_MAX_DIMS) {
fprintf(stderr, "%s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 "\n",
GGML_LOG_ERROR("%s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 "\n",
__func__, info.t.name, n_dims, GGML_MAX_DIMS);
ok = false;
break;
@@ -518,7 +518,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
// check that all ne are non-negative
if (info.t.ne[j] < 0) {
fprintf(stderr, "%s: tensor '%s' dimension %" PRIu32 " has invalid number of elements: %" PRIi64 " < 0\n",
GGML_LOG_ERROR("%s: tensor '%s' dimension %" PRIu32 " has invalid number of elements: %" PRIi64 " < 0\n",
__func__, info.t.name, j, info.t.ne[j]);
ok = false;
break;
@@ -530,7 +530,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
(INT64_MAX/info.t.ne[2] <= info.t.ne[0]*info.t.ne[1]) ||
(INT64_MAX/info.t.ne[3] <= info.t.ne[0]*info.t.ne[1]*info.t.ne[2]))) {
fprintf(stderr, "%s: total number of elements in tensor '%s' with shape "
GGML_LOG_ERROR("%s: total number of elements in tensor '%s' with shape "
"(%" PRIi64 ", %" PRIi64 ", %" PRIi64 ", %" PRIi64 ") is >= %" PRIi64 "\n",
__func__, info.t.name, info.t.ne[0], info.t.ne[1], info.t.ne[2], info.t.ne[3], INT64_MAX);
ok = false;
@@ -547,7 +547,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
// check that tensor type is within defined range
if (info.t.type < 0 || info.t.type >= GGML_TYPE_COUNT) {
fprintf(stderr, "%s: tensor '%s' has invalid ggml type %d (%s)\n",
GGML_LOG_ERROR("%s: tensor '%s' has invalid ggml type %d (%s)\n",
__func__, info.t.name, info.t.type, ggml_type_name(info.t.type));
ok = false;
break;
@@ -557,7 +557,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
// check that row size is divisible by block size
if (blck_size == 0 || info.t.ne[0] % blck_size != 0) {
fprintf(stderr, "%s: tensor '%s' of type %d (%s) has %" PRId64 " elements per row, "
GGML_LOG_ERROR("%s: tensor '%s' of type %d (%s) has %" PRId64 " elements per row, "
"not a multiple of block size (%" PRId64 ")\n",
__func__, info.t.name, (int) info.t.type, ggml_type_name(info.t.type), info.t.ne[0], blck_size);
ok = false;
@@ -582,7 +582,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
}
if (!ok) {
fprintf(stderr, "%s: failed to read tensor info\n", __func__);
GGML_LOG_ERROR("%s: failed to read tensor info\n", __func__);
gguf_free(ctx);
return nullptr;
}
@@ -590,7 +590,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
// we require the data section to be aligned, so take into account any padding
if (fseek(file, GGML_PAD(ftell(file), ctx->alignment), SEEK_SET) != 0) {
fprintf(stderr, "%s: failed to seek to beginning of data section\n", __func__);
GGML_LOG_ERROR("%s: failed to seek to beginning of data section\n", __func__);
gguf_free(ctx);
return nullptr;
}
@@ -604,9 +604,9 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
for (size_t i = 0; i < ctx->info.size(); ++i) {
const gguf_tensor_info & ti = ctx->info[i];
if (ti.offset != ctx->size) {
fprintf(stderr, "%s: tensor '%s' has offset %" PRIu64 ", expected %zu\n",
GGML_LOG_ERROR("%s: tensor '%s' has offset %" PRIu64 ", expected %zu\n",
__func__, ti.t.name, ti.offset, ctx->size);
fprintf(stderr, "%s: failed to read tensor data\n", __func__);
GGML_LOG_ERROR("%s: failed to read tensor data\n", __func__);
gguf_free(ctx);
return nullptr;
}
@@ -634,7 +634,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
*params.ctx = ggml_init(pdata);
if (*params.ctx == nullptr) {
fprintf(stderr, "%s: failed to initialize ggml context for storing tensors\n", __func__);
GGML_LOG_ERROR("%s: failed to initialize ggml context for storing tensors\n", __func__);
gguf_free(ctx);
return nullptr;
}
@@ -656,7 +656,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
ok = ok && gr.read(data->data, ctx->size);
if (!ok) {
fprintf(stderr, "%s: failed to read tensor data binary blob\n", __func__);
GGML_LOG_ERROR("%s: failed to read tensor data binary blob\n", __func__);
ggml_free(ctx_data);
*params.ctx = nullptr;
gguf_free(ctx);
@@ -689,7 +689,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
}
if (!ok) {
fprintf(stderr, "%s: failed to create tensors\n", __func__);
GGML_LOG_ERROR("%s: failed to create tensors\n", __func__);
ggml_free(ctx_data);
*params.ctx = nullptr;
gguf_free(ctx);
@@ -706,7 +706,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
FILE * file = ggml_fopen(fname, "rb");
if (!file) {
fprintf(stderr, "%s: failed to open GGUF file '%s'\n", __func__, fname);
GGML_LOG_ERROR("%s: failed to open GGUF file '%s'\n", __func__, fname);
return nullptr;
}
@@ -1305,7 +1305,7 @@ bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, boo
FILE * file = ggml_fopen(fname, "wb");
if (!file) {
fprintf(stderr, "%s: failed to open file '%s' for writing GGUF data\n", __func__, fname);
GGML_LOG_ERROR("%s: failed to open file '%s' for writing GGUF data\n", __func__, fname);
return false;
}
+9
View File
@@ -483,7 +483,9 @@ class MODEL_TENSOR(IntEnum):
V_ENC_EMBD_PATCH = auto()
V_ENC_EMBD_POS = auto()
V_ENC_ATTN_Q = auto()
V_ENC_ATTN_Q_NORM = auto()
V_ENC_ATTN_K = auto()
V_ENC_ATTN_K_NORM = auto()
V_ENC_ATTN_V = auto()
V_ENC_INPUT_NORM = auto()
V_ENC_OUTPUT = auto()
@@ -742,7 +744,9 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.V_ENC_EMBD_PATCH: "v.patch_embd",
MODEL_TENSOR.V_ENC_EMBD_POS: "v.position_embd",
MODEL_TENSOR.V_ENC_ATTN_Q: "v.blk.{bid}.attn_q",
MODEL_TENSOR.V_ENC_ATTN_Q_NORM: "v.blk.{bid}.attn_q_norm",
MODEL_TENSOR.V_ENC_ATTN_K: "v.blk.{bid}.attn_k",
MODEL_TENSOR.V_ENC_ATTN_K_NORM: "v.blk.{bid}.attn_k_norm",
MODEL_TENSOR.V_ENC_ATTN_V: "v.blk.{bid}.attn_v",
MODEL_TENSOR.V_ENC_INPUT_NORM: "v.blk.{bid}.ln1",
MODEL_TENSOR.V_ENC_OUTPUT: "v.blk.{bid}.attn_out",
@@ -782,7 +786,9 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.V_ENC_EMBD_PATCH,
MODEL_TENSOR.V_ENC_EMBD_POS,
MODEL_TENSOR.V_ENC_ATTN_Q,
MODEL_TENSOR.V_ENC_ATTN_Q_NORM,
MODEL_TENSOR.V_ENC_ATTN_K,
MODEL_TENSOR.V_ENC_ATTN_K_NORM,
MODEL_TENSOR.V_ENC_ATTN_V,
MODEL_TENSOR.V_ENC_INPUT_NORM,
MODEL_TENSOR.V_ENC_OUTPUT,
@@ -1899,6 +1905,9 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
],
MODEL_ARCH.CHAMELEON: [
MODEL_TENSOR.TOKEN_EMBD,
+7 -3
View File
@@ -823,6 +823,7 @@ class GGUFEditorWindow(QMainWindow):
self.modified = False
self.metadata_changes = {} # Store changes to apply when saving
self.metadata_to_remove = set() # Store keys to remove when saving
self.on_metadata_changed_is_connected = False
self.setup_ui()
@@ -941,9 +942,11 @@ class GGUFEditorWindow(QMainWindow):
return
# Disconnect to prevent triggering during loading
with warnings.catch_warnings():
warnings.filterwarnings('ignore')
self.metadata_table.itemChanged.disconnect(self.on_metadata_changed)
if self.on_metadata_changed_is_connected:
with warnings.catch_warnings():
warnings.filterwarnings('ignore')
self.metadata_table.itemChanged.disconnect(self.on_metadata_changed)
self.on_metadata_changed_is_connected = False
for i, (key, field) in enumerate(self.reader.fields.items()):
self.metadata_table.insertRow(i)
@@ -1021,6 +1024,7 @@ class GGUFEditorWindow(QMainWindow):
# Reconnect after loading
self.metadata_table.itemChanged.connect(self.on_metadata_changed)
self.on_metadata_changed_is_connected = True
def extract_array_values(self, field: ReaderField) -> list:
"""Extract all values from an array field."""
+38 -29
View File
@@ -68,7 +68,7 @@ class TensorNameMap:
"output_layer", # chatglm
"head", # rwkv
"head.out", # wavtokenizer
"language_model.lm_head", # llama4
"lm_head", # llama4
),
# Output norm
@@ -91,7 +91,7 @@ class TensorNameMap:
"rwkv.ln_out", # rwkv6
"model.ln_out", # rwkv7
"backbone.final_layer_norm", # wavtokenizer
"language_model.model.norm", # llama4
"model.norm", # llama4
),
# Rope frequencies
@@ -133,7 +133,7 @@ class TensorNameMap:
"transformer.layers.{bid}.attn_norm", # openelm
"rwkv.blocks.{bid}.ln1", # rwkv6
"model.layers.{bid}.ln1", # rwkv7
"language_model.model.layers.{bid}.input_layernorm", # llama4
"model.layers.{bid}.input_layernorm", # llama4
),
# Attention norm 2
@@ -173,7 +173,7 @@ class TensorNameMap:
"model.layers.{bid}.attention.wq", # internlm2
"transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok
"transformer.h.{bid}.attn.attention.q_proj", # exaone
"language_model.model.layers.{bid}.self_attn.q_proj", # llama4
"model.layers.{bid}.self_attn.q_proj", # llama4
),
# Attention key
@@ -188,7 +188,7 @@ class TensorNameMap:
"model.layers.{bid}.attention.wk", # internlm2
"transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok
"transformer.h.{bid}.attn.attention.k_proj", # exaone
"language_model.model.layers.{bid}.self_attn.k_proj", # llama4
"model.layers.{bid}.self_attn.k_proj", # llama4
),
# Attention value
@@ -202,7 +202,7 @@ class TensorNameMap:
"model.layers.{bid}.attention.wv", # internlm2
"transformer.decoder_layer.{bid}.multi_head_attention.value",# Grok
"transformer.h.{bid}.attn.attention.v_proj", # exaone
"language_model.model.layers.{bid}.self_attn.v_proj", # llama4
"model.layers.{bid}.self_attn.v_proj", # llama4
),
# Attention output
@@ -229,7 +229,7 @@ class TensorNameMap:
"encoder.layers.{bid}.self_attention.dense", # chatglm
"transformer.layers.{bid}.attn.out_proj", # openelm
"transformer.h.{bid}.attn.attention.out_proj", # exaone
"language_model.model.layers.{bid}.self_attn.o_proj", # llama4
"model.layers.{bid}.self_attn.o_proj", # llama4
),
# Attention output norm
@@ -268,7 +268,7 @@ class TensorNameMap:
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
"transformer.layers.{bid}.ffn_norm", # openelm
"language_model.model.layers.{bid}.post_attention_layernorm", # llama4
"model.layers.{bid}.post_attention_layernorm", # llama4
),
# Post feed-forward norm
@@ -289,7 +289,7 @@ class TensorNameMap:
"transformer.decoder_layer.{bid}.router", # Grok
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
"model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
"language_model.model.layers.{bid}.feed_forward.router", # llama4
"model.layers.{bid}.feed_forward.router", # llama4
"encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe
),
@@ -329,7 +329,7 @@ class TensorNameMap:
"model.layers.{bid}.residual_mlp.w3", # arctic
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
"transformer.h.{bid}.mlp.c_fc_1", # exaone
"language_model.model.layers.{bid}.feed_forward.up_proj", # llama4
"model.layers.{bid}.feed_forward.up_proj", # llama4
),
MODEL_TENSOR.FFN_UP_EXP: (
@@ -338,14 +338,14 @@ class TensorNameMap:
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
"model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged)
"language_model.model.layers.{bid}.feed_forward.experts.up_proj", # llama4
"model.layers.{bid}.feed_forward.experts.up_proj", # llama4
"encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe
),
MODEL_TENSOR.FFN_UP_SHEXP: (
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2
"language_model.model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2
"model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4
),
# AWQ-activation gate
@@ -366,22 +366,22 @@ class TensorNameMap:
"transformer.h.{bid}.mlp.linear_1", # refact
"model.layers.{bid}.residual_mlp.w1", # arctic
"transformer.h.{bid}.mlp.c_fc_0", # exaone
"language_model.model.layers.{bid}.feed_forward.gate_proj", # llama4
"model.layers.{bid}.feed_forward.gate_proj", # llama4
),
MODEL_TENSOR.FFN_GATE_EXP: (
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
"model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged)
"language_model.model.layers.{bid}.feed_forward.experts.gate_proj", # llama4
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
"model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged)
"model.layers.{bid}.feed_forward.experts.gate_proj", # llama4
),
MODEL_TENSOR.FFN_GATE_SHEXP: (
"model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe
"model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2
"language_model.model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4
"model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe
"model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2
"model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4
),
# Feed-forward down
@@ -410,7 +410,7 @@ class TensorNameMap:
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
"model.layers.h.{bid}.mlp.c_proj", # exaone
"language_model.model.layers.{bid}.feed_forward.down_proj", # llama4
"model.layers.{bid}.feed_forward.down_proj", # llama4
),
MODEL_TENSOR.FFN_DOWN_EXP: (
@@ -420,14 +420,15 @@ class TensorNameMap:
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
"model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
"model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged)
"language_model.model.layers.{bid}.feed_forward.experts.down_proj", # llama4
"model.layers.{bid}.feed_forward.experts.down_proj", # llama4
"encoder.layers.{bid}.mlp.experts.mlp.w2", # nomic-bert-moe
),
MODEL_TENSOR.FFN_DOWN_SHEXP: (
"model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe
"model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2
"language_model.model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4
"model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe
"model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2
"model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4
"model.layers.{bid}.shared_mlp.output_linear", # granitemoe
),
MODEL_TENSOR.ATTN_Q_NORM: (
@@ -938,6 +939,10 @@ class TensorNameMap:
"visual.blocks.{bid}.attn.q", # qwen2vl, generated
),
MODEL_TENSOR.V_ENC_ATTN_Q_NORM: (
"vision_tower.vision_model.encoder.layers.{bid}.attn.q_norm", # InternVL
),
MODEL_TENSOR.V_ENC_ATTN_K: (
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.k_proj",
"vpm.encoder.layers.{bid}.self_attn.k_proj",
@@ -946,6 +951,10 @@ class TensorNameMap:
"visual.blocks.{bid}.attn.k", # qwen2vl, generated
),
MODEL_TENSOR.V_ENC_ATTN_K_NORM: (
"vision_tower.vision_model.encoder.layers.{bid}.attn.k_norm", # InternVL
),
MODEL_TENSOR.V_ENC_ATTN_V: (
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.v_proj",
"vpm.encoder.layers.{bid}.self_attn.v_proj",
+38 -1
View File
@@ -4,6 +4,7 @@
#include "ggml.h"
#include "ggml-cpu.h"
#include "ggml-backend.h"
#include "ggml-opt.h"
#include <stddef.h>
#include <stdint.h>
@@ -344,7 +345,7 @@ extern "C" {
float yarn_beta_fast; // YaRN low correction dim
float yarn_beta_slow; // YaRN high correction dim
uint32_t yarn_orig_ctx; // YaRN original context size
float defrag_thold; // defragment the KV cache if holes/size > thold, < 0 disabled (default)
float defrag_thold; // defragment the KV cache if holes/size > thold, <= 0 disabled (default)
ggml_backend_sched_eval_callback cb_eval;
void * cb_eval_user_data;
@@ -363,6 +364,7 @@ extern "C" {
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
bool no_perf; // whether to measure performance timings
bool op_offload; // whether to offload host tensor operations to device
};
// model quantization parameters
@@ -444,6 +446,10 @@ extern "C" {
size_t n_paths,
struct llama_model_params params);
LLAMA_API void llama_model_save_to_file(
const struct llama_model * model,
const char * path_model);
DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model),
"use llama_model_free instead");
@@ -1432,6 +1438,37 @@ extern "C" {
LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain);
LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain);
//
// training
//
// function that returns whether or not a given tensor contains trainable parameters
typedef bool (*llama_opt_param_filter)(const struct ggml_tensor * tensor, void * userdata);
// always returns true
LLAMA_API bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata);
struct llama_opt_params {
uint32_t n_ctx_train; // assumed context size post training, use context size specified in llama_context if 0
llama_opt_param_filter param_filter; // callback for determining which tensors contain trainable parameters
void * param_filter_ud; // userdata for determining which tensors contain trainable parameters
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
};
LLAMA_API void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params);
LLAMA_API void llama_opt_epoch(
struct llama_context * lctx,
ggml_opt_dataset_t dataset,
ggml_opt_result_t result_train,
ggml_opt_result_t result_eval,
int64_t idata_split,
ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval);
#ifdef __cplusplus
}
#endif
+337 -125
View File
@@ -7,6 +7,10 @@ import sys
import os
from glob import glob
import sqlite3
import json
import csv
from typing import Optional, Union
from collections.abc import Iterator, Sequence
try:
import git
@@ -17,6 +21,28 @@ except ImportError as e:
logger = logging.getLogger("compare-llama-bench")
# All llama-bench SQL fields
DB_FIELDS = [
"build_commit", "build_number", "cpu_info", "gpu_info", "backends", "model_filename",
"model_type", "model_size", "model_n_params", "n_batch", "n_ubatch", "n_threads",
"cpu_mask", "cpu_strict", "poll", "type_k", "type_v", "n_gpu_layers",
"split_mode", "main_gpu", "no_kv_offload", "flash_attn", "tensor_split", "tensor_buft_overrides",
"defrag_thold",
"use_mmap", "embeddings", "no_op_offload", "n_prompt", "n_gen", "n_depth",
"test_time", "avg_ns", "stddev_ns", "avg_ts", "stddev_ts",
]
DB_TYPES = [
"TEXT", "INTEGER", "TEXT", "TEXT", "TEXT", "TEXT",
"TEXT", "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER",
"TEXT", "INTEGER", "INTEGER", "TEXT", "TEXT", "INTEGER",
"TEXT", "INTEGER", "INTEGER", "INTEGER", "TEXT", "TEXT",
"REAL",
"INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER",
"TEXT", "INTEGER", "INTEGER", "REAL", "REAL",
]
assert len(DB_FIELDS) == len(DB_TYPES)
# Properties by which to differentiate results per commit:
KEY_PROPERTIES = [
"cpu_info", "gpu_info", "backends", "n_gpu_layers", "tensor_buft_overrides", "model_filename", "model_type",
@@ -42,7 +68,7 @@ DEFAULT_HIDE = ["model_filename"] # Always hide these properties by default.
GPU_NAME_STRIP = ["NVIDIA GeForce ", "Tesla ", "AMD Radeon "] # Strip prefixes for smaller tables.
MODEL_SUFFIX_REPLACE = {" - Small": "_S", " - Medium": "_M", " - Large": "_L"}
DESCRIPTION = """Creates tables from llama-bench data written to an SQLite database. Example usage (Linux):
DESCRIPTION = """Creates tables from llama-bench data written to multiple JSON/CSV files, a single JSONL file or SQLite database. Example usage (Linux):
$ git checkout master
$ make clean && make llama-bench
@@ -70,12 +96,13 @@ help_c = (
)
parser.add_argument("-c", "--compare", help=help_c)
help_i = (
"Input SQLite file for comparing commits. "
"JSON/JSONL/SQLite/CSV files for comparing commits. "
"Specify multiple times to use multiple input files (JSON/CSV only). "
"Defaults to 'llama-bench.sqlite' in the current working directory. "
"If no such file is found and there is exactly one .sqlite file in the current directory, "
"that file is instead used as input."
)
parser.add_argument("-i", "--input", help=help_i)
parser.add_argument("-i", "--input", action="append", help=help_i)
help_o = (
"Output format for the table. "
"Defaults to 'pipe' (GitHub compatible). "
@@ -86,7 +113,7 @@ parser.add_argument("-o", "--output", help=help_o, default="pipe")
help_s = (
"Columns to add to the table. "
"Accepts a comma-separated list of values. "
f"Legal values: {', '.join(KEY_PROPERTIES[:-2])}. "
f"Legal values: {', '.join(KEY_PROPERTIES[:-3])}. "
"Defaults to model name (model_type) and CPU and/or GPU name (cpu_info, gpu_info) "
"plus any column where not all data points are the same. "
"If the columns are manually specified, then the results for each unique combination of the "
@@ -110,119 +137,321 @@ if unknown_args:
sys.exit(1)
input_file = known_args.input
if input_file is None and os.path.exists("./llama-bench.sqlite"):
input_file = "llama-bench.sqlite"
if input_file is None:
if not input_file and os.path.exists("./llama-bench.sqlite"):
input_file = ["llama-bench.sqlite"]
if not input_file:
sqlite_files = glob("*.sqlite")
if len(sqlite_files) == 1:
input_file = sqlite_files[0]
input_file = sqlite_files
if input_file is None:
if not input_file:
logger.error("Cannot find a suitable input file, please provide one.\n")
parser.print_help()
sys.exit(1)
connection = sqlite3.connect(input_file)
cursor = connection.cursor()
build_len_min: int = cursor.execute("SELECT MIN(LENGTH(build_commit)) from test;").fetchone()[0]
build_len_max: int = cursor.execute("SELECT MAX(LENGTH(build_commit)) from test;").fetchone()[0]
class LlamaBenchData:
repo: Optional[git.Repo]
build_len_min: int
build_len_max: int
build_len: int = 8
builds: list[str] = []
check_keys = set(KEY_PROPERTIES + ["build_commit", "test_time", "avg_ts"])
if build_len_min != build_len_max:
logger.warning(f"{input_file} contains commit hashes of differing lengths. It's possible that the wrong commits will be compared. "
"Try purging the the database of old commits.")
cursor.execute(f"UPDATE test SET build_commit = SUBSTRING(build_commit, 1, {build_len_min});")
def __init__(self):
try:
self.repo = git.Repo(".", search_parent_directories=True)
except git.InvalidGitRepositoryError:
self.repo = None
build_len: int = build_len_min
def _builds_init(self):
self.build_len = self.build_len_min
builds = cursor.execute("SELECT DISTINCT build_commit FROM test;").fetchall()
builds = list(map(lambda b: b[0], builds)) # list[tuple[str]] -> list[str]
if not builds:
raise RuntimeError(f"{input_file} does not contain any builds.")
try:
repo = git.Repo(".", search_parent_directories=True)
except git.InvalidGitRepositoryError:
repo = None
def find_parent_in_data(commit: git.Commit):
"""Helper function to find the most recent parent measured in number of commits for which there is data."""
heap: list[tuple[int, git.Commit]] = [(0, commit)]
seen_hexsha8 = set()
while heap:
depth, current_commit = heapq.heappop(heap)
current_hexsha8 = commit.hexsha[:build_len]
if current_hexsha8 in builds:
return current_hexsha8
for parent in commit.parents:
parent_hexsha8 = parent.hexsha[:build_len]
if parent_hexsha8 not in seen_hexsha8:
seen_hexsha8.add(parent_hexsha8)
heapq.heappush(heap, (depth + 1, parent))
return None
def get_all_parent_hexsha8s(commit: git.Commit):
"""Helper function to recursively get hexsha8 values for all parents of a commit."""
unvisited = [commit]
visited = []
while unvisited:
current_commit = unvisited.pop(0)
visited.append(current_commit.hexsha[:build_len])
for parent in current_commit.parents:
if parent.hexsha[:build_len] not in visited:
unvisited.append(parent)
return visited
def get_commit_name(hexsha8: str):
"""Helper function to find a human-readable name for a commit if possible."""
if repo is None:
return hexsha8
for h in repo.heads:
if h.commit.hexsha[:build_len] == hexsha8:
return h.name
for t in repo.tags:
if t.commit.hexsha[:build_len] == hexsha8:
return t.name
return hexsha8
def get_commit_hexsha8(name: str):
"""Helper function to search for a commit given a human-readable name."""
if repo is None:
def _check_keys(self, keys: set) -> Optional[set]:
"""Private helper method that checks against required data keys and returns missing ones."""
if not keys >= self.check_keys:
return self.check_keys - keys
return None
for h in repo.heads:
if h.name == name:
return h.commit.hexsha[:build_len]
for t in repo.tags:
if t.name == name:
return t.commit.hexsha[:build_len]
for c in repo.iter_commits("--all"):
if c.hexsha[:build_len] == name[:build_len]:
return c.hexsha[:build_len]
return None
def find_parent_in_data(self, commit: git.Commit) -> Optional[str]:
"""Helper method to find the most recent parent measured in number of commits for which there is data."""
heap: list[tuple[int, git.Commit]] = [(0, commit)]
seen_hexsha8 = set()
while heap:
depth, current_commit = heapq.heappop(heap)
current_hexsha8 = commit.hexsha[:self.build_len]
if current_hexsha8 in self.builds:
return current_hexsha8
for parent in commit.parents:
parent_hexsha8 = parent.hexsha[:self.build_len]
if parent_hexsha8 not in seen_hexsha8:
seen_hexsha8.add(parent_hexsha8)
heapq.heappush(heap, (depth + 1, parent))
return None
def get_all_parent_hexsha8s(self, commit: git.Commit) -> Sequence[str]:
"""Helper method to recursively get hexsha8 values for all parents of a commit."""
unvisited = [commit]
visited = []
while unvisited:
current_commit = unvisited.pop(0)
visited.append(current_commit.hexsha[:self.build_len])
for parent in current_commit.parents:
if parent.hexsha[:self.build_len] not in visited:
unvisited.append(parent)
return visited
def get_commit_name(self, hexsha8: str) -> str:
"""Helper method to find a human-readable name for a commit if possible."""
if self.repo is None:
return hexsha8
for h in self.repo.heads:
if h.commit.hexsha[:self.build_len] == hexsha8:
return h.name
for t in self.repo.tags:
if t.commit.hexsha[:self.build_len] == hexsha8:
return t.name
return hexsha8
def get_commit_hexsha8(self, name: str) -> Optional[str]:
"""Helper method to search for a commit given a human-readable name."""
if self.repo is None:
return None
for h in self.repo.heads:
if h.name == name:
return h.commit.hexsha[:self.build_len]
for t in self.repo.tags:
if t.name == name:
return t.commit.hexsha[:self.build_len]
for c in self.repo.iter_commits("--all"):
if c.hexsha[:self.build_len] == name[:self.build_len]:
return c.hexsha[:self.build_len]
return None
def builds_timestamp(self, reverse: bool = False) -> Union[Iterator[tuple], Sequence[tuple]]:
"""Helper method that gets rows of (build_commit, test_time) sorted by the latter."""
return []
def get_rows(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare: str) -> Sequence[tuple]:
"""
Helper method that gets table rows for some list of properties.
Rows are created by combining those where all provided properties are equal.
The resulting rows are then grouped by the provided properties and the t/s values are averaged.
The returned rows are unique in terms of property combinations.
"""
return []
class LlamaBenchDataSQLite3(LlamaBenchData):
connection: sqlite3.Connection
cursor: sqlite3.Cursor
def __init__(self):
super().__init__()
self.connection = sqlite3.connect(":memory:")
self.cursor = self.connection.cursor()
self.cursor.execute(f"CREATE TABLE test({', '.join(' '.join(x) for x in zip(DB_FIELDS, DB_TYPES))});")
def _builds_init(self):
if self.connection:
self.build_len_min = self.cursor.execute("SELECT MIN(LENGTH(build_commit)) from test;").fetchone()[0]
self.build_len_max = self.cursor.execute("SELECT MAX(LENGTH(build_commit)) from test;").fetchone()[0]
if self.build_len_min != self.build_len_max:
logger.warning("Data contains commit hashes of differing lengths. It's possible that the wrong commits will be compared. "
"Try purging the the database of old commits.")
self.cursor.execute(f"UPDATE test SET build_commit = SUBSTRING(build_commit, 1, {self.build_len_min});")
builds = self.cursor.execute("SELECT DISTINCT build_commit FROM test;").fetchall()
self.builds = list(map(lambda b: b[0], builds)) # list[tuple[str]] -> list[str]
super()._builds_init()
def builds_timestamp(self, reverse: bool = False) -> Union[Iterator[tuple], Sequence[tuple]]:
data = self.cursor.execute(
"SELECT build_commit, test_time FROM test ORDER BY test_time;").fetchall()
return reversed(data) if reverse else data
def get_rows(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare: str) -> Sequence[tuple]:
select_string = ", ".join(
[f"tb.{p}" for p in properties] + ["tb.n_prompt", "tb.n_gen", "tb.n_depth", "AVG(tb.avg_ts)", "AVG(tc.avg_ts)"])
equal_string = " AND ".join(
[f"tb.{p} = tc.{p}" for p in KEY_PROPERTIES] + [
f"tb.build_commit = '{hexsha8_baseline}'", f"tc.build_commit = '{hexsha8_compare}'"]
)
group_order_string = ", ".join([f"tb.{p}" for p in properties] + ["tb.n_gen", "tb.n_prompt", "tb.n_depth"])
query = (f"SELECT {select_string} FROM test tb JOIN test tc ON {equal_string} "
f"GROUP BY {group_order_string} ORDER BY {group_order_string};")
return self.cursor.execute(query).fetchall()
class LlamaBenchDataSQLite3File(LlamaBenchDataSQLite3):
def __init__(self, data_file: str):
super().__init__()
self.connection.close()
self.connection = sqlite3.connect(data_file)
self.cursor = self.connection.cursor()
self._builds_init()
@staticmethod
def valid_format(data_file: str) -> bool:
connection = sqlite3.connect(data_file)
cursor = connection.cursor()
try:
if cursor.execute("PRAGMA schema_version;").fetchone()[0] == 0:
raise sqlite3.DatabaseError("The provided input file does not exist or is empty.")
except sqlite3.DatabaseError as e:
logger.debug(f'"{data_file}" is not a valid SQLite3 file.', exc_info=e)
cursor = None
connection.close()
return True if cursor else False
class LlamaBenchDataJSONL(LlamaBenchDataSQLite3):
def __init__(self, data_file: str):
super().__init__()
with open(data_file, "r", encoding="utf-8") as fp:
for i, line in enumerate(fp):
parsed = json.loads(line)
for k in parsed.keys() - set(DB_FIELDS):
del parsed[k]
if (missing_keys := self._check_keys(parsed.keys())):
raise RuntimeError(f"Missing required data key(s) at line {i + 1}: {', '.join(missing_keys)}")
self.cursor.execute(f"INSERT INTO test({', '.join(parsed.keys())}) VALUES({', '.join('?' * len(parsed))});", tuple(parsed.values()))
self._builds_init()
@staticmethod
def valid_format(data_file: str) -> bool:
try:
with open(data_file, "r", encoding="utf-8") as fp:
for line in fp:
json.loads(line)
break
except Exception as e:
logger.debug(f'"{data_file}" is not a valid JSONL file.', exc_info=e)
return False
return True
class LlamaBenchDataJSON(LlamaBenchDataSQLite3):
def __init__(self, data_files: list[str]):
super().__init__()
for data_file in data_files:
with open(data_file, "r", encoding="utf-8") as fp:
parsed = json.load(fp)
for i, entry in enumerate(parsed):
for k in entry.keys() - set(DB_FIELDS):
del entry[k]
if (missing_keys := self._check_keys(entry.keys())):
raise RuntimeError(f"Missing required data key(s) at entry {i + 1}: {', '.join(missing_keys)}")
self.cursor.execute(f"INSERT INTO test({', '.join(entry.keys())}) VALUES({', '.join('?' * len(entry))});", tuple(entry.values()))
self._builds_init()
@staticmethod
def valid_format(data_files: list[str]) -> bool:
if not data_files:
return False
for data_file in data_files:
try:
with open(data_file, "r", encoding="utf-8") as fp:
json.load(fp)
except Exception as e:
logger.debug(f'"{data_file}" is not a valid JSON file.', exc_info=e)
return False
return True
class LlamaBenchDataCSV(LlamaBenchDataSQLite3):
def __init__(self, data_files: list[str]):
super().__init__()
for data_file in data_files:
with open(data_file, "r", encoding="utf-8") as fp:
for i, parsed in enumerate(csv.DictReader(fp)):
keys = set(parsed.keys())
for k in keys - set(DB_FIELDS):
del parsed[k]
if (missing_keys := self._check_keys(keys)):
raise RuntimeError(f"Missing required data key(s) at line {i + 1}: {', '.join(missing_keys)}")
self.cursor.execute(f"INSERT INTO test({', '.join(parsed.keys())}) VALUES({', '.join('?' * len(parsed))});", tuple(parsed.values()))
self._builds_init()
@staticmethod
def valid_format(data_files: list[str]) -> bool:
if not data_files:
return False
for data_file in data_files:
try:
with open(data_file, "r", encoding="utf-8") as fp:
for parsed in csv.DictReader(fp):
break
except Exception as e:
logger.debug(f'"{data_file}" is not a valid CSV file.', exc_info=e)
return False
return True
bench_data = None
if len(input_file) == 1:
if LlamaBenchDataSQLite3File.valid_format(input_file[0]):
bench_data = LlamaBenchDataSQLite3File(input_file[0])
elif LlamaBenchDataJSON.valid_format(input_file):
bench_data = LlamaBenchDataJSON(input_file)
elif LlamaBenchDataJSONL.valid_format(input_file[0]):
bench_data = LlamaBenchDataJSONL(input_file[0])
elif LlamaBenchDataCSV.valid_format(input_file):
bench_data = LlamaBenchDataCSV(input_file)
else:
if LlamaBenchDataJSON.valid_format(input_file):
bench_data = LlamaBenchDataJSON(input_file)
elif LlamaBenchDataCSV.valid_format(input_file):
bench_data = LlamaBenchDataCSV(input_file)
if not bench_data:
raise RuntimeError("No valid (or some invalid) input files found.")
if not bench_data.builds:
raise RuntimeError(f"{input_file} does not contain any builds.")
hexsha8_baseline = name_baseline = None
# If the user specified a baseline, try to find a commit for it:
if known_args.baseline is not None:
if known_args.baseline in builds:
if known_args.baseline in bench_data.builds:
hexsha8_baseline = known_args.baseline
if hexsha8_baseline is None:
hexsha8_baseline = get_commit_hexsha8(known_args.baseline)
hexsha8_baseline = bench_data.get_commit_hexsha8(known_args.baseline)
name_baseline = known_args.baseline
if hexsha8_baseline is None:
logger.error(f"cannot find data for baseline={known_args.baseline}.")
sys.exit(1)
# Otherwise, search for the most recent parent of master for which there is data:
elif repo is not None:
hexsha8_baseline = find_parent_in_data(repo.heads.master.commit)
elif bench_data.repo is not None:
hexsha8_baseline = bench_data.find_parent_in_data(bench_data.repo.heads.master.commit)
if hexsha8_baseline is None:
logger.error("No baseline was provided and did not find data for any master branch commits.\n")
@@ -235,27 +464,25 @@ else:
sys.exit(1)
name_baseline = get_commit_name(hexsha8_baseline)
name_baseline = bench_data.get_commit_name(hexsha8_baseline)
hexsha8_compare = name_compare = None
# If the user has specified a compare value, try to find a corresponding commit:
if known_args.compare is not None:
if known_args.compare in builds:
if known_args.compare in bench_data.builds:
hexsha8_compare = known_args.compare
if hexsha8_compare is None:
hexsha8_compare = get_commit_hexsha8(known_args.compare)
hexsha8_compare = bench_data.get_commit_hexsha8(known_args.compare)
name_compare = known_args.compare
if hexsha8_compare is None:
logger.error(f"cannot find data for compare={known_args.compare}.")
sys.exit(1)
# Otherwise, search for the commit for llama-bench was most recently run
# and that is not a parent of master:
elif repo is not None:
hexsha8s_master = get_all_parent_hexsha8s(repo.heads.master.commit)
builds_timestamp = cursor.execute(
"SELECT build_commit, test_time FROM test ORDER BY test_time;").fetchall()
for (hexsha8, _) in reversed(builds_timestamp):
elif bench_data.repo is not None:
hexsha8s_master = bench_data.get_all_parent_hexsha8s(bench_data.repo.heads.master.commit)
for (hexsha8, _) in bench_data.builds_timestamp(reverse=True):
if hexsha8 not in hexsha8s_master:
hexsha8_compare = hexsha8
break
@@ -270,26 +497,7 @@ else:
parser.print_help()
sys.exit(1)
name_compare = get_commit_name(hexsha8_compare)
def get_rows(properties):
"""
Helper function that gets table rows for some list of properties.
Rows are created by combining those where all provided properties are equal.
The resulting rows are then grouped by the provided properties and the t/s values are averaged.
The returned rows are unique in terms of property combinations.
"""
select_string = ", ".join(
[f"tb.{p}" for p in properties] + ["tb.n_prompt", "tb.n_gen", "tb.n_depth", "AVG(tb.avg_ts)", "AVG(tc.avg_ts)"])
equal_string = " AND ".join(
[f"tb.{p} = tc.{p}" for p in KEY_PROPERTIES] + [
f"tb.build_commit = '{hexsha8_baseline}'", f"tc.build_commit = '{hexsha8_compare}'"]
)
group_order_string = ", ".join([f"tb.{p}" for p in properties] + ["tb.n_gen", "tb.n_prompt", "tb.n_depth"])
query = (f"SELECT {select_string} FROM test tb JOIN test tc ON {equal_string} "
f"GROUP BY {group_order_string} ORDER BY {group_order_string};")
return cursor.execute(query).fetchall()
name_compare = bench_data.get_commit_name(hexsha8_compare)
# If the user provided columns to group the results by, use them:
@@ -297,16 +505,16 @@ if known_args.show is not None:
show = known_args.show.split(",")
unknown_cols = []
for prop in show:
if prop not in KEY_PROPERTIES[:-2]: # Last two values are n_prompt, n_gen.
if prop not in KEY_PROPERTIES[:-3]: # Last three values are n_prompt, n_gen, n_depth.
unknown_cols.append(prop)
if unknown_cols:
logger.error(f"Unknown values for --show: {', '.join(unknown_cols)}")
parser.print_usage()
sys.exit(1)
rows_show = get_rows(show)
rows_show = bench_data.get_rows(show, hexsha8_baseline, hexsha8_compare)
# Otherwise, select those columns where the values are not all the same:
else:
rows_full = get_rows(KEY_PROPERTIES)
rows_full = bench_data.get_rows(KEY_PROPERTIES, hexsha8_baseline, hexsha8_compare)
properties_different = []
for i, kp_i in enumerate(KEY_PROPERTIES):
if kp_i in DEFAULT_SHOW or kp_i in ["n_prompt", "n_gen", "n_depth"]:
@@ -318,7 +526,7 @@ else:
show = []
# Show CPU and/or GPU by default even if the hardware for all results is the same:
if "n_gpu_layers" not in properties_different:
if rows_full and "n_gpu_layers" not in properties_different:
ngl = int(rows_full[0][KEY_PROPERTIES.index("n_gpu_layers")])
if ngl != 99 and "cpu_info" not in properties_different:
@@ -336,7 +544,11 @@ else:
show.remove(prop)
except ValueError:
pass
rows_show = get_rows(show)
rows_show = bench_data.get_rows(show, hexsha8_baseline, hexsha8_compare)
if not rows_show:
logger.error(f"No comparable data was found between {name_baseline} and {name_compare}.\n")
sys.exit(1)
table = []
for row in rows_show:
+1 -1
View File
@@ -1 +1 @@
b59bddafe278877dfa22a80e53a637513862babb
9b048bb72b811f50b0c30d9e5c84d6ff9f4bf005
+1
View File
@@ -23,6 +23,7 @@ add_library(llama
llama-memory.cpp
llama-mmap.cpp
llama-model-loader.cpp
llama-model-saver.cpp
llama-model.cpp
llama-quant.cpp
llama-sampling.cpp
+3
View File
@@ -1481,6 +1481,9 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
},
},
{
+264 -11
View File
@@ -93,6 +93,7 @@ llama_context::llama_context(
}
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
cparams.op_offload = params.op_offload;
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
@@ -243,7 +244,7 @@ llama_context::llama_context(
}
}
sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel));
sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel, cparams.op_offload));
if (pipeline_parallel) {
LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get()));
@@ -358,7 +359,9 @@ llama_context::llama_context(
}
}
llama_context::~llama_context() = default;
llama_context::~llama_context() {
ggml_opt_free(opt_ctx);
}
void llama_context::synchronize() {
ggml_backend_sched_synchronize(sched.get());
@@ -1701,10 +1704,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
}
}
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
kv_self->state_write(io);
if (kv_self != nullptr) {
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
kv_self->state_write(io);
}
return io.n_bytes();
}
@@ -1787,10 +1792,13 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
}
}
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
if (memory) {
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
kv_self->state_read(io);
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
kv_self->state_read(io);
}
return io.n_bytes();
}
@@ -1798,9 +1806,11 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
GGML_UNUSED(seq_id);
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
if (memory) {
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
kv_self->state_write(io, seq_id);
kv_self->state_write(io, seq_id);
}
return io.n_bytes();
}
@@ -1808,9 +1818,11 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
GGML_UNUSED(seq_id);
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
if (memory) {
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
kv_self->state_read(io, seq_id);
kv_self->state_read(io, seq_id);
}
return io.n_bytes();
}
@@ -1838,6 +1850,215 @@ void llama_context::perf_reset() {
t_p_eval_us = n_p_eval = 0;
}
//
// training
//
static void llama_set_param(struct ggml_tensor * tensor, llama_opt_param_filter param_filter, void * userdata) {
if (!tensor || tensor->type != GGML_TYPE_F32) {
return;
}
if (!param_filter(tensor, userdata)) {
return;
}
if (strcmp(tensor->name, "token_embd.weight") == 0) {
return; // FIXME
}
if (strcmp(tensor->name, "rope_freqs.weight") == 0) {
return; // FIXME
}
ggml_set_param(tensor);
}
void llama_context::opt_init(struct llama_model * model, struct llama_opt_params lopt_params) {
GGML_ASSERT(!opt_ctx);
model->hparams.n_ctx_train = lopt_params.n_ctx_train > 0 ? lopt_params.n_ctx_train : n_ctx();
const uint32_t n_batch = std::min(this->n_batch(), model->hparams.n_ctx_train);
const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
GGML_ASSERT(model->hparams.n_ctx_train % n_batch == 0);
GGML_ASSERT(n_batch % n_ubatch == 0);
ggml_opt_params opt_params = ggml_opt_default_params(sched.get(), GGML_OPT_LOSS_TYPE_CROSS_ENTROPY);
opt_params.opt_period = n_batch / n_ubatch;
opt_params.get_opt_pars = lopt_params.get_opt_pars;
opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud;
opt_ctx = ggml_opt_init(opt_params);
llama_opt_param_filter param_filter = lopt_params.param_filter;
void * param_filter_ud = lopt_params.param_filter_ud;
//llama_set_param(model->tok_embd, param_filter, param_filter_ud); // FIXME
llama_set_param(model->type_embd, param_filter, param_filter_ud);
llama_set_param(model->pos_embd, param_filter, param_filter_ud);
llama_set_param(model->tok_norm, param_filter, param_filter_ud);
llama_set_param(model->tok_norm_b, param_filter, param_filter_ud);
llama_set_param(model->output_norm, param_filter, param_filter_ud);
llama_set_param(model->output_norm_b, param_filter, param_filter_ud);
llama_set_param(model->output, param_filter, param_filter_ud);
llama_set_param(model->output_b, param_filter, param_filter_ud);
llama_set_param(model->output_norm_enc, param_filter, param_filter_ud);
llama_set_param(model->cls, param_filter, param_filter_ud);
llama_set_param(model->cls_b, param_filter, param_filter_ud);
llama_set_param(model->cls_out, param_filter, param_filter_ud);
llama_set_param(model->cls_out_b, param_filter, param_filter_ud);
for (struct llama_layer & layer : model->layers) {
for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) {
llama_set_param(reinterpret_cast<struct ggml_tensor **>(&layer)[i], param_filter, param_filter_ud);
}
}
}
void llama_context::opt_epoch_iter(
ggml_opt_dataset_t dataset,
ggml_opt_result_t result,
const std::vector<llama_token> & tokens,
const std::vector<llama_token> & labels_sparse,
llama_batch & batch,
ggml_opt_epoch_callback callback,
bool train,
int64_t idata_in_loop,
int64_t ndata_in_loop,
int64_t t_loop_start) {
GGML_ASSERT(opt_ctx);
const uint32_t n_ctx = llama_model_n_ctx_train(&model);
const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
kv_self->clear();
llama_kv_cache_guard kv_guard(kv_self);
for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
batch.n_tokens = n_batch;
for (uint32_t pos_batch = 0; pos_batch < n_batch; ++pos_batch) {
batch.token [pos_batch] = tokens[pos_ctx + pos_batch];
batch.pos [pos_batch] = pos_ctx + pos_batch;
batch.n_seq_id[pos_batch] = 1;
batch.seq_id [pos_batch][0] = 0;
batch.logits [pos_batch] = true;
}
const auto n_tokens_all = batch.n_tokens;
n_queued_tokens += n_tokens_all;
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
embd_seq.clear();
int64_t n_outputs_all = n_tokens_all;
llama_sbatch sbatch = kv_self->sbatch_init(batch, /*logits_all =*/ true);
// reserve output buffer
if (output_reserve(n_outputs_all) < n_outputs_all) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
GGML_ABORT("TODO: handle this error");
};
for (uint32_t pos_batch = 0; pos_batch < n_batch; pos_batch += n_ubatch) {
llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
n_outputs = ubatch.n_tokens;
// TODO: not sure if this is needed
if (!kv_self->find_slot(ubatch)) {
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
GGML_ABORT("TODO: handle this error");
}
auto * gf = graph_init();
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
struct ggml_context * ctx_compute_opt;
{
const size_t size_gf = ggml_graph_size(gf);
const size_t size_meta = 4*size_gf*ggml_tensor_overhead() + 2*ggml_graph_overhead_custom(size_gf, /*grads = */ true);
struct ggml_init_params params = {
/*.mem_size =*/ size_meta,
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
};
ctx_compute_opt = ggml_init(params);
}
ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
ggml_opt_alloc(opt_ctx, train);
res->set_inputs(&ubatch);
{
struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
GGML_ASSERT(labels->ne[1] == n_ubatch);
ggml_set_zero(labels);
const float onef = 1.0f;
for (uint32_t pos_ubatch = 0; pos_ubatch < n_ubatch; ++pos_ubatch) {
const uint32_t ilabel = pos_ctx + pos_batch + pos_ubatch;
GGML_ASSERT(labels_sparse[ilabel] < labels->ne[0]);
ggml_backend_tensor_set(labels, &onef, (pos_ubatch*labels->ne[0] + labels_sparse[ilabel])*sizeof(float), sizeof(float));
}
}
ggml_opt_eval(opt_ctx, result);
if (callback) {
callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
}
ggml_free(ctx_compute_opt);
}
}
kv_guard.commit();
}
void llama_context::opt_epoch(
ggml_opt_dataset_t dataset,
ggml_opt_result_t result_train,
ggml_opt_result_t result_eval,
int64_t idata_split,
ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval) {
const uint32_t n_ctx = this->n_ctx();
const uint32_t n_batch = std::min(cparams.n_batch, n_ctx);
const uint32_t n_ubatch = std::min(cparams.n_ubatch, n_batch);
const int64_t ndata = ggml_opt_dataset_ndata(dataset);
GGML_ASSERT(idata_split >= 0);
GGML_ASSERT(idata_split <= ndata);
const uint32_t ubatch_per_ctx = n_ctx / n_ubatch;
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
std::vector<llama_token> tokens(n_ctx);
std::vector<llama_token> labels_sparse(n_ctx);
int64_t idata = 0;
int64_t t_loop_start = ggml_time_us();
int64_t ndata_in_loop = idata_split*ubatch_per_ctx;
for (; idata < idata_split; ++idata) {
constexpr bool train = true;
const int64_t idata_in_loop = idata*ubatch_per_ctx;
ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata);
opt_epoch_iter(dataset, result_train, tokens, labels_sparse, batch,
callback_train, train, idata_in_loop, ndata_in_loop, t_loop_start);
}
t_loop_start = ggml_time_us();
ndata_in_loop = (ndata - idata_split)*ubatch_per_ctx;
for (; idata < ndata; ++idata) {
constexpr bool train = false;
const int64_t idata_in_loop = (idata - idata_split)*ubatch_per_ctx;
ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata);
opt_epoch_iter(dataset, result_eval, tokens, labels_sparse, batch,
callback_eval, train, idata_in_loop, ndata_in_loop, t_loop_start);
}
llama_batch_free(batch);
}
//
// interface implementation
//
@@ -1871,6 +2092,7 @@ llama_context_params llama_context_default_params() {
/*.offload_kqv =*/ true,
/*.flash_attn =*/ false,
/*.no_perf =*/ true,
/*.op_offload =*/ true,
};
return result;
@@ -2455,3 +2677,34 @@ void llama_perf_context_print(const llama_context * ctx) {
void llama_perf_context_reset(llama_context * ctx) {
ctx->perf_reset();
}
//
// training
//
bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata) {
GGML_UNUSED(tensor);
GGML_UNUSED(userdata);
return true;
}
void llama_opt_init(struct llama_context * ctx, struct llama_model * model, struct llama_opt_params lopt_params) {
ctx->opt_init(model, lopt_params);
}
void llama_opt_epoch(
struct llama_context * ctx,
ggml_opt_dataset_t dataset,
ggml_opt_result_t result_train,
ggml_opt_result_t result_eval,
int64_t idata_split,
ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval) {
ctx->opt_epoch(
dataset,
result_train,
result_eval,
idata_split,
callback_train,
callback_eval);
}
+30
View File
@@ -7,6 +7,7 @@
#include "llama-adapter.h"
#include "ggml-cpp.h"
#include "ggml-opt.h"
#include <map>
#include <vector>
@@ -133,6 +134,32 @@ struct llama_context {
llama_perf_context_data perf_get_data() const;
void perf_reset();
//
// training
//
void opt_init(struct llama_model * model, struct llama_opt_params lopt_params);
void opt_epoch(
ggml_opt_dataset_t dataset,
ggml_opt_result_t result_train,
ggml_opt_result_t result_eval,
int64_t idata_split,
ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval);
void opt_epoch_iter(
ggml_opt_dataset_t dataset,
ggml_opt_result_t result,
const std::vector<llama_token> & tokens,
const std::vector<llama_token> & labels_sparse,
llama_batch & batch,
ggml_opt_epoch_callback callback,
bool train,
int64_t idata_in_loop,
int64_t ndata_in_loop,
int64_t t_loop_start);
private:
//
// output
@@ -212,6 +239,9 @@ private:
ggml_context_ptr ctx_compute;
// training
ggml_opt_context_t opt_ctx = nullptr;
ggml_threadpool_t threadpool = nullptr;
ggml_threadpool_t threadpool_batch = nullptr;
+1
View File
@@ -30,6 +30,7 @@ struct llama_cparams {
bool flash_attn;
bool no_perf;
bool warmup;
bool op_offload;
enum llama_pooling_type pooling_type;
+1
View File
@@ -971,6 +971,7 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
//cb(inp->tokens, "inp_tokens", -1);
ggml_set_input(inp->tokens);
res->t_tokens = inp->tokens;
cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
+3
View File
@@ -298,6 +298,7 @@ class llm_graph_result_i {
public:
virtual ~llm_graph_result_i() = default;
virtual ggml_tensor * get_tokens() = 0;
virtual ggml_tensor * get_logits() = 0;
virtual ggml_tensor * get_embd() = 0;
virtual ggml_tensor * get_embd_pooled() = 0;
@@ -312,6 +313,7 @@ class llm_graph_result : public llm_graph_result_i {
public:
virtual ~llm_graph_result() = default;
ggml_tensor * get_tokens() override { return t_tokens; }
ggml_tensor * get_logits() override { return t_logits; }
ggml_tensor * get_embd() override { return t_embd; }
ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
@@ -328,6 +330,7 @@ public:
}
// important graph nodes
ggml_tensor * t_tokens = nullptr;
ggml_tensor * t_logits = nullptr;
ggml_tensor * t_embd = nullptr;
ggml_tensor * t_embd_pooled = nullptr;
+8
View File
@@ -441,6 +441,13 @@ void llama_kv_cache_unified::defrag_sched(float thold) {
void llama_kv_cache_unified::set_full() {
n = size;
// when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not
// affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views.
// we should only guarantee that the head position won't cause out-of-bounds view of the K, V tensors, so
// setting it to 0 is the simplest way to achieve that
// ref: https://github.com/ggml-org/llama.cpp/issues/13359
head = 0;
}
llama_sbatch llama_kv_cache_unified::sbatch_init(
@@ -1712,6 +1719,7 @@ void llama_kv_cache_recurrent::defrag_sched(float thold) {
void llama_kv_cache_recurrent::set_full() {
n = size;
head = 0;
}
llama_sbatch llama_kv_cache_recurrent::sbatch_init(
+4 -10
View File
@@ -171,11 +171,8 @@ public:
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
// Note: The value of head isn't only used to optimize searching
// for a free KV slot. llama_decode_impl also uses it, so it
// cannot be freely changed after a slot has been allocated.
uint32_t head = 0;
uint32_t size = 0;
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
uint32_t size = 0; // total number of cells, shared across all sequences
uint32_t used = 0; // used cells (i.e. at least one seq_id)
// computed before each graph build
@@ -343,11 +340,8 @@ public:
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
// Note: The value of head isn't only used to optimize searching
// for a free KV slot. llama_decode_impl also uses it, so it
// cannot be freely changed after a slot has been allocated.
uint32_t head = 0;
uint32_t size = 0;
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
uint32_t size = 0; // total number of cells, shared across all sequences
uint32_t used = 0; // used cells (i.e. at least one seq_id)
// computed before each graph build
+22 -17
View File
@@ -301,12 +301,12 @@ namespace GGUFMeta {
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
switch (arr_info.gt) {
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
case GGUF_TYPE_INT32: GGML_ASSERT(
(std::is_same<T, int32_t>::value) ||
(std::is_same<T, uint32_t>::value)); break;
case GGUF_TYPE_UINT32:
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
(std::is_same<T, uint32_t>::value)); break;
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
default:
throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str()));
throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str()));
}
result.resize(arr_info.length);
@@ -330,12 +330,12 @@ namespace GGUFMeta {
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
switch (arr_info.gt) {
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
case GGUF_TYPE_INT32: GGML_ASSERT(
(std::is_same<T, int32_t>::value) ||
(std::is_same<T, uint32_t>::value)); break;
case GGUF_TYPE_UINT32:
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
(std::is_same<T, uint32_t>::value)); break;
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
default:
throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str()));
throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str()));
}
if (arr_info.length > N_MAX) {
@@ -469,7 +469,7 @@ llama_model_loader::llama_model_loader(
meta.reset(gguf_init_from_file(fname.c_str(), params));
if (!meta) {
throw std::runtime_error(format("%s: failed to load model from %s\n", __func__, fname.c_str()));
throw std::runtime_error(format("%s: failed to load model from %s", __func__, fname.c_str()));
}
get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false);
@@ -528,7 +528,7 @@ llama_model_loader::llama_model_loader(
};
gguf_context_ptr ctx_gguf { gguf_init_from_file(fname_split, split_params) };
if (!ctx_gguf) {
throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, fname_split));
throw std::runtime_error(format("%s: failed to load GGUF split from %s", __func__, fname_split));
}
// check idx
@@ -822,13 +822,18 @@ void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps
mappings.reserve(files.size());
mmaps_used.reserve(files.size());
for (const auto & file : files) {
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU));
if (!reg) {
throw std::runtime_error(format("%s: no CPU backend found", __func__));
bool is_numa = false;
auto * dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
if (dev) {
auto * reg = ggml_backend_dev_backend_reg(dev);
auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa");
if (is_numa_fn) {
is_numa = is_numa_fn();
}
}
auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa");
std::unique_ptr<llama_mmap> mapping = std::make_unique<llama_mmap>(file.get(), prefetch ? -1 : 0, is_numa_fn());
std::unique_ptr<llama_mmap> mapping = std::make_unique<llama_mmap>(file.get(), prefetch ? -1 : 0, is_numa);
mmaps_used.emplace_back(mapping->size(), 0);
if (mlock_mmaps) {
std::unique_ptr<llama_mlock> mlock_mmap(new llama_mlock());
+281
View File
@@ -0,0 +1,281 @@
#include "llama-model-saver.h"
#include "gguf.h"
#include "llama.h"
#include "llama-hparams.h"
#include "llama-model.h"
#include "llama-vocab.h"
#include <string>
llama_model_saver::llama_model_saver(const struct llama_model & model) : model(model), llm_kv(model.arch) {
gguf_ctx = gguf_init_empty();
}
llama_model_saver::~llama_model_saver() {
gguf_free(gguf_ctx);
}
void llama_model_saver::add_kv(const enum llm_kv key, const uint32_t value) {
gguf_set_val_u32(gguf_ctx, llm_kv(key).c_str(), value);
}
void llama_model_saver::add_kv(const enum llm_kv key, const int32_t value) {
gguf_set_val_i32(gguf_ctx, llm_kv(key).c_str(), value);
}
void llama_model_saver::add_kv(const enum llm_kv key, const float value) {
gguf_set_val_f32(gguf_ctx, llm_kv(key).c_str(), value);
}
void llama_model_saver::add_kv(const enum llm_kv key, const bool value) {
gguf_set_val_bool(gguf_ctx, llm_kv(key).c_str(), value);
}
void llama_model_saver::add_kv(const enum llm_kv key, const char * value) {
gguf_set_val_str(gguf_ctx, llm_kv(key).c_str(), value);
}
[[noreturn]]
void llama_model_saver::add_kv(const enum llm_kv key, const char value) {
GGML_UNUSED(key);
GGML_UNUSED(value);
GGML_ABORT("fatal error"); // this should never be called, only needed to make the template below compile
}
template <typename Container>
void llama_model_saver::add_kv(const enum llm_kv key, const Container & value, const bool per_layer) {
const size_t n_values = per_layer ? size_t(model.hparams.n_layer) : value.size();
GGML_ASSERT(n_values <= value.size());
if (n_values == 0) {
return;
}
if (per_layer) {
bool all_values_the_same = true;
for (size_t i = 1; i < n_values; ++i) {
if (value[i] != value[0]) {
all_values_the_same = false;
break;
}
}
if (all_values_the_same) {
add_kv(key, value[0]);
return;
}
}
if (std::is_same<typename Container::value_type, uint8_t>::value) {
gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_UINT8, value.data(), n_values);
} else if (std::is_same<typename Container::value_type, int8_t>::value) {
gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_INT8, value.data(), n_values);
} else if (std::is_same<typename Container::value_type, uint32_t>::value) {
gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_UINT32, value.data(), n_values);
} else if (std::is_same<typename Container::value_type, int32_t>::value) {
gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_INT32, value.data(), n_values);
} else if (std::is_same<typename Container::value_type, float>::value) {
gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_FLOAT32, value.data(), n_values);
} else if (std::is_same<Container, std::string>::value) {
gguf_set_val_str(gguf_ctx, llm_kv(key).c_str(), reinterpret_cast<const char *>(value.data()));
} else {
GGML_ABORT("fatal error");
}
}
void llama_model_saver::add_kv(const enum llm_kv key, const std::vector<std::string> & value) {
std::vector<const char *> tmp(value.size());
for (size_t i = 0; i < value.size(); ++i) {
tmp[i] = value[i].c_str();
}
gguf_set_arr_str(gguf_ctx, llm_kv(key).c_str(), tmp.data(), tmp.size());
}
void llama_model_saver::add_tensor(const struct ggml_tensor * tensor) {
if (!tensor) {
return;
}
if (gguf_find_tensor(gguf_ctx, tensor->name) >= 0) {
GGML_ASSERT(std::string(tensor->name) == "rope_freqs.weight"); // FIXME
return;
}
gguf_add_tensor(gguf_ctx, tensor);
}
void llama_model_saver::add_kv_from_model() {
const llama_hparams & hparams = model.hparams;
const llama_vocab & vocab = model.vocab;
const int32_t n_vocab = vocab.n_tokens();
std::vector<std::string> tokens(n_vocab);
std::vector<float> scores(n_vocab);
std::vector<int32_t> token_types(n_vocab);
for (int32_t id = 0; id < n_vocab; ++id) {
const llama_vocab::token_data & token_data = vocab.get_token_data(id);
tokens[id] = token_data.text;
scores[id] = token_data.score;
switch(token_data.attr) {
case LLAMA_TOKEN_ATTR_UNKNOWN: token_types[id] = LLAMA_TOKEN_TYPE_UNKNOWN; break;
case LLAMA_TOKEN_ATTR_UNUSED: token_types[id] = LLAMA_TOKEN_TYPE_UNUSED; break;
case LLAMA_TOKEN_ATTR_NORMAL: token_types[id] = LLAMA_TOKEN_TYPE_NORMAL; break;
case LLAMA_TOKEN_ATTR_CONTROL: token_types[id] = LLAMA_TOKEN_TYPE_CONTROL; break;
case LLAMA_TOKEN_ATTR_USER_DEFINED: token_types[id] = LLAMA_TOKEN_TYPE_USER_DEFINED; break;
case LLAMA_TOKEN_ATTR_BYTE: token_types[id] = LLAMA_TOKEN_TYPE_BYTE; break;
case LLAMA_TOKEN_ATTR_UNDEFINED:
default: token_types[id] = LLAMA_TOKEN_TYPE_UNDEFINED; break;
}
}
// add_kv(LLM_KV_GENERAL_TYPE, ???);
add_kv(LLM_KV_GENERAL_ARCHITECTURE, model.arch_name());
// add_kv(LLM_KV_GENERAL_QUANTIZATION_VERSION, ???);
// add_kv(LLM_KV_GENERAL_ALIGNMENT, ???);
add_kv(LLM_KV_GENERAL_NAME, model.name);
// add_kv(LLM_KV_GENERAL_AUTHOR, ???);
// add_kv(LLM_KV_GENERAL_VERSION, ???);
// add_kv(LLM_KV_GENERAL_URL, ???);
// add_kv(LLM_KV_GENERAL_DESCRIPTION, ???);
// add_kv(LLM_KV_GENERAL_LICENSE, ???);
// add_kv(LLM_KV_GENERAL_SOURCE_URL, ???);
// add_kv(LLM_KV_GENERAL_SOURCE_HF_REPO, ???);
add_kv(LLM_KV_VOCAB_SIZE, vocab.n_tokens());
add_kv(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train);
add_kv(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd);
add_kv(LLM_KV_BLOCK_COUNT, hparams.n_layer);
add_kv(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
add_kv(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, true);
add_kv(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
add_kv(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
add_kv(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res);
// add_kv(LLM_KV_TENSOR_DATA_LAYOUT, ???);
add_kv(LLM_KV_EXPERT_COUNT, hparams.n_expert);
add_kv(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used);
add_kv(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
add_kv(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
add_kv(LLM_KV_POOLING_TYPE, uint32_t(hparams.pooling_type));
add_kv(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
add_kv(LLM_KV_DECODER_START_TOKEN_ID, hparams.dec_start_token_id);
add_kv(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping);
add_kv(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping);
add_kv(LLM_KV_SWIN_NORM, hparams.swin_norm);
add_kv(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers);
add_kv(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim);
add_kv(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim);
add_kv(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale);
add_kv(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale);
add_kv(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, true);
add_kv(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, true);
add_kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias);
add_kv(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv);
add_kv(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k);
add_kv(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v);
add_kv(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
add_kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
add_kv(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
add_kv(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q);
add_kv(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv);
add_kv(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts);
add_kv(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
add_kv(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale);
const float rope_scaling_factor = hparams.rope_freq_scale_train == 1.0f ? 0.0f : 1.0f/hparams.rope_freq_scale_train;
add_kv(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot);
add_kv(LLM_KV_ROPE_FREQ_BASE, hparams.rope_freq_base_train);
// add_kv(LLM_KV_ROPE_SCALE_LINEAR, rope_scaling_factor); // old name
add_kv(LLM_KV_ROPE_SCALING_TYPE, llama_rope_scaling_type_name(hparams.rope_scaling_type_train));
add_kv(LLM_KV_ROPE_SCALING_FACTOR, rope_scaling_factor);
add_kv(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor);
add_kv(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, hparams.n_ctx_orig_yarn);
add_kv(LLM_KV_ROPE_SCALING_FINETUNED, hparams.rope_finetuned);
add_kv(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul);
// TODO: implement split file support
// add_kv(LLM_KV_SPLIT_NO, ???);
// add_kv(LLM_KV_SPLIT_COUNT, ???);
// add_kv(LLM_KV_SPLIT_TENSORS_COUNT, ???);
add_kv(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
add_kv(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
add_kv(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
add_kv(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
add_kv(LLM_KV_SSM_DT_B_C_RMS, hparams.ssm_dt_b_c_rms);
add_kv(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size);
add_kv(LLM_KV_TOKENIZER_MODEL, vocab.get_tokenizer_model());
add_kv(LLM_KV_TOKENIZER_PRE, vocab.get_tokenizer_pre());
add_kv(LLM_KV_TOKENIZER_LIST, tokens);
add_kv(LLM_KV_TOKENIZER_TOKEN_TYPE, token_types);
add_kv(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, vocab.n_token_types());
add_kv(LLM_KV_TOKENIZER_SCORES, scores);
add_kv(LLM_KV_TOKENIZER_MERGES, vocab.get_bpe_merges());
// FIXME llama_token is type i32 but when reading in a GGUF file u32 is expected, not an issue for writing though
add_kv(LLM_KV_TOKENIZER_BOS_ID, uint32_t(vocab.token_bos()));
add_kv(LLM_KV_TOKENIZER_EOS_ID, uint32_t(vocab.token_eos()));
add_kv(LLM_KV_TOKENIZER_EOT_ID, uint32_t(vocab.token_eot()));
add_kv(LLM_KV_TOKENIZER_EOM_ID, uint32_t(vocab.token_eom()));
add_kv(LLM_KV_TOKENIZER_UNK_ID, uint32_t(vocab.token_unk()));
add_kv(LLM_KV_TOKENIZER_SEP_ID, uint32_t(vocab.token_sep()));
add_kv(LLM_KV_TOKENIZER_PAD_ID, uint32_t(vocab.token_pad()));
// add_kv(LLM_KV_TOKENIZER_CLS_ID, uint32_t(vocab.token_bos())); // deprecated
// add_kv(LLM_KV_TOKENIZER_MASK_ID, ???);
add_kv(LLM_KV_TOKENIZER_ADD_BOS, vocab.get_add_bos());
add_kv(LLM_KV_TOKENIZER_ADD_EOS, vocab.get_add_eos());
add_kv(LLM_KV_TOKENIZER_ADD_PREFIX, vocab.get_add_space_prefix());
add_kv(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, vocab.get_remove_extra_whitespaces());
add_kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, vocab.get_precompiled_charsmap());
// add_kv(LLM_KV_TOKENIZER_HF_JSON, ???);
// add_kv(LLM_KV_TOKENIZER_RWKV, ???);
add_kv(LLM_KV_TOKENIZER_FIM_PRE_ID, uint32_t(vocab.token_fim_pre()));
add_kv(LLM_KV_TOKENIZER_FIM_SUF_ID, uint32_t(vocab.token_fim_suf()));
add_kv(LLM_KV_TOKENIZER_FIM_MID_ID, uint32_t(vocab.token_fim_mid()));
add_kv(LLM_KV_TOKENIZER_FIM_PAD_ID, uint32_t(vocab.token_fim_pad()));
add_kv(LLM_KV_TOKENIZER_FIM_REP_ID, uint32_t(vocab.token_fim_rep()));
add_kv(LLM_KV_TOKENIZER_FIM_SEP_ID, uint32_t(vocab.token_fim_sep()));
// TODO: implement LoRA support
// add_kv(LLM_KV_ADAPTER_TYPE, ???);
// add_kv(LLM_KV_ADAPTER_LORA_ALPHA, ???);
// deprecated
// add_kv(LLM_KV_TOKENIZER_PREFIX_ID, ???);
// add_kv(LLM_KV_TOKENIZER_SUFFIX_ID, ???);
// add_kv(LLM_KV_TOKENIZER_MIDDLE_ID, ???);
}
void llama_model_saver::add_tensors_from_model() {
if (std::string(model.output->name) != std::string(model.tok_embd->name)) {
add_tensor(model.tok_embd); // some models use the same tensor for tok_embd and output
}
add_tensor(model.type_embd);
add_tensor(model.pos_embd);
add_tensor(model.tok_norm);
add_tensor(model.tok_norm_b);
add_tensor(model.output_norm);
add_tensor(model.output_norm_b);
add_tensor(model.output);
add_tensor(model.output_b);
add_tensor(model.output_norm_enc);
add_tensor(model.cls);
add_tensor(model.cls_b);
add_tensor(model.cls_out);
add_tensor(model.cls_out_b);
for (const struct llama_layer & layer : model.layers) {
for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) {
add_tensor(reinterpret_cast<const struct ggml_tensor * const *>(&layer)[i]);
}
}
}
void llama_model_saver::save(const std::string & path_model) {
gguf_write_to_file(gguf_ctx, path_model.c_str(), false);
}
+37
View File
@@ -0,0 +1,37 @@
#pragma once
#include "llama.h"
#include "llama-arch.h"
#include <vector>
struct llama_model_saver {
struct gguf_context * gguf_ctx = nullptr;
const struct llama_model & model;
const struct LLM_KV llm_kv;
llama_model_saver(const struct llama_model & model);
~llama_model_saver();
void add_kv(enum llm_kv key, uint32_t value);
void add_kv(enum llm_kv key, int32_t value);
void add_kv(enum llm_kv key, float value);
void add_kv(enum llm_kv key, bool value);
void add_kv(enum llm_kv key, const char * value);
[[noreturn]]
void add_kv(enum llm_kv key, char value); // needed to make the template below compile
template <typename Container>
void add_kv(enum llm_kv key, const Container & value, bool per_layer = false);
void add_kv(enum llm_kv key, const std::vector<std::string> & value);
void add_tensor(const struct ggml_tensor * tensor);
void add_kv_from_model();
void add_tensors_from_model();
void save(const std::string & path_model);
};
+213 -35
View File
@@ -117,6 +117,10 @@ static const std::map<llama_rope_scaling_type, const char *> LLAMA_ROPE_SCALING_
{ LLAMA_ROPE_SCALING_TYPE_LONGROPE, "longrope" },
};
std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type) {
return LLAMA_ROPE_SCALING_TYPES.at(rope_scaling_type);
}
static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) {
for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) {
if (kv.second == name) {
@@ -1385,6 +1389,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
// Add additional layer/vocab/etc checks here for other model sizes
default: type = LLM_TYPE_UNKNOWN;
}
// For Granite MoE Shared
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false);
} break;
case LLM_ARCH_CHAMELEON:
{
@@ -1768,6 +1775,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
// For Granite MoE Shared
if (hparams.n_ff_shexp > 0) {
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0);
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0);
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0);
}
}
}
} break;
@@ -4264,7 +4278,7 @@ uint64_t llama_model::n_elements() const {
}
void llama_model::print_info() const {
const char * rope_scaling_type = LLAMA_ROPE_SCALING_TYPES.at(hparams.rope_scaling_type_train);
const std::string rope_scaling_type = llama_rope_scaling_type_name(hparams.rope_scaling_type_train);
auto print_f = [](const std::function<uint32_t(uint32_t)> & f, uint32_t n) {
bool is_var = false;
@@ -4325,7 +4339,7 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: causal attn = %d\n", __func__, hparams.causal_attn);
LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type);
LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type);
LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type);
LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str());
LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train);
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn);
@@ -4381,10 +4395,13 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
}
if (arch == LLM_ARCH_MINICPM || arch == LLM_ARCH_GRANITE || arch == LLM_ARCH_GRANITE_MOE) {
if (arch == LLM_ARCH_MINICPM ||
arch == LLM_ARCH_GRANITE ||
arch == LLM_ARCH_GRANITE_MOE) {
LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale);
LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale);
LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale);
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
}
if (arch == LLM_ARCH_BAILINGMOE) {
@@ -4594,11 +4611,6 @@ struct llm_build_llama : public llm_graph_context {
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
// For Granite architecture
if (hparams.f_residual_scale) {
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
}
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
@@ -4670,11 +4682,6 @@ struct llm_build_llama : public llm_graph_context {
cb(cur, "ffn_moe_out", il);
}
// For Granite architecture
if (hparams.f_residual_scale) {
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
}
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "ffn_out", il);
@@ -4697,11 +4704,6 @@ struct llm_build_llama : public llm_graph_context {
// lm_head
cur = build_lora_mm(model.output, cur);
// For Granite architecture
if (hparams.f_logit_scale) {
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
}
cb(cur, "result_output", -1);
res->t_logits = cur;
@@ -4812,11 +4814,6 @@ struct llm_build_deci : public llm_graph_context {
continue;
}
// For Granite architecture
if (hparams.f_residual_scale) {
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
}
// modified to support attention-free layer of Llama-3_1-Nemotron-51B
ggml_tensor * ffn_inp = cur;
if (n_head > 0) {
@@ -4840,11 +4837,6 @@ struct llm_build_deci : public llm_graph_context {
cb(cur, "ffn_out", il);
}
// For Granite architecture
if (hparams.f_residual_scale) {
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
}
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "ffn_out", il);
@@ -4867,11 +4859,6 @@ struct llm_build_deci : public llm_graph_context {
// lm_head
cur = build_lora_mm(model.output, cur);
// For Granite architecture
if (hparams.f_logit_scale) {
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
}
cb(cur, "result_output", -1);
res->t_logits = cur;
@@ -12210,6 +12197,194 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
}
};
struct llm_build_granite : public llm_graph_context {
llm_build_granite(
const llama_model & model,
const llm_graph_params & params,
ggml_cgraph * gf,
const bool use_rope = true)
: llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;
inpL = build_inp_embd(model.tok_embd);
// inp_pos - built only if rope enabled
ggml_tensor * inp_pos = nullptr;
if (use_rope) {
inp_pos = build_inp_pos();
}
auto * inp_attn = build_attn_inp_kv_unified();
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
// norm
cur = build_norm(inpL,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
// self-attention
{
// compute Q and K and (optionally) RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
cb(Qcur, "Qcur", il);
}
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
cb(Kcur, "Kcur", il);
}
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
cb(Vcur, "Vcur", il);
}
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
if (use_rope) {
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
}
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
cur = build_attn(inp_attn, gf,
model.layers[il].wo, model.layers[il].bo,
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
cb(cur, "attn_out", il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
// For Granite architectures - scale residual
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
// feed-forward network (non-MoE)
if (model.layers[il].ffn_gate_inp == nullptr) {
cur = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
cur = build_ffn(cur,
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
} else {
// MoE branch
cur = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
ggml_tensor * moe_out = build_moe_ffn(cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
nullptr,
n_expert, n_expert_used,
LLM_FFN_SILU, true,
false, 0.0,
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
il);
cb(moe_out, "ffn_moe_out", il);
// For Granite MoE Shared
if (hparams.n_ff_shexp > 0) {
ggml_tensor * ffn_shexp = build_ffn(cur,
model.layers[il].ffn_up_shexp, NULL, NULL,
model.layers[il].ffn_gate_shexp, NULL, NULL,
model.layers[il].ffn_down_shexp, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(ffn_shexp, "ffn_shexp", il);
cur = ggml_add(ctx0, moe_out, ffn_shexp);
cb(cur, "ffn_out", il);
} else {
cur = moe_out;
}
}
// For Granite architectures - scale residual
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "ffn_out", il);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = build_norm(cur,
model.output_norm, NULL,
LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
// For Granite architectures - scale logits
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}
};
// ref: https://github.com/facebookresearch/chameleon
// based on the original build_llama() function, changes:
// * qk-norm
@@ -12917,8 +13092,6 @@ llm_graph_result_ptr llama_model::build_graph(
case LLM_ARCH_LLAMA:
case LLM_ARCH_LLAMA4:
case LLM_ARCH_MINICPM:
case LLM_ARCH_GRANITE:
case LLM_ARCH_GRANITE_MOE:
{
llm = std::make_unique<llm_build_llama>(*this, params, gf);
} break;
@@ -13149,6 +13322,11 @@ llm_graph_result_ptr llama_model::build_graph(
{
llm = std::make_unique<llm_build_arwkv7>(*this, params, gf);
} break;
case LLM_ARCH_GRANITE:
case LLM_ARCH_GRANITE_MOE:
{
llm = std::make_unique<llm_build_granite>(*this, params, gf);
} break;
case LLM_ARCH_CHAMELEON:
{
llm = std::make_unique<llm_build_chameleon>(*this, params, gf);
+2
View File
@@ -96,6 +96,8 @@ enum llm_type {
LLM_TYPE_235B_A22B,
};
std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type);
struct llama_layer_posnet {
// resnet
struct ggml_tensor * norm1 = nullptr;
+15 -13
View File
@@ -14,6 +14,12 @@
#include <thread>
#include <unordered_map>
// Quantization types. Changes to this struct must be replicated in quantize.cpp
struct tensor_quantization {
std::string name;
ggml_type quant = GGML_TYPE_COUNT;
};
static void zeros(std::ofstream & file, size_t n) {
char zero = 0;
for (size_t i = 0; i < n; ++i) {
@@ -48,12 +54,6 @@ struct quantize_state_impl {
{}
};
// changes to this struct must be replicated in quantize.cpp
struct tensor_quantization {
std::string name;
ggml_type quant = GGML_TYPE_COUNT;
};
static void llama_tensor_dequantize_impl(
ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers,
const size_t nelements, const int nthread
@@ -519,7 +519,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
nthread = std::thread::hardware_concurrency();
}
// mmap consistently increases speed Linux, and also increases speed on Windows with
// mmap consistently increases speed on Linux, and also increases speed on Windows with
// hot cache. It may cause a slowdown on macOS, possibly related to free memory.
#if defined(__linux__) || defined(_WIN32)
constexpr bool use_mmap = true;
@@ -529,7 +529,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
llama_model_kv_override * kv_overrides = nullptr;
if (params->kv_overrides) {
auto v = (std::vector<llama_model_kv_override>*)params->kv_overrides;
auto * v = (std::vector<llama_model_kv_override>*)params->kv_overrides;
kv_overrides = v->data();
}
@@ -796,17 +796,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
// unless the user specifies a type
if (params->tensor_types) {
const std::vector<tensor_quantization> & tensor_types = *static_cast<const std::vector<tensor_quantization> *>(params->tensor_types);
const std::string tensor_name(tensor->name);
for (const auto & [tname, qtype] : tensor_types) {
if (std::regex pattern(tname); std::regex_search(tensor->name, pattern)) {
if (qtype != new_type) {
LLAMA_LOG_DEBUG("(overriding %s -> %s), ", ggml_type_name(new_type), ggml_type_name(qtype));
if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) {
if (qtype != new_type) {
LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type));
new_type = qtype;
break; // if two or more types are specified for the tensor, first match wins
}
new_type = qtype;
break;
}
}
}
}
if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) {
new_type = params->token_embedding_type;
}
+31 -4
View File
@@ -1,5 +1,7 @@
#include "llama-vocab.h"
#include "ggml.h"
#include "gguf.h"
#include "llama-impl.h"
#include "llama-model-loader.h"
@@ -1234,6 +1236,9 @@ struct fragment_buffer_variant {
struct llama_vocab::impl {
uint32_t n_token_types = 0; // for BERT-style token types
std::string tokenizer_model;
std::string tokenizer_pre;
enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
enum llama_vocab_pre_type pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
@@ -1369,9 +1374,6 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
// determine vocab type
{
std::string tokenizer_model;
std::string tokenizer_pre;
ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model);
ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
@@ -1466,7 +1468,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str());
if (precompiled_charsmap_keyidx != -1) {
size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx);
const gguf_type pc_type = gguf_get_arr_type(ctx, precompiled_charsmap_keyidx);
GGML_ASSERT(pc_type == GGUF_TYPE_INT8 || pc_type == GGUF_TYPE_UINT8);
const size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx);
const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx);
precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap);
#ifdef IS_BIG_ENDIAN
@@ -2789,6 +2794,14 @@ void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) {
pimpl->load(ml, kv);
}
std::string llama_vocab::get_tokenizer_model() const {
return pimpl->tokenizer_model;
}
std::string llama_vocab::get_tokenizer_pre() const {
return pimpl->tokenizer_pre;
}
enum llama_vocab_type llama_vocab::get_type() const {
return pimpl->type;
}
@@ -3011,6 +3024,20 @@ int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string
return it->second;
}
std::vector<std::string> llama_vocab::get_bpe_merges() const {
std::vector<std::string> result(pimpl->bpe_ranks.size());
for (const auto & pair : pimpl->bpe_ranks) {
result[pair.second] = pair.first.first + " " + pair.first.second;
}
return result;
}
std::vector<char> llama_vocab::get_precompiled_charsmap() const {
return pimpl->precompiled_charsmap;
}
int32_t llama_vocab::tokenize(
const char * text,
int32_t text_len,
+6
View File
@@ -21,6 +21,9 @@ struct llama_vocab {
void load(llama_model_loader & ml, const LLM_KV & kv);
std::string get_tokenizer_model() const;
std::string get_tokenizer_pre() const;
enum llama_vocab_type get_type() const;
enum llama_vocab_pre_type get_pre_type() const;
@@ -80,6 +83,9 @@ struct llama_vocab {
int max_token_len() const;
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
std::vector<std::string> get_bpe_merges() const;
std::vector<char> get_precompiled_charsmap() const;
int32_t tokenize(
const char * text,
+14
View File
@@ -4,6 +4,7 @@
#include "llama-mmap.h"
#include "llama-vocab.h"
#include "llama-model-loader.h"
#include "llama-model-saver.h"
#include "llama-model.h"
#include "ggml.h"
@@ -139,6 +140,11 @@ static struct llama_model * llama_model_load_from_file_impl(
struct llama_model_params params) {
ggml_time_init();
if (!params.vocab_only && ggml_backend_reg_count() == 0) {
LLAMA_LOG_ERROR("%s: no backends are loaded. hint: use ggml_backend_load() or ggml_backend_load_all() to load a backend before calling this function\n", __func__);
return nullptr;
}
unsigned cur_percentage = 0;
if (params.progress_callback == NULL) {
params.progress_callback_user_data = &cur_percentage;
@@ -253,6 +259,13 @@ struct llama_model * llama_model_load_from_splits(
return llama_model_load_from_file_impl(splits.front(), splits, params);
}
void llama_model_save_to_file(const struct llama_model * model, const char * path_model) {
llama_model_saver ms(*model);
ms.add_kv_from_model();
ms.add_tensors_from_model();
ms.save(path_model);
}
//
// chat templates
//
@@ -338,3 +351,4 @@ const char * llama_print_system_info(void) {
return s.c_str();
}
+1
View File
@@ -144,6 +144,7 @@ endif()
llama_build_and_test(test-log.cpp)
llama_build_and_test(test-chat-template.cpp)
llama_build_and_test(test-regex-partial.cpp)
# this fails on windows (github hosted runner) due to curl DLL not found (exit code 0xc0000135)
if (NOT WIN32)
+50 -43
View File
@@ -823,7 +823,7 @@ struct test_case {
ggml_build_forward_expand(gf, out);
ggml_graph_cpy(gf, gb);
ggml_build_backward_expand(ctx.get(), ctx.get(), gb, false);
ggml_build_backward_expand(ctx.get(), gb, nullptr);
if (expect.size() != 1 || expect[0] != 0.0f) {
GGML_ASSERT(ggml_graph_n_nodes(gb) > ggml_graph_n_nodes(gf));
for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) {
@@ -1026,7 +1026,7 @@ struct test_example : public test_case {
// Step 3: return the output tensor.
return out;
}
// In order to also check the gradients for your op, add calls like ggml_set_param(ctx, a)
// In order to also check the gradients for your op, add calls like ggml_set_param(a)
// immediately after you create the tensors.
// This is optional and only makes sense if a backward pass has actually been implemented for the new op.
};
@@ -1058,7 +1058,7 @@ struct test_unary : public test_case {
auto ne = ne_a; ne[0] *= 3;
a = ggml_new_tensor(ctx, type, 4, ne.data());
if (grad_supported) {
ggml_set_param(ctx, a);
ggml_set_param(a);
}
ggml_set_name(a, "a");
@@ -1067,7 +1067,7 @@ struct test_unary : public test_case {
} else {
a = ggml_new_tensor(ctx, type, 4, ne_a.data());
if (grad_supported) {
ggml_set_param(ctx, a);
ggml_set_param(a);
}
ggml_set_name(a, "a");
}
@@ -1133,7 +1133,7 @@ struct test_get_rows : public test_case {
const bool grad_supported = ggml_is_matrix(in) && ggml_is_vector(rows);
if (grad_supported) {
ggml_set_param(ctx, in);
ggml_set_param(in);
// rows is a constant input -> no gradients
}
@@ -1322,7 +1322,7 @@ struct test_repeat : public test_case {
ggml_set_name(target, "target");
ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, src);
ggml_set_param(src);
ggml_set_name(src, "src");
ggml_tensor * out = ggml_repeat(ctx, src, target);
@@ -1406,7 +1406,7 @@ struct test_dup : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, src);
ggml_set_param(src);
ggml_set_name(src, "src");
if (_use_permute) {
@@ -1442,7 +1442,7 @@ struct test_set : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());
ggml_set_param(ctx, src);
ggml_set_param(src);
ggml_set_name(src, "src");
auto ne_dst = ne;
@@ -1450,7 +1450,7 @@ struct test_set : public test_case {
ne_dst[i] *= 2;
}
ggml_tensor* dst = ggml_new_tensor(ctx, type_dst, 4, ne_dst.data());
ggml_set_param(ctx, dst);
ggml_set_param(dst);
ggml_set_name(dst, "dst");
size_t offset = 0;
@@ -1498,7 +1498,7 @@ struct test_cpy : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());
ggml_set_param(ctx, src);
ggml_set_param(src);
ggml_set_name(src, "src");
if (_src_use_permute) {
@@ -1536,7 +1536,7 @@ struct test_cont : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, src);
ggml_set_param(src);
ggml_set_name(src, "src");
src = ggml_transpose(ctx, src);
@@ -1583,8 +1583,8 @@ struct test_bin_bcast : public test_case {
// The backward pass supports broadcasting only for GGML_ADD:
const bool grad_supported = op == ggml_add || ggml_are_same_shape(a, b);
if (grad_supported) {
ggml_set_param(ctx, a);
ggml_set_param(ctx, b);
ggml_set_param(a);
ggml_set_param(b);
}
ggml_tensor * out = op(ctx, a, b);
@@ -1632,11 +1632,11 @@ struct test_add1 : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, a);
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * b = ggml_new_tensor_1d(ctx, type, 1);
// ggml_set_param(ctx, b); // TODO: implement
// ggml_set_param(b); // TODO: implement
ggml_set_name(b, "b");
ggml_tensor * out = ggml_add1(ctx, a, b);
@@ -1667,7 +1667,7 @@ struct test_scale : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, a);
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * out = ggml_scale(ctx, a, scale);
@@ -1762,7 +1762,7 @@ struct test_rms_norm : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, a);
ggml_set_param(a);
ggml_set_name(a, "a");
if (v) {
@@ -2028,9 +2028,9 @@ struct test_mul_mat : public test_case {
b = ggml_new_tensor_4d(ctx, type_b, ne_b[per[0]], ne_b[per[1]], ne_b[per[2]], ne_b[per[3]]);
if (!ggml_is_quantized(type_a)) {
if (bs[1] == 1 && nr[1] == 1) {
ggml_set_param(ctx, a);
ggml_set_param(a);
}
ggml_set_param(ctx, b);
ggml_set_param(b);
}
ggml_set_name(a, "a");
ggml_set_name(b, "b");
@@ -2040,22 +2040,29 @@ struct test_mul_mat : public test_case {
ggml_set_name(a, "a_permuted");
ggml_set_name(b, "b_permuted");
} else {
if (v) {
a = ggml_new_tensor_4d(ctx, type_a, k*2, m, bs[0], bs[1]);
b = ggml_new_tensor_4d(ctx, type_b, k*2, n, bs[0]*nr[0], bs[1]*nr[1]);
if (!ggml_is_quantized(type_a)) {
if (bs[1] == 1 && nr[1] == 1) {
ggml_set_param(a);
}
ggml_set_param(b);
}
a = ggml_view_4d(ctx, a, k, m, bs[0], bs[1], a->nb[1], a->nb[2], a->nb[3], 0);
b = ggml_view_4d(ctx, b, k, n, bs[0]*nr[0], bs[1]*nr[1], b->nb[1], b->nb[2], b->nb[3], 0);
} else {
a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]);
b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
}
if (!ggml_is_quantized(type_a)) {
if (bs[1] == 1 && nr[1] == 1) {
ggml_set_param(ctx, a);
if (!ggml_is_quantized(type_a)) {
if (bs[1] == 1 && nr[1] == 1) {
ggml_set_param(a);
}
ggml_set_param(b);
}
ggml_set_param(ctx, b);
}
ggml_set_name(a, "a");
ggml_set_name(b, "b");
@@ -2204,7 +2211,7 @@ struct test_sqr : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, a);
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * out = ggml_sqr(ctx, a);
@@ -2233,7 +2240,7 @@ struct test_sqrt : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, a);
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * out = ggml_sqrt(ctx, a);
@@ -2273,7 +2280,7 @@ struct test_log : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, a);
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * out = ggml_log(ctx, a);
@@ -2309,7 +2316,7 @@ struct test_sin : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, a);
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * out = ggml_sin(ctx, a);
@@ -2352,7 +2359,7 @@ struct test_cos : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, a);
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * out = ggml_cos(ctx, a);
@@ -2432,7 +2439,7 @@ struct test_diag_mask_inf : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, a);
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * out = ggml_diag_mask_inf(ctx, a, n_past);
@@ -2471,7 +2478,7 @@ struct test_soft_max : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, a);
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * mask = nullptr;
@@ -2553,7 +2560,7 @@ struct test_rope : public test_case {
auto ne = ne_a; ne[0] *= 2; ne[1] *= 4; ne[2] *= 3;
a = ggml_new_tensor(ctx, type, 4, ne.data());
if (forward) {
ggml_set_param(ctx, a);
ggml_set_param(a);
}
ggml_set_name(a, "a");
@@ -2562,7 +2569,7 @@ struct test_rope : public test_case {
} else {
a = ggml_new_tensor(ctx, type, 4, ne_a.data());
if (forward) {
ggml_set_param(ctx, a);
ggml_set_param(a);
}
ggml_set_name(a, "a");
}
@@ -2676,7 +2683,7 @@ struct test_pool2d : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());
ggml_set_param(ctx, input);
ggml_set_param(input);
ggml_set_name(input, "input");
ggml_tensor * out = ggml_pool_2d(ctx, input, pool_type, k0, k1, s0, s1, p0, p1);
@@ -2752,7 +2759,7 @@ struct test_im2col : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());
ggml_set_param(ctx, input);
ggml_set_param(input);
ggml_set_name(input, "input");
ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data());
@@ -2929,7 +2936,7 @@ struct test_sum : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, a);
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * out = ggml_sum(ctx, a);
@@ -2958,7 +2965,7 @@ struct test_sum_rows : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, a);
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * out = ggml_sum_rows(ctx, a);
@@ -2983,7 +2990,7 @@ struct test_mean : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, a);
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * out = ggml_mean(ctx, a);
@@ -3129,11 +3136,11 @@ struct test_acc : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
ggml_set_param(ctx, a);
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data());
ggml_set_param(ctx, b);
ggml_set_param(b);
ggml_set_name(b, "b");
ggml_tensor * out = ggml_acc(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], b->nb[1]);
@@ -3370,7 +3377,7 @@ struct test_cross_entropy_loss : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * logits = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, logits);
ggml_set_param(logits);
ggml_set_name(logits, "logits");
ggml_tensor * labels = ggml_new_tensor(ctx, type, 4, ne.data());
@@ -3452,7 +3459,7 @@ struct test_opt_step_adamw : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
ggml_set_param(ctx, a); // Despite tensor a having gradients the output tensor will not.
ggml_set_param(a); // Despite tensor a having gradients the output tensor will not.
ggml_set_name(a, "a");
ggml_tensor * grad = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
+3 -1
View File
@@ -832,7 +832,9 @@ static void test_template_output_parsers() {
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY,
common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
common_chat_templates_apply(tmpls.get(), inputs_tools).format);
common_chat_templates_apply(tmpls.get(), inputs_tools).format);
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY,
common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
+33 -21
View File
@@ -57,7 +57,8 @@ static helper_ctx_data helper_get_ctx_data(
enum ggml_opt_loss_type loss_type = GGML_OPT_LOSS_TYPE_SUM) {
std::vector<ggml_opt_dataset_t> datasets(ndata);
for (int64_t ndata_shard = 1; ndata_shard <= ndata; ++ndata_shard) {
ggml_opt_dataset_t dataset = ggml_opt_dataset_init(ne_datapoint, ne_label, ndata, ndata_shard);
ggml_opt_dataset_t dataset = ggml_opt_dataset_init(
GGML_TYPE_F32, GGML_TYPE_F32, ne_datapoint, ne_label, ndata, ndata_shard);
float * data = ggml_get_data_f32(ggml_opt_dataset_data( dataset));
float * labels = ggml_get_data_f32(ggml_opt_dataset_labels(dataset));
@@ -74,7 +75,8 @@ static helper_ctx_data helper_get_ctx_data(
datasets[ndata_shard-1] = dataset;
}
ggml_opt_dataset_t dataset_unsupervised = ggml_opt_dataset_init(1, 0, ndata, /*ndata_shard =*/ 1);
ggml_opt_dataset_t dataset_unsupervised = ggml_opt_dataset_init(
GGML_TYPE_F32, GGML_TYPE_F32, 1, 0, ndata, /*ndata_shard =*/ 1);
float * data = ggml_get_data_f32(ggml_opt_dataset_data(dataset_unsupervised));
@@ -113,7 +115,7 @@ static helper_ctx_data helper_get_ctx_data(
struct ggml_tensor * weights = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, 1);
ggml_set_name(weights, "weights");
ggml_set_param(ctx_static, weights);
ggml_set_param(weights);
struct ggml_tensor * intermediary = ggml_add(ctx_compute, inputs, weights);
@@ -127,8 +129,11 @@ static helper_ctx_data helper_get_ctx_data(
GGML_ASSERT(nbatch_logical % nbatch_physical == 0);
const int32_t opt_period = nbatch_logical / nbatch_physical;
struct ggml_opt_params opt_params = ggml_opt_default_params(backend_sched, ctx_compute, inputs, outputs, loss_type);
opt_params.opt_period = opt_period;
struct ggml_opt_params opt_params = ggml_opt_default_params(backend_sched, loss_type);
opt_params.ctx_compute = ctx_compute;
opt_params.inputs = inputs;
opt_params.outputs = outputs;
opt_params.opt_period = opt_period;
if (!optimizer_defaults) {
opt_params.get_opt_pars = helper_get_test_opt_pars;
}
@@ -264,8 +269,9 @@ static std::pair<int, int> test_grad(ggml_backend_sched_t backend_sched, ggml_ba
for (int idata = 0; idata < ndata; ++idata) {
const float idataf = idata;
ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);
ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
ggml_opt_forward_backward(cd.opt_ctx, cd.result);
ggml_opt_eval(cd.opt_ctx, cd.result);
ggml_backend_tensor_get(ggml_opt_grad_acc(cd.opt_ctx, cd.weights), grad_history.data() + idata, 0, sizeof(float));
}
@@ -334,8 +340,9 @@ static std::pair<int, int> test_forward_backward(
} else {
for (int idata = 0; idata < ndata; ++idata) {
const float idataf = idata;
ggml_opt_alloc(cd.opt_ctx, /*backward =*/ false);
ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
ggml_opt_forward(cd.opt_ctx, cd.result);
ggml_opt_eval(cd.opt_ctx, cd.result);
ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float));
}
}
@@ -367,7 +374,8 @@ static std::pair<int, int> test_forward_backward(
float w0;
ggml_backend_tensor_get(cd.weights, &w0, 0, sizeof(float));
for (int i = 0; i < 10; ++i) {
ggml_opt_forward_backward(cd.opt_ctx, nullptr);
ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);
ggml_opt_eval(cd.opt_ctx, cd.result);
}
ggml_backend_tensor_set(cd.weights, &w0, 0, sizeof(float));
@@ -387,8 +395,9 @@ static std::pair<int, int> test_forward_backward(
} else {
for (int idata = 0; idata < ndata; ++idata) {
const float idataf = idata;
ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);
ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
ggml_opt_forward_backward(cd.opt_ctx, cd.result);
ggml_opt_eval(cd.opt_ctx, cd.result);
ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float));
}
}
@@ -492,14 +501,16 @@ static std::pair<int, int> test_idata_split(ggml_backend_sched_t backend_sched,
int idata = 0;
for (; idata < idata_split; ++idata) {
const float idataf = idata;
ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);
ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
ggml_opt_forward_backward(cd.opt_ctx, cd.result);
ggml_opt_eval(cd.opt_ctx, cd.result);
ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float));
}
for (; idata < ndata; ++idata) {
const float idataf = idata;
ggml_opt_alloc(cd.opt_ctx, /*backward =*/ false);
ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
ggml_opt_forward(cd.opt_ctx, cd.result2);
ggml_opt_eval(cd.opt_ctx, cd.result2);
ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float));
}
}
@@ -573,7 +584,6 @@ static std::pair<int, int> test_gradient_accumulation(
struct helper_ctx_data cd = helper_get_ctx_data(
backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false, /*nbatch_logical =*/ 6, nbatch_physical, loss_type);
struct ggml_tensor * loss = ggml_opt_loss(cd.opt_ctx);
std::vector<float> grad_history(ndata);
for (int64_t idata = 0; idata < ndata; ++idata) {
@@ -584,15 +594,17 @@ static std::pair<int, int> test_gradient_accumulation(
if (nbatch_physical == 1) {
for (int idata = 0; idata < ndata; ++idata) {
const float idataf = idata;
ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);
ggml_backend_tensor_set(cd.inputs, &idataf, 0, 1*sizeof(float));
ggml_opt_forward_backward(cd.opt_ctx, cd.result);
ggml_opt_eval(cd.opt_ctx, cd.result);
ggml_backend_tensor_get(ggml_opt_grad_acc(cd.opt_ctx, cd.weights), grad_history.data() + idata, 0, 1*sizeof(float));
}
} else if (nbatch_physical == 2) {
for (int idata = 0; idata < ndata; idata += 2) {
const float idataf[2] = {float(idata + 0), float(idata + 1)};
ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);
ggml_backend_tensor_set(cd.inputs, idataf, 0, 2*sizeof(float));
ggml_opt_forward_backward(cd.opt_ctx, cd.result);
ggml_opt_eval(cd.opt_ctx, cd.result);
grad_history[idata + 0] = 0.0f;
ggml_backend_tensor_get(ggml_opt_grad_acc(cd.opt_ctx, cd.weights), grad_history.data() + idata + 1, 0, 1*sizeof(float));
@@ -617,7 +629,7 @@ static std::pair<int, int> test_gradient_accumulation(
}
subtest_ok = subtest_ok && almost_equal(grad_history[1], 2.0, atol);
subtest_ok = subtest_ok && almost_equal(grad_history[3], 4.0, atol);
subtest_ok = subtest_ok && almost_equal(grad_history[5], 0.0, atol);
subtest_ok = subtest_ok && almost_equal(grad_history[5], 6.0, atol);
} else if (loss_type == GGML_OPT_LOSS_TYPE_MEAN) {
if (nbatch_physical == 1) {
subtest_ok = subtest_ok && almost_equal(grad_history[0], 1.0/ndata, atol);
@@ -630,7 +642,7 @@ static std::pair<int, int> test_gradient_accumulation(
}
subtest_ok = subtest_ok && almost_equal(grad_history[1], 2.0/ndata, atol);
subtest_ok = subtest_ok && almost_equal(grad_history[3], 4.0/ndata, atol);
subtest_ok = subtest_ok && almost_equal(grad_history[5], 0.0/ndata, atol);
subtest_ok = subtest_ok && almost_equal(grad_history[5], 6.0/ndata, atol);
} else {
GGML_ASSERT(false);
}
@@ -692,7 +704,8 @@ static std::pair<int, int> test_regression(ggml_backend_sched_t backend_sched, g
std::mt19937 gen(12345);
std::normal_distribution<float> nd{0.0f, 0.1f};
ggml_opt_dataset_t dataset = ggml_opt_dataset_init(1, 1, ndata_regression, ndata_regression);
ggml_opt_dataset_t dataset = ggml_opt_dataset_init(
GGML_TYPE_F32, GGML_TYPE_F32, 1, 1, ndata_regression, ndata_regression);
float * data = ggml_get_data_f32(ggml_opt_dataset_data( dataset));
float * labels = ggml_get_data_f32(ggml_opt_dataset_labels(dataset));
@@ -733,15 +746,14 @@ static std::pair<int, int> test_regression(ggml_backend_sched_t backend_sched, g
struct ggml_tensor * a = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, 1);
ggml_set_name(a, "a");
ggml_set_param(ctx_static, a);
ggml_set_param(a);
struct ggml_tensor * b = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, 1);
ggml_set_name(b, "b");
ggml_set_param(ctx_static, b);
ggml_set_param(b);
struct ggml_tensor * f = ggml_add(ctx_compute, ggml_mul(ctx_compute, x, a), b);
ggml_set_name(f, "f");
ggml_set_param(ctx_static, f);
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx_static, backend);
const float a0 = 1.0f;
@@ -853,7 +865,7 @@ int main(void) {
backends_modded.insert(backends_modded.end(), backends.begin(), backends.end());
ggml_backend_sched_t backend_sched = ggml_backend_sched_new(
backends_modded.data(), nullptr, backends_modded.size(), GGML_DEFAULT_GRAPH_SIZE, false);
backends_modded.data(), nullptr, backends_modded.size(), GGML_DEFAULT_GRAPH_SIZE, false, true);
printf("Backend %zu/%zu: %s\n", i + 1, dev_count, ggml_backend_dev_name(devs[i]));
printf(" Device description: %s\n", ggml_backend_dev_description(devs[i]));
+288
View File
@@ -0,0 +1,288 @@
// Tests common_regex (esp. its partial final matches support).
#include "common.h"
#include "regex-partial.h"
#include <sstream>
#include <iostream>
#include <optional>
template <class T> static void assert_equals(const T & expected, const T & actual) {
if (expected != actual) {
std::cerr << "Expected: " << expected << std::endl;
std::cerr << " Actual: " << actual << std::endl;
std::cerr << std::flush;
throw std::runtime_error("Test failed");
}
}
struct test_case {
std::string pattern;
struct input_output {
std::string input;
common_regex_match output;
};
std::vector<input_output> inputs_outputs;
};
static std::string common_regex_match_type_name(common_regex_match_type type) {
switch (type) {
case COMMON_REGEX_MATCH_TYPE_NONE:
return "COMMON_REGEX_MATCH_TYPE_NONE";
case COMMON_REGEX_MATCH_TYPE_PARTIAL:
return "COMMON_REGEX_MATCH_TYPE_PARTIAL";
case COMMON_REGEX_MATCH_TYPE_FULL:
return "COMMON_REGEX_MATCH_TYPE_FULL";
}
return "?";
}
static void test_regex() {
printf("[%s]\n", __func__);
auto test = [](const test_case & test_case) {
common_regex cr(test_case.pattern);
std::cout << "Testing pattern: /" << test_case.pattern << "/\n";
// std::cout << " partial rev: " << cr.reversed_partial_pattern.str() << '\n';
for (const auto & input_output : test_case.inputs_outputs) {
std::cout << " Input: " << input_output.input << '\n';
auto m = cr.search(input_output.input, 0);
if (m != input_output.output) {
auto match_to_str = [&](const std::optional<common_regex_match> & m) {
std::ostringstream ss;
if (m->type == COMMON_REGEX_MATCH_TYPE_NONE) {
ss << "<no match>";
} else {
GGML_ASSERT(!input_output.output.groups.empty());
std::vector<std::string> parts;
for (const auto & g : m->groups) {
parts.push_back("{" + std::to_string(g.begin) + ", " + std::to_string(g.end) + "}");
}
ss << "{" << common_regex_match_type_name(m->type) << ", {" << string_join(parts, ", ") << "}}";
}
return ss.str();
};
std::cout << " Expected: " << match_to_str(input_output.output) << '\n';
std::cout << " Got: " << match_to_str(m) << '\n';
std::cout << " Inverted pattern: /" << regex_to_reversed_partial_regex(test_case.pattern) << "/\n";
throw std::runtime_error("Test failed");
}
}
};
test({
"a",
{
{"a", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 1}}}},
{"b", {COMMON_REGEX_MATCH_TYPE_NONE, {}}},
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 1}}}},
{"ba", {COMMON_REGEX_MATCH_TYPE_FULL, {{1, 2}}}},
}
});
test({
"abcd",
{
{"abcd", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
{"abcde", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
{"abc", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
{"d", {}},
{"bcd", {}},
{"cde", {}},
{"cd", {}},
{"yeah ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{5, 7}}}},
{"abbie", {}},
{"", {}},
}
});
test({
".*?ab",
{
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
{"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
{"dab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
{"dabc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
{"da", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
{"d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
}
});
test({
"a.*?b",
{
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
{"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
{"a b", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
{"argh", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
{"d", {}},
{"b", {}},
}
});
test({
"ab(?:cd){2,4}ef",
{
// {"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, 0, {}}},
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
{"abcd", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
{"abcde", {}},
{"abcdef", {}},
{"abcdcd", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
{"abcdcde", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 7}}}},
{"abcdcdef", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}}}},
{"abcdcdcdcdef", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 12}}}},
{"abcdcdcdcdcdef", {}},
{"abcde", {}},
{"yea", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{2, 3}}}},
}
});
test({
"a(?:rte| pure )fact",
{
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
{"art", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
{"artefa", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
{"fact", {}},
{"an arte", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{3, 7}}}},
{"artefact", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}}}},
{"an artefact", {COMMON_REGEX_MATCH_TYPE_FULL, {{3, 11}}}},
{"a pure", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
{"a pure fact", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 11}}}},
{"it's a pure fact", {COMMON_REGEX_MATCH_TYPE_FULL, {{5, 16}}}},
{"" , {}},
{"pure", {}},
{"pure fact", {}},
}
});
test({
"abc",
{
{" abcc", {COMMON_REGEX_MATCH_TYPE_FULL, {{1, 4}}}},
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
{"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
{" ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{1, 3}}}},
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
{"b", {}},
{"c", {}},
{"", {}},
}
});
test({
"(?:abc)?\\s*def",
{
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
{"abc", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
{"abc ", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
{"abc d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 5}}}},
{"abc de", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
{"abc def", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
{"abc defg", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
{"abc defgh", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
{"abcde", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 5}}}},
{"abcdefgh", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 6}}}},
{" d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
{"def", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
}
});
test({
"a+b",
{
{"aaab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
{"aaa", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
}
});
test({
"(?:"
"(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start)
"(" // match 2 (open_tag)
"<tool_call>"
"|<function_call>"
"|<tool>"
"|<tools>"
"|<response>"
"|<json>"
"|<xml>"
"|<JSON>"
")?"
"(\\s*\\{\\s*\"name\"\\s*:)" // match 3 (named tool call)
")"
"|<function=([^>]+)>" // match 4 (function name)
"|<function name=\"([^\"]+)\">", // match 5 (function name again)
{
{"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}, {54, 54}, {54, 54}, {0, 8}, {54, 54}, {54, 54}}}},
{"<tool_call> {\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 18}}}},
{"<tool_call>{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 17}}}},
{"Let's call something\n<tool_call>{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{21, 38}}}},
{"Ok then<tool_call>{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{7, 24}}}},
{"{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
{"Ok then{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{7, 13}}}},
{"<tool_call> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 20}, {66, 66}, {0, 11}, {11, 20}, {66, 66}, {66, 66}}}},
{"<function_call> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 24}, {70, 70}, {0, 15}, {15, 24}, {70, 70}, {70, 70}}}},
{"<function name=\"special_function\"> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 34}, {89, 89}, {89, 89}, {89, 89}, {89, 89}, {16, 32}}}},
{"<function=all>", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 14}, {14, 14}, {14, 14}, {14, 14}, {10, 13}, {14, 14}}}},
}
});
}
static void test_regex_to_reversed_partial_regex() {
printf("[%s]\n", __func__);
assert_equals<std::string>(
"((?:(?:c)?b)?a)[\\s\\S]*",
regex_to_reversed_partial_regex("abc"));
assert_equals<std::string>(
"(a+)[\\s\\S]*",
regex_to_reversed_partial_regex("a+"));
assert_equals<std::string>(
"(a*)[\\s\\S]*",
regex_to_reversed_partial_regex("a*"));
assert_equals<std::string>(
"(a?)[\\s\\S]*",
regex_to_reversed_partial_regex("a?"));
assert_equals<std::string>(
"([a-z])[\\s\\S]*",
regex_to_reversed_partial_regex("[a-z]"));
assert_equals<std::string>(
"((?:\\w+)?[a-z])[\\s\\S]*",
regex_to_reversed_partial_regex("[a-z]\\w+"));
assert_equals<std::string>(
"((?:a|b))[\\s\\S]*",
regex_to_reversed_partial_regex("(?:a|b)"));
assert_equals<std::string>(
"((?:(?:(?:d)?c)?b)?a)[\\s\\S]*",
regex_to_reversed_partial_regex("abcd"));
assert_equals<std::string>(
"((?:b)?a*)[\\s\\S]*", // TODO: ((?:b)?a*+).* ??
regex_to_reversed_partial_regex("a*b"));
assert_equals<std::string>(
"((?:(?:b)?a)?.*)[\\s\\S]*",
regex_to_reversed_partial_regex(".*?ab"));
assert_equals<std::string>(
"((?:(?:b)?.*)?a)[\\s\\S]*",
regex_to_reversed_partial_regex("a.*?b"));
assert_equals<std::string>(
"((?:(?:d)?(?:(?:c)?b))?a)[\\s\\S]*",
regex_to_reversed_partial_regex("a(bc)d"));
assert_equals<std::string>(
"((?:(?:(?:c)?b|(?:e)?d))?a)[\\s\\S]*",
regex_to_reversed_partial_regex("a(bc|de)"));
assert_equals<std::string>(
"((?:(?:(?:(?:(?:c)?b?)?b?)?b)?b)?a)[\\s\\S]*",
regex_to_reversed_partial_regex("ab{2,4}c"));
}
int main() {
test_regex_to_reversed_partial_regex();
test_regex();
std::cout << "All tests passed.\n";
}
+2 -2
View File
@@ -123,8 +123,8 @@ int main(int argc, char ** argv) {
common_batch_clear(batch);
for (int i = 0; i < pp; ++i) {
for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) {
for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) {
for (int i = 0; i < pp; ++i) {
common_batch_add(batch, 0, i, { j }, false);
}
}
+20 -11
View File
@@ -20,10 +20,20 @@ Performance testing tool for llama.cpp.
## Syntax
```
usage: ./llama-bench [options]
usage: llama-bench [options]
options:
-h, --help
--numa <distribute|isolate|numactl> numa mode (default: disabled)
-r, --repetitions <n> number of times to repeat each test (default: 5)
--prio <0|1|2|3> process/thread priority (default: 0)
--delay <0...N> (seconds) delay between each test (default: 0)
-o, --output <csv|json|jsonl|md|sql> output format printed to stdout (default: md)
-oe, --output-err <csv|json|jsonl|md|sql> output format printed to stderr (default: none)
-v, --verbose verbose output
--progress print test progress indicators
test parameters:
-m, --model <filename> (default: models/7B/ggml-model-q4_0.gguf)
-p, --n-prompt <n> (default: 512)
-n, --n-gen <n> (default: 128)
@@ -33,28 +43,27 @@ options:
-ub, --ubatch-size <n> (default: 512)
-ctk, --cache-type-k <t> (default: f16)
-ctv, --cache-type-v <t> (default: f16)
-t, --threads <n> (default: 8)
-dt, --defrag-thold <f> (default: -1)
-t, --threads <n> (default: system dependent)
-C, --cpu-mask <hex,hex> (default: 0x0)
--cpu-strict <0|1> (default: 0)
--poll <0...100> (default: 50)
-ngl, --n-gpu-layers <n> (default: 99)
-rpc, --rpc <rpc_servers> (default: )
-rpc, --rpc <rpc_servers> (default: none)
-sm, --split-mode <none|layer|row> (default: layer)
-mg, --main-gpu <i> (default: 0)
-nkvo, --no-kv-offload <0|1> (default: 0)
-fa, --flash-attn <0|1> (default: 0)
-mmp, --mmap <0|1> (default: 1)
--numa <distribute|isolate|numactl> (default: disabled)
-embd, --embeddings <0|1> (default: 0)
-ts, --tensor-split <ts0/ts1/..> (default: 0)
-r, --repetitions <n> (default: 5)
--prio <0|1|2|3> (default: 0)
--delay <0...N> (seconds) (default: 0)
-o, --output <csv|json|jsonl|md|sql> (default: md)
-oe, --output-err <csv|json|jsonl|md|sql> (default: none)
-v, --verbose (default: 0)
-ot --override-tensors <tensor name pattern>=<buffer type>;...
(default: disabled)
-nopo, --no-op-offload <0|1> (default: 0)
Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.
Multiple values can be given for each parameter by separating them with ','
or by specifying the parameter multiple times. Ranges can be given as
'first-last' or 'first-last+step' or 'first-last*mult'.
```
llama-bench can perform three types of tests:
File diff suppressed because it is too large Load Diff
-35
View File
@@ -1,29 +1,3 @@
# llava (legacy)
add_library(llava OBJECT
llava.cpp
llava.h
clip.cpp
clip.h
)
target_link_libraries(llava PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT})
target_include_directories(llava PUBLIC .)
target_include_directories(llava PUBLIC ../..)
target_include_directories(llava PUBLIC ../../common)
target_compile_features(llava PRIVATE cxx_std_17)
add_library(llava_static STATIC $<TARGET_OBJECTS:llava>)
if (BUILD_SHARED_LIBS)
set_target_properties(llava PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_compile_definitions(llava PRIVATE LLAMA_SHARED LLAMA_BUILD)
add_library(llava_shared SHARED $<TARGET_OBJECTS:llava>)
target_link_libraries(llava_shared PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT})
install(TARGETS llava_shared LIBRARY)
endif()
# mtmd
add_library(mtmd OBJECT
@@ -53,12 +27,10 @@ if (BUILD_SHARED_LIBS)
endif()
if (NOT MSVC)
target_compile_options(llava PRIVATE -Wno-cast-qual) # stb_image.h
target_compile_options(mtmd PRIVATE -Wno-cast-qual) # stb_image.h
endif()
if(TARGET BUILD_INFO)
add_dependencies(llava BUILD_INFO)
add_dependencies(mtmd BUILD_INFO)
endif()
@@ -73,10 +45,3 @@ set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-mtmd-cli)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common mtmd ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)
set(TARGET llama-llava-clip-quantize-cli)
add_executable(${TARGET} clip-quantize-cli.cpp)
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-llava-clip-quantize-cli)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)
-44
View File
@@ -1,44 +0,0 @@
# Quantizing CLIP Visual Projector
This is the tool for quantizing the CLIP visual projector model. Quantization reduces the precision of the model's weights, which can significantly decrease the model size and improve inference speed, often with minimal impact on performance.
## Usage
To quantize a CLIP visual projector model, use the following command:
```sh
./bin/llama-llava-clip-quantize-cli /path/to/ggml-model-f32.gguf /path/to/ggml-model-quantized.gguf <type>
```
After the quantization, the visual projector can be used freely with the existing LLAVA cli (LLAVA, Qwen2VL, etc).
### Arguments
- `/path/to/ggml-model-f32.gguf`: The path to the input model file in FP32 or FP16 format.
- `/path/to/ggml-model-quantized.gguf`: The path where the quantized model will be saved.
- `<type>`: The quantization type to apply. This should be an integer corresponding to one of the quantization types defined in the `enum ggml_type`.
### Quantization Types
The following quantization types are supported, based on the `enum ggml_type` definition:
- `2` - `q4_0`: 4-bit quantization with a single scale value.
- `3` - `q4_1`: 4-bit quantization with a separate scale value for each block.
- `6` - `q5_0`: 5-bit quantization with a single scale value.
- `7` - `q5_1`: 5-bit quantization with a separate scale value for each block.
- `8` - `q8_0`: 8-bit quantization with a single scale value.
### Example
To quantize a model using the `q4_0` quantization type, you would run:
```sh
./bin/llama-llava-clip-quantize-cli /path/to/ggml-model-f32.gguf /path/to/ggml-model-quantized.gguf 2
```
This command will generate a quantized model at `/path/to/ggml-model-quantized.gguf` using the `q4_0` quantization method.
## Notes
- Quantization can lead to a loss in model accuracy, depending on the chosen quantization type. It is recommended to evaluate the quantized model's performance on your specific task to ensure it meets your requirements.
- The quantized model will typically be smaller in size and faster to run, making it more suitable for deployment in resource-constrained environments.
+4 -3
View File
@@ -41,8 +41,8 @@ Built upon `clip.cpp` (similar to `llava.cpp`), `libmtmd` offers several advanta
Multimodal projector (`mmproj`) files are specific to each model architecture.
For the following models, you can use `convert_hf_to_gguf.py`with `--mmproj` flag to get the `mmproj` file:
- [Gemma 3](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d) - Note: 1B variant does not have vision support
For the following models, you can use `convert_hf_to_gguf.py` with `--mmproj` flag to get the `mmproj` file:
- [Gemma 3](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d) ; See the guide [here](../../docs/multimodal/gemma3.md) - Note: 1B variant does not have vision support
- SmolVLM (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
- SmolVLM2 (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
- [Pixtral 12B](https://huggingface.co/mistral-community/pixtral-12b) - only works with `transformers`-compatible checkpoint
@@ -52,6 +52,8 @@ For the following models, you can use `convert_hf_to_gguf.py`with `--mmproj` fla
For older models, please refer to the relevant guide for instructions on how to obtain or create them:
NOTE: conversion scripts are located under `tools/mtmd/legacy-models`
- [LLaVA](../../docs/multimodal/llava.md)
- [MobileVLM](../../docs/multimodal/MobileVLM.md)
- [GLM-Edge](../../docs/multimodal/glmedge.md)
@@ -59,4 +61,3 @@ For older models, please refer to the relevant guide for instructions on how to
- [MiniCPM-V 2.6](../../docs/multimodal/minicpmv2.6.md)
- [MiniCPM-o 2.6](../../docs/multimodal/minicpmo2.6.md)
- [IBM Granite Vision](../../docs/multimodal/granitevision.md)
- [Google Gemma 3](../../docs/multimodal/gemma3.md)
-53
View File
@@ -1,53 +0,0 @@
#!/bin/bash
model_dir="/Users/cxt/model/llm/mobileVLM/MobileVLM-1.7B_processed"
projector_name="mmproj-model-f16.gguf"
llama_name="ggml-model-q4_k.gguf"
img_dir="/Users/cxt/model/llm"
img_name="demo.jpg"
prompt="A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <image>\nWho is the author of this book? \nAnswer the question using a single word or phrase. ASSISTANT:"
# img_name="cat.jpeg"
# prompt="A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <image>\nWhat is in the image? ASSISTANT:"
program_dir="build_64/bin"
binName="llama-mtmd-cli"
n_threads=4
deviceDir="/data/local/tmp"
saveDir="output"
if [ ! -d ${saveDir} ]; then
mkdir ${saveDir}
fi
function android_run() {
# # copy resource into device
# adb push ${model_dir}/${projector_name} ${deviceDir}/${projector_name}
# adb push ${model_dir}/${llama_name} ${deviceDir}/${llama_name}
adb push ${img_dir}/${img_name} ${deviceDir}/${img_name}
# copy program into device
adb push ${program_dir}/${binName} ${deviceDir}/${binName}
adb shell "chmod 0777 ${deviceDir}/${binName}"
# run
adb shell "echo cd ${deviceDir} ${deviceDir}/${binName} \
-m ${deviceDir}/${llama_name} \
--mmproj ${deviceDir}/${projector_name} \
-t ${n_threads} \
--image ${deviceDir}/${img_name} \
-p \"${prompt}\" \
> ${deviceDir}/${modelName}_${projector_name}_${n_threads}_${img_name}.txt"
adb shell "cd ${deviceDir}; pwd; ${deviceDir}/${binName} \
-m ${deviceDir}/${llama_name} \
--mmproj ${deviceDir}/${projector_name} \
-t ${n_threads} \
--image ${deviceDir}/${img_name} \
-p \"${prompt}\" \
>> ${deviceDir}/${modelName}_${projector_name}_${n_threads}_${img_name}.txt 2>&1"
adb pull ${deviceDir}/${modelName}_${projector_name}_${n_threads}_${img_name}.txt ${saveDir}
}
android_run
echo "android_run is Done!"
-8
View File
@@ -1,8 +0,0 @@
#!/bin/bash
cmake ../../../../ \
-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \
-DCMAKE_BUILD_TYPE=Release \
-DANDROID_ABI="arm64-v8a" \
-DANDROID_PLATFORM=android-23 $1
make -j4

Some files were not shown because too many files have changed in this diff Show More