Compare commits

..

145 Commits

Author SHA1 Message Date
Sigbjørn Skjæret 4b436e4e5e flake8 fix 2026-02-23 11:48:01 +01:00
Georgi Gerganov a6d3e9a239 ggml : relax asseerts for ggml_get_type_traits() 2026-02-22 21:37:58 +02:00
Georgi Gerganov 9c5d8dec37 gguf : add file size check for arrays 2026-02-22 21:36:56 +02:00
Georgi Gerganov c76408dbb9 gguf : add mem_size overflow test 2026-02-22 18:40:06 +02:00
Georgi Gerganov c79698f28a ggml : relax ggml_type asserts to debug-only 2026-02-22 16:32:39 +02:00
Georgi Gerganov 45250db0f8 ggml : remove deprecated ggml_type_sizef() 2026-02-22 16:23:57 +02:00
Georgi Gerganov dfac6caa40 ggml : print values when overflow 2026-02-22 16:09:53 +02:00
Georgi Gerganov 327e2ca6f2 gguf : minor print fix 2026-02-22 16:01:29 +02:00
Georgi Gerganov 09788740f3 gguf : fix ctx size for no_alloc == true 2026-02-22 15:54:44 +02:00
Georgi Gerganov 4e89ec67fa gguf : better name 2026-02-22 15:47:01 +02:00
Sigbjørn Skjæret 46a9a0656a enforce proper alignment in add_custom_alignment 2026-02-22 10:38:58 +01:00
Georgi Gerganov f2ac3ef57e py : restore tensor_fields 2026-02-22 10:54:14 +02:00
Georgi Gerganov 12c719b3f1 gguf-py : error on duplicate keys when reading 2026-02-22 09:48:16 +02:00
Georgi Gerganov 5d67acd422 ggml : check int overflow in ggml_new_tensor_impl and ggml_new_object 2026-02-22 09:48:16 +02:00
Georgi Gerganov d97dd299a0 py : assert that alignment is non-zero power of 2 2026-02-19 16:44:43 +02:00
Georgi Gerganov 2e23292cfe ggml : fix negative tensor type oob 2026-02-19 16:42:46 +02:00
Georgi Gerganov 7babe5fb13 gguf : prevent array elements exhaustion 2026-02-19 16:33:25 +02:00
Georgi Gerganov 357b8e50f1 gguf : prevent string exhaustion 2026-02-19 16:08:04 +02:00
Georgi Gerganov 69788e0d23 ggml : fix int overflows in ggml_new_object() 2026-02-19 15:59:09 +02:00
Georgi Gerganov 198f79d6c3 gguf : prevent integer overflow for ggml_context mem size 2026-02-19 15:51:00 +02:00
Georgi Gerganov da348c9dfb models : fix qwen3.5 beta/gate shapes (#19730)
* models : fix qwen3.5 beta/gate shapes

* cont : avoid extra reshapes
2026-02-19 15:19:53 +02:00
Saba Fallah e6267a9359 mtmd: build_attn modified, flash_attn on/off via ctx_params (#19729) 2026-02-19 13:50:29 +01:00
3 a l i 2bf318fd2f model : add JAIS-2 architecture support (#19488)
* model: add JAIS-2 architecture support

Add support for the JAIS-2 family of Arabic-English bilingual models
from Inception AI (https://huggingface.co/inceptionai/Jais-2-8B-Chat).

Architecture characteristics:
- LayerNorm (not RMSNorm) with biases
- ReLU² (ReLU squared) activation function
- Separate Q/K/V projections with biases
- Simple MLP without gate projection (up -> act -> down)
- RoPE positional embeddings
- GPT-2 BPE tokenizer

Supported model sizes:
- Jais-2-8B (32 layers, 26 heads, 3328 hidden)
- Jais-2-70B (68 layers, 56 heads, 7168 hidden)

Tested with quantizations: BF16, Q8_0, Q6_K, Q5_K_M, Q5_0, Q4_K_M, Q4_0, Q3_K_M, Q2_K

Note: JAIS-2 requires F32 precision accumulators for numerical stability
and uses standard attention (not flash attention) on CUDA backends.

* fix: run convert_hf_to_gguf_update.py for jais-2 tokenizer hash

* fix: use NEOX RoPE type for JAIS2

* fix: remove Q/K permutation (NEOX RoPE doesn't need it)

* fix: enable flash attention for JAIS2 (fixed by #19115)

* fix: add dedicated JAIS2 pre-tokenizer type and control vector support

- Add LLAMA_VOCAB_PRE_TYPE_JAIS2 with cascading whitespace regex
- Include original regex from tokenizer.json as comment
- Add build_cvec call for control vector support

* no longer necessary to override set_vocab

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-02-19 13:30:17 +01:00
Johannes Gäßler c78e682245 CUDA: fix kernel selection logic for tile FA (#19686)
* CUDA: fix kernel selection logic for tile FA

* add comment
2026-02-19 12:42:58 +01:00
Tarek Dakhran c5897995a7 mtmd : chat : Fix extra \n between text and media marker (#19595)
* mtmd : chat : Fix extra \n between text and media marker

Thanks to @tugot17 for detecting and reporting the issue.

For vision models (e.g. LFM2.5-VL-1.6B and Qwen/Qwen3-VL-4B-Instruct) `llama-mtmd-cli` produces identical output to HF implementation.

However `llama-server` doesn't. I traced it down to extra newline
inserted after `<__media__>`.

This happens in `to_json_oaicompat`, that treats media markers as text
and joins all parts with `\n` separator.

PR introduces new type `media_marker` and uses it for media markers.
Extra logic is added to prevent insertion of newlines before and after
media markers.

With this change number of input tokens is identical to HF
implementation and as a result the output is also identical.

I explored other ways to address the issue
* remove completely `\n` between text parts in `to_json_oaicompat`
* merge text messages in server-common.cpp before sending them to `to_json_oaicompat`

Please propose alternative ways of fixing this issue.

* Refactor to use explicite per type ifs

* Update common/chat.cpp

Co-authored-by: Piotr Wilkin (ilintar) <piotr.wilkin@syndatis.com>

* Update common_chat_templates_apply_legacy

---------

Co-authored-by: Piotr Wilkin (ilintar) <piotr.wilkin@syndatis.com>
2026-02-19 12:18:57 +01:00
Aleksander Grygier 03fd9d3bb4 webui: Fix Attachments not being included in completion request (#19731)
* fix: Add missing argument

* chore: update webui build output
2026-02-19 10:27:38 +01:00
Tarek Dakhran 8004f3a8d1 model : add tokenizer from LFM2.5-Audio-1.5B (#19687)
* model : Add tokenizer from LFM2.5-Audio-1.5B

[LFM2.5-Audio-1.5B](https://huggingface.co/LiquidAI/LFM2.5-Audio-1.5B) introduced lightweight audio tokenizer.

Tokenizer based on LFM2 architecture and acts as "embedding" model with
different input `n_embd` and output `n_embd_out`.

To be used in https://github.com/ggml-org/llama.cpp/pull/18641.

To convert use

```shell
python3 convert_hf_to_gguf.py /path/to/LFM2.5-Audio-1.5B/audio_detokenizer
```

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Formatting

* Rework check for attention layers

* Add LFM2 SWA model support

* Address PR feedback

* Set vocab to none

* Move helper function definitions to cpp file

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-02-19 09:54:48 +01:00
Daniel Bevenius eacb4b67a2 llama : use output_resolve_row() in get_logits_ith/get_embeddings_ith (#19663)
This commit updates get_logits_ith(), and get_embeddings_ith() to use
output_resolve_row() to resolve the batch index to output row index.

The motivation for this is to remove some code duplication between these
functions.
2026-02-19 09:48:08 +01:00
Ryan Mangeno c0d0430340 model : full modern bert support (#18330)
* full modern bert support

* added gelu op in rank pooling for modern bert

* still working on stuff, added mean calculation before classifier head

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* first layer is dense, as per modern bert research paper

* Update src/llama-graph.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* fixed set input for mean pooling to check if pooling type is ranking since modern bert does mean & rank

* Update src/llama-graph.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-02-19 08:52:21 +01:00
shalinib-ibm 3bb2fcc856 llamafile: powerpc: add FP16 MMA path for Q4/Q8 matmul (#19709)
Avoid xvi8ger4pp signed→unsigned bias correction by dequantizing Q4/Q8
inputs to FP16 and using FP16×FP16→FP32 MMA. This removes
post-processing overhead and improves performance.

Performance Impact:
1.5 ~ 2x improvement in PP_Speed for Q4 and Q8 Models,
measured with llama-bench and llama-batched-bench.
Q8 Model: granite-4.0-h-micro-Q8_0.gguf (from huggingface)
Q4 Model: Meta-Llama3-8b Q4 model (generated with llama-quantize from
f32 model)

llama-bench Q8 Model Results:
 model                          	       size 	     params 	 backend    	 threads 	            test 	Base t/s	Patch t/s
 granitehybrid 3B Q8_0          	   3.16 GiB 	     3.19 B 	 CPU        	      10 	             pp8 	         64.48 ± 4.72 	         73.99 ± 0.27
 granitehybrid 3B Q8_0          	   3.16 GiB 	     3.19 B 	 CPU        	      10 	            pp16 	         80.11 ± 0.32 	        112.53 ± 0.40
 granitehybrid 3B Q8_0          	   3.16 GiB 	     3.19 B 	 CPU        	      10 	            pp32 	         89.10 ± 0.27 	        152.95 ± 0.68
 granitehybrid 3B Q8_0          	   3.16 GiB 	     3.19 B 	 CPU        	      10 	            pp64 	         93.65 ± 0.25 	        187.83 ± 0.83
 granitehybrid 3B Q8_0          	   3.16 GiB 	     3.19 B 	 CPU        	      10 	           pp128 	         99.93 ± 0.02 	        201.32 ± 0.11
 granitehybrid 3B Q8_0          	   3.16 GiB 	     3.19 B 	 CPU        	      10 	           pp256 	        102.32 ± 0.40 	        208.32 ± 0.41
 granitehybrid 3B Q8_0          	   3.16 GiB 	     3.19 B 	 CPU        	      10 	           pp512 	        103.42 ± 0.40 	        209.98 ± 0.14
 granitehybrid 3B Q8_0          	   3.16 GiB 	     3.19 B 	 CPU        	      10 	           tg128 	         20.35 ± 0.01 	         19.57 ± 0.01

llama-bench Q4 Model Results:
 model                          	       size 	     params 	 backend    	 threads 	            test 	              Base    t/s 	               Patch   t/s
 llama 8B Q4_0                  	   4.33 GiB 	     8.03 B 	 CPU        	      10 	             pp8 	         34.77 ± 0.10 	         41.23 ± 0.08
 llama 8B Q4_0                  	   4.33 GiB 	     8.03 B 	 CPU        	      10 	            pp16 	         40.81 ± 0.04 	         64.55 ± 0.15
 llama 8B Q4_0                  	   4.33 GiB 	     8.03 B 	 CPU        	      10 	            pp32 	         44.65 ± 0.05 	         90.84 ± 0.22
 llama 8B Q4_0                  	   4.33 GiB 	     8.03 B 	 CPU        	      10 	            pp64 	         47.49 ± 0.03 	        114.39 ± 0.11
 llama 8B Q4_0                  	   4.33 GiB 	     8.03 B 	 CPU        	      10 	           pp128 	         49.29 ± 0.24 	        120.13 ± 0.19
 llama 8B Q4_0                  	   4.33 GiB 	     8.03 B 	 CPU        	      10 	           pp256 	         49.77 ± 0.23 	        121.51 ± 0.11
 llama 8B Q4_0                  	   4.33 GiB 	     8.03 B 	 CPU        	      10 	           pp512 	         49.89 ± 0.23 	        117.52 ± 0.10
 llama 8B Q4_0                  	   4.33 GiB 	     8.03 B 	 CPU        	      10 	           tg128 	         13.40 ± 0.01 	         13.37 ± 0.00

Llama perplexity Results:

Model	                    Base Final PPL Estimate	Patch Final PPL Estimate
granite-4.0-h-micro-Q8_0    1.3862 +/- 0.04424	        1.3868 +/- 0.04432
Meta-Llama3-8b Q4	    1.3801 +/- 0.04116	        1.3803 +/- 0.04116

Signed-off-by: Shalini.Salomi.Bodapati <Shalini.Salomi.Bodapati@ibm.com>
2026-02-19 14:28:53 +08:00
Georgi Gerganov 27326bfce1 models : dedup qwen35 graphs (#19660)
* models : dedup qwen35 graphs

* cont : add missing sigmoid
2026-02-19 08:17:49 +02:00
ymcki ad9f692f8f models : dedup Kimi Linear delta net implementation (#19668)
* models : add llm_build_delta_net_base

* cont : keep qwen35 and qwen35moe graphs intact

* cont : add comments [no ci]

* add kimi linear to delta-net-base

* removed unnecessary ggml_cont from g_exp_t

* removed ggml_cont from g_diff_exp_t. moved ggml_cont for o to kimi-linear.cpp

* removed unnecessary diag mask

* cont : simplify

* cont : avoid graph splits

* scale q after mul instead of beginning

* scale q after mul instead of beginning

* identical ppl

* cont : fix scale and decay mask

* minor : remove TODO

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-02-19 08:15:17 +02:00
Piotr Wilkin (ilintar) 8a70973557 Add Jinja support for "indent" string filter (#19529)
* Add partial Jinja support for "indent" string filter

* Fully implement indent

* Add tests for all width variants.

* Update tests/test-jinja.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Fix getline ignoring trailing newlines

* Update common/jinja/value.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* fix first indent condition

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-02-19 00:25:52 +01:00
Reese Levine e7f2f95c9a ggml webgpu: Fix bug in dispatching large matrix-vector multiplication (#19535)
* Fix bug in dispatching large matrix-vector multiplication
2026-02-18 16:06:29 -07:00
matteo b55dcdef5d server: save generated text for the /slots endpoint (for LLAMA_SERVER_SLOTS_DEBUG=1) (#19622)
* save generated text for the /slots endpoint

* update debug_generated_text only when LLAMA_SERVER_SLOTS_DEBUG > 0

* Apply suggestions from code review

---------

Co-authored-by: Matteo <matteo@matteo>
Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
2026-02-18 18:53:37 +01:00
Xuan-Son Nguyen eeef3cfced model: support GLM-OCR (#19677)
* model: support GLM-OCR

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-02-18 17:51:40 +01:00
Maciej Lisowski e99f1083a0 docs: Fix broken links for preparing models in Backends (#19684) 2026-02-18 23:50:23 +08:00
Reese Levine 238856ec8f ggml webgpu: shader library organization (#19530)
* Basic JIT compilation for mul_mat, get_rows, and scale (#17)

* scale jit working

* preliminary working jit for getrows and mulmat, needs refining

* simplified mul_mat preprocessing switch statement

* get_rows fixes, mul_mat refinement

* formatted + last edits

* removed some extraneous prints

* fixed get_rows, fixed workgroup dispatch in mul_mat. no gibberish

* small fix

* some changes, working

* get_rows and mul_mat jit fixed and working

* Update formatting

* formatting

* Add header

---------

Co-authored-by: Neha Abbas <nehaabbas@ReeseLevines-MacBook-Pro.local>
Co-authored-by: Reese Levine <reeselevine1@gmail.com>

* Start work on all-encompassing shader library

* refactor argmax, set_rows

* Refactor all but flashattention, mat mul

* flashattention and matrix multiplication moved to new format

* clean up preprocessing

* Formatting

* remove duplicate constants

* Split large shaders into multiple static strings

---------

Co-authored-by: neha-ha <137219201+neha-ha@users.noreply.github.com>
2026-02-18 07:51:02 -07:00
Aleksander Grygier ea003229d3 Pre-MCP UI and architecture cleanup (#19689) 2026-02-18 12:02:02 +01:00
Jeff Bolz d0061be838 vulkan: split mul_mat into multiple dispatches to avoid overflow (#19509)
* vulkan: split mul_mat into multiple dispatches to avoid overflow

The batch dimensions can be greater than the max workgroup count limit,
in which case we need to split into multiple dispatches and pass the base
index through a push constant.

Fall back for the less common p021 and nc variants.

* address feedback
2026-02-18 10:47:10 +01:00
Adrien Gallouët a569bda445 common : make small string helpers as inline functions (#19693)
Also use string_view when it make sense and fix some corner cases.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-18 08:03:01 +01:00
shaofeiqi e2f19b320f opencl: refactor expm1 and softplus (#19404)
* opencl: refactor expm1

* opencl: refactor softplus

* opencl: use h for half literals

---------

Co-authored-by: Li He <lih@qti.qualcomm.com>
2026-02-17 14:47:18 -08:00
shaofeiqi 983559d24b opencl: optimize mean and sum_row kernels (#19614)
* opencl: optimize mean and sum_row kernels

* opencl: add comment for max subgroups

* opencl: format

---------

Co-authored-by: Li He <lih@qti.qualcomm.com>
2026-02-17 13:56:09 -08:00
Daniel Bevenius 2b089c7758 model-conversion : add option to print tensor values (#19692)
This commit updates the tensor-info.py script to support the option to
print the first N values of a tensor when displaying its information.

The motivation for this is that it can be useful to inspect some actual
values in addition to the shapes of the tensors.
2026-02-17 20:43:22 +01:00
Aleksander Grygier afa6bfe4f7 Pre-MCP UI and architecture cleanup (#19685)
* webui: extract non-MCP changes from mcp-mvp review split

* webui: extract additional pre-MCP UI and architecture cleanup

* chore: update webui build output
2026-02-17 13:47:45 +01:00
Talha Can Havadar ae2d3f28a8 ggml: ggml-cpu: force-no-lto-for-cpu-feats (#19609)
When LTO enabled in build environments it forces all builds to have LTO
in place. But feature detection logic is fragile, and causing Illegal
instruction errors with lto. This disables LTO for the feature
detection code to prevent cross-module optimization from inlining
architecture-specific instructions into the score function. Without this,
LTO can cause SIGILL when loading backends on older CPUs (e.g., loading
power10 backend on power9 crashes before feature check runs).
2026-02-17 13:22:46 +02:00
Georgi Gerganov ad8207af77 cuda : enable CUDA graphs for MMID 1 <= BS <= 4 (#19645)
* cuda : enable CUDA graphs for MMID BS <= 4

* cont : add stream capture check

Co-authored-by: Oliver Simons <osimons@nvidia.com>

* cont : add MMVQ_MMID_MAX_BATCH_SIZE

---------

Co-authored-by: Oliver Simons <osimons@nvidia.com>
2026-02-17 12:31:49 +02:00
Daniel Bevenius 667b694278 model-conversion : make printing of config values optional (#19681)
* model-conversion : make printing of config values optional

This commit updates run-org-model.py to make the printing of model
configuration values optional.

The motivation for this change is that not all models have these
configuration values defined and those that do not will error when
running this script. With these changes we only print the values if they
exist or a default value.

We could optionally just remove them but it can be useful to see these
values when running the original model.
2026-02-17 10:46:53 +01:00
Sigbjørn Skjæret e48349a49d ci : bump komac version (#19682) 2026-02-17 09:30:31 +01:00
Adrien Gallouët ae46a61e41 build : link ws2_32 as PUBLIC on Windows (#19666)
Signed-off-by: Adrien Gallouët <adrien@gallouet.fr>
2026-02-17 08:37:07 +01:00
Adrien Gallouët 65cede7c70 build : cleanup library linking logic (#19665)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-17 08:36:45 +01:00
DAN™ 05fa625eac convert : add JoyAI-LLM-Flash (#19651)
* convert_hf_to_gguf: add JoyAI-LLM-Flash tokenizer hash mapping to deepseek-v3

* llama-vocab: create a new pre-tokenizer name for joyai-llm.

* add missing vocab type section

* Update convert_hf_to_gguf_update.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-02-16 22:49:57 +01:00
AesSedai d612901116 perplexity: add proper batching (#19661) 2026-02-16 18:44:44 +02:00
Ivan Chikish cceb1b4e33 common : inline functions (#18639) 2026-02-16 17:52:24 +02:00
Judd d23a55997d ggml : make ggml_is_view as API (#19539)
* make `ggml_is_view` as API

* introduce `ggml_aux_is_view` as inline version for internal use.

* change `ggml_aux_is_view` to  `ggml_impl_is_view`
2026-02-16 17:43:34 +02:00
Saurabh Dash 5f28c53d11 model: Add support for Tiny Aya Models (#19611)
* changes for tiny aya

* changes to hash

* changes to vocab

* fix some tokenizer regex edge cases

* update comment

* add some comments for regex

* Apply suggestion from @ngxson

---------

Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
2026-02-16 16:28:46 +01:00
Adrien Gallouët 4408494144 build : rework llama_option_depr to handle LLAMA_CURL (#19658)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-16 16:06:48 +01:00
Mario Limonciello 2ba9adc093 Adjust workaround for ROCWMMA_FATTN/GFX9 to only newer ROCm veresions (#19591)
Avoids issues with ROCm 6.4.4.

Closes: https://github.com/ggml-org/llama.cpp/issues/19580
Fixes: 6845f7f87 ("Add a workaround for compilation with ROCWMMA_FATTN and gfx9 (#19461)")

Signed-off-by: Mario Limonciello (AMD) <superm1@kernel.org>
2026-02-16 14:46:08 +01:00
Georgi Gerganov cc45f2ada6 models : deduplicate delta-net graphs for Qwen family (#19597)
* models : add llm_build_delta_net_base

* cont : keep qwen35 and qwen35moe graphs intact

* cont : add comments
2026-02-16 14:35:04 +02:00
Georgi Gerganov d5dfc33027 graph : fix KQ mask, lora, cvec reuse checks (#19644)
* graph : fix KQ mask reuse condition

* cont : dedup KQ mask build and can_reuse

* cont : fix build

* graph : fix adapter check for reuse
2026-02-16 09:21:11 +02:00
abhijain1204fujitsu 267ba5a1d9 ggml: aarch64: Implement SVE in Gemm q4_k 8x8 q8_k Kernel (#19132)
* Updated repack.cpp

* Updated repack.cpp

* Updated repack.cpp

* Added if condition to support only vector length 256.

* Changed the format removed comments and duplicate variable

* If SVE 256 not present then was using generic function to compute, hence slowing the performance. 

So added code if SVE 256 is not present then use NEON code.

* Code format change suggestion

---------

Co-authored-by: Vithule, Prashant <Prashant.Vithule@fujitsu.com>
2026-02-16 14:38:43 +08:00
Georgi Gerganov ff4affb4c1 sync : ggml 2026-02-15 22:24:29 +02:00
Georgi Gerganov 55d58599c8 ggml : bump version to 0.9.7 (ggml/1425) 2026-02-15 22:24:29 +02:00
Georgi Gerganov 1a8c700bfd ggml : bump version to 0.9.6 (ggml/1423) 2026-02-15 22:24:29 +02:00
David Friehs 27b93cbd15 cuda: optimize iq2xxs/iq2xs/iq3xxs dequantization (#19624)
* cuda: optimize iq2xxs/iq2xs/iq3xxs dequantization

- load all 8 int8 for a grid position in one load
- calculate signs via popcnt instead of fetching from ksigns table
- broadcast signs to drop individual shift/mask

* cuda: iq2xxs: simplify sum scaling

express `(sum * scale + sum / 2) / 4` as `(sum * (scale * 2 + 1)) / 8`
express `((aux32 >> 28) * 2 + 1)` as `(aux32 >> 27 | 1)`

saves 3 registers for mul_mat_vec_q (152 -> 149) according to nsight
AFAICT no overflow can occur here as iq2xxs values are far too small

* uint -> uint32_t

error: identifier "uint" is undefined
2026-02-15 22:38:42 +05:30
Aaron Teo 6e67fd2144 docs: update s390x build docs (#19643) 2026-02-16 00:33:34 +08:00
Adrien Gallouët 9e118b97c4 build : remove LLAMA_HTTPLIB option (#19623)
This option was introduced as a workaround because cpp-httplib could not
build on visionOS. Since it has been fixed and now compiles on all platforms,
we can remove it and simplify many things.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-15 15:38:50 +01:00
Daniel Bevenius 57088276d4 cmake : check if KleidiAI API has been fetched (#19640)
This commit addresses a build issue with the KleidiAI backend when
building multiple cpu backends. Commmit
3a00c98584 ("cmake : fix KleidiAI install
target failure with EXCLUDE_FROM_ALL") introduced a change where
FetchContent_Populate is called instead of FetchContent_MakeAvailable,
where the latter does handle this case (it is idempotent but
FetchContent_Populate is not).

I missed this during my review and I should not have commited without
verifying the CI failure, sorry about that.
2026-02-15 13:59:38 +01:00
Georgi Gerganov 341bc7d23c context : fix output reorder with backend sampling (#19638) 2026-02-15 14:57:40 +02:00
Georgi Gerganov 08e6d914b8 ggml : avoid UB in gemm ukernel (#19642) 2026-02-15 14:56:35 +02:00
Aaron Teo 184c694f45 ggml-cpu: optimize ggml_vec_dot_bf16 for s390x (#19399) 2026-02-15 18:20:35 +08:00
Aman Gupta 684b36101c ggml-cpu: FA add GEMM microkernel (#19422)
* ggml-cpu: FA add GEMM microkernel

* add guard for sizeless vector types

* fix case where DV % GGML_F32_EPR !=0

* move memset out of the loop

* move another memset out of the loop

* use RM=4 for arm

* simd_gemm: convert everything to int

* convert everything to size_t to avoid warnings

* fixup

* add pragma for ignoring aggressive loop optimizations
2026-02-15 11:09:24 +05:30
SamareshSingh 3a00c98584 cmake : fix KleidiAI install target failure with EXCLUDE_FROM_ALL (#19581)
* cmake: fix KleidiAI install target failure with EXCLUDE_FROM_ALL

Fix for the bug #19501 by adding EXCLUDE_FROM_ALL to FetchContent_Declare. This properly excludes KleidiAI from both build and install targets, preventing install failures when GGML_CPU_KLEIDIAI=ON is used.

The KleidiAI source files are still compiled into libggml-cpu.so, preserving all functionality.

* addressed code review comments
2026-02-15 06:22:53 +01:00
Sigbjørn Skjæret 079feab9e3 convert : ensure all models handle new experts count (#19621)
* ensure all models handle new experts count

* revert removal for PhiMoeModel, does not inherit from base
2026-02-14 22:22:32 +01:00
Anav Prasad 01d8eaa28d mtmd : Add Nemotron Nano 12B v2 VL support (#19547)
* nemotron nano v2 vlm support added

* simplified code; addressed reviews

* pre-downsample position embeddings during GGUF conversion for fixed input size
2026-02-14 14:07:00 +01:00
Georgi Gerganov 1725e316c1 models : optimize qwen3next graph (#19375)
* models : optimizing qwen3next graph

* cont

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* cont : remove redundant q, g chunking

* minor

* minor

* avoid passing masks around

* avoid concats during chunking

* naming + shapes

* update names and use prefix to disable CUDA graphs
2026-02-14 12:57:36 +02:00
Adrien Gallouët b7742cf321 ggml : fix GGML_DEBUG with OpenMP (#19599)
last_graph is only available without OpenMP, but
ggml_graph_compute_thread() is called in both cases.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-14 11:22:57 +01:00
iMil badba89320 NetBSD build support (#19589) 2026-02-14 09:47:01 +01:00
Aleksander Grygier baa12f3831 webui: Architecture and UI improvements (#19596) 2026-02-14 09:06:41 +01:00
agent-enemy-2 2d8015e8a4 llama : update LoRA API. + fix excessive graph reserves (#19280)
* Refactoring to use new llama_put_adapter_loras

* cont : alternative lora API

---------

Co-authored-by: Jake Chavis <jakechavis6@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-02-14 10:06:27 +02:00
George eb145c0753 mmap: Fix Windows handle lifetime (#19598)
* ggml: added cleanups in ggml_quantize_free
Add missing cleanup calls for IQ2_S, IQ1_M quantization types and IQ3XS with 512 blocks during quantization cleanup.

* mmap: Fix Windows handle lifetime
Move hMapping from local variable to member variable so it stays alive for the entire lifetime of the mapping.
The file mapping handle must remain valid until UnmapViewOfFile is called.
Fixes cleanup order in destructor.

* Update llama-mmap.cpp

* Update llama-mmap.cpp

Remove trailing whitespace from line 567
2026-02-14 10:05:12 +02:00
Georgi Gerganov 6e473fb384 metal : fix ACC op (#19427) 2026-02-14 09:54:03 +02:00
Adrien Gallouët c7db95f106 scripts : use official split.py for cpp-httplib (#19588)
* scripts : use official split.py for cpp-httplib

Using the official script is safer and ensures the generated code aligns
with the library's standards.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Catch generic errors

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Allow print()

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Ensure robust cleanup

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

---------

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-14 08:41:16 +01:00
Sigbjørn Skjæret 0d00ef65ed convert : store ffn_gate_inp_shexp as F32 (#19606) 2026-02-14 08:17:43 +01:00
Adrien Gallouët 91ea5d67f2 build : fix libtool call in build-xcframework.sh (#19605)
Run libtool via xcrun like strip and dsymutil, to have proper tool resolution.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-14 06:48:37 +01:00
Jeff Bolz dbb023336b vulkan: support L2_NORM with contiguous rows (#19604) 2026-02-14 06:42:04 +01:00
Jeff Bolz 53aef25a88 vulkan: support GGML_OP_SET (#19584) 2026-02-14 06:36:38 +01:00
Sophon 2dec548094 vulkan: Add vendor id for Qualcomm drivers (#19569)
This commit allows Qualcomm native vulkan driver to be used on Windows
instead of Mesa Dozen.
2026-02-14 06:29:17 +01:00
Max Krasnyansky 0ccbfdef3e hexagon: further optimizations and refactoring for flash attention (#19583)
* ggml-hexagon: fa improvements

ggml-hexagon: optimize flash attention calculations with improved variable handling

ggml-hexagon: streamline flash attention operations by removing redundant checks for FP32

ggml-hexagon: optimize hvx_dot_f16_f16_aa_rx2 by simplifying variable handling for unused elements

ggml-hexagon: optimize flash attention by changing slope vector type to F16

* hexfa: fixed test-backend-ops failurs due to leftover element handling

* hexagon: refactor and optimize fa to use local context struct

* ggml-hexagon: optimize flash-attention using hvx_vec_expf

Use HVX for online softmax.

---------

Co-authored-by: chraac <chraac@gmail.com>
2026-02-13 16:27:30 -08:00
Mengsheng Wu 94a602db66 github : add missing backends to issue templates (#19603) 2026-02-14 00:56:53 +01:00
Jeff Bolz 05a6f0e894 vulkan: restore -inf check in FA shaders (#19582) 2026-02-13 13:35:29 -06:00
Adrien Gallouët b48e80f677 common : update download code (#19573)
* common : remove legacy .json to .etag migration code

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* common : simplify common_download_file_single_online

This commit also force a redownload if the file exists
but has no .etag file.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

---------

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-13 15:10:46 +01:00
Xuan-Son Nguyen 752584d5f5 model: support GLM MoE DSA arch (NOTE: indexer is not yet supported) (#19460)
* model: support GLM MoE DSA arch

* working version

* pyright

* keep indexer tensors

* add indexer gguf params

* loaded now

* Apply suggestions from code review

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* update

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* minor fix and cleanup

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-02-13 14:56:53 +01:00
Alberto Cabrera Pérez cc2aa81513 Fix wrong memcpy length for block_interleave == 4 (#19575) 2026-02-13 20:32:14 +08:00
ymcki 0e21991472 fix vulkan ggml_acc only works in 3d but not 4d (#19426)
* fix vulkan ggml_acc only works in 3d but not 4d

* removed clamp in test_acc_block

* use the correct stride and its test case

* cuda : fix "supports op" condition

* change src0 to src1 in ggml_vk_acc. Update acc.comp with jeffbolznv\'s suggestion except to keep the boundary check

* version without boundary check

* revert back to boundary check version

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-02-13 13:31:37 +01:00
Sigbjørn Skjæret b2ecc0cdb4 support --verbose-prompt (#19576) 2026-02-13 12:49:10 +01:00
Aman Gupta 5065da554e CUDA: loop over ne2*ne3 in case it overflows (#19538)
* CUDA: loop over ne2*ne3 in case it overflows

* use fastdiv
2026-02-13 17:01:40 +05:30
Aleksander Grygier 5174d7206f webui: UI and routing fixes (#19586)
* chore: update webui build output

* chore: update webui build output

* fix: Scroll issues in DropdownMenuSearchable

* webui: fix redirect to root ignoring base path

* fix: Word wrapping

* fix: remove obsolete modality UI tests causing CI failures

- Remove VisionModality/AudioModality test stories
- Remove mockServerProps usage and imports
- Simplify Default test (remove dropdown interaction checks)
- Simplify FileAttachments test (remove mocks)

* feat: Improve formatting performance time

---------

Co-authored-by: Pascal <admin@serveurperso.com>
2026-02-13 12:31:00 +01:00
Oliver Simons 43919b7f4f CUDA: Do not mutate cgraph for fused ADDs (#19566)
* Do not mutate cgraph for fused ADDs

1. We should try to minimize in-place changes to the incoming
   ggml_cgraph where possible (those should happen in graph_optimize)
2. Modifying in-place leads to an additional, unnecessary graph capture
   step as we store the properties before modifying the graph in-place
   in the cuda-backend

* Assert ggml_tensor is trivially copyable

* Update ggml/src/ggml-cuda/ggml-cuda.cu

Co-authored-by: Aman Gupta <amangupta052@gmail.com>

---------

Co-authored-by: Aman Gupta <amangupta052@gmail.com>
2026-02-13 15:07:55 +05:30
Pavan Shinde 423cf0b26f docs : fix broken link and typo (#19560) 2026-02-13 09:38:09 +01:00
ymcki 33a56f90a6 model : Kimi Linear fix conv state update (#19531)
* fix conv state update for llama-server parallel serving

---------

Co-authored-by: Piotr Wilkin (ilintar) <piotr.wilkin@syndatis.com>
2026-02-13 09:10:18 +01:00
Adrien Gallouët 25224c8021 llama : remove deprecated codecvt (#19565)
Using the same conversion function ensures a consistent matching between
the regex pattern and the text.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-13 06:43:53 +01:00
Adrien Gallouët 2f5d8f8edc vendor : update BoringSSL to 0.20260211.0 (#19562)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-13 06:43:26 +01:00
Georgi Gerganov bb96bfd361 memory : fix kv cache size for hybrid models (#19559) 2026-02-13 07:36:24 +02:00
Georgi Gerganov 0644baefde metal : improve concurrency (#19555) 2026-02-13 07:35:57 +02:00
Georgi Gerganov 490eb96b88 metal : support GGML_OP_SET (#19548) 2026-02-13 07:34:52 +02:00
Shupei Fan 3bb78133ab hexagon: fix typo in vtcm_needs_release (#19545) 2026-02-12 15:07:49 -08:00
lhez 79cc0f2daf opencl: add basic support for q4_1 (#19534)
* opencl: add q4_1 mv

* opencl: clean up

* opencl: add flattened q4_1 mv

* opencl: clean up

* opencl: add basic q4_1 mm

* opencl: fix whitespace

* opencl: add general q4_0 mm
2026-02-12 14:52:37 -08:00
Georgi Gerganov 338085c69e args : add -kvu to llama-parallel (#19577) 2026-02-12 21:52:41 +02:00
Aleksander Grygier 4c61875bf8 webui: Add switcher to Chat Message UI to show raw LLM output (#19571) 2026-02-12 19:55:51 +01:00
Adrien Gallouët 4b385bfcf8 vendor : update cpp-httplib (#19537)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-12 16:11:22 +01:00
Christian Schmitz f488429380 llama : update outdated comment in llama.h (#19428)
* Updated documentation

Model is no longer a parameter

* llama : fix trailing whitespace in comment

---------

Co-authored-by: Daniel Bevenius <daniel.bevenius@gmail.com>
2026-02-12 15:52:57 +01:00
Aleksander Grygier 4d688f9ebb (webui) FEATURE: Enable adding or injecting System Message into chat (#19556)
* feat: Enable adding System Prompt per-chat

* fix: Save draft message in Chat Form when adding System Prompt from new chat view

* fix: Proper system message deletion logic

* chore: Formatting

* chore: update webui build output
2026-02-12 13:56:08 +01:00
Daniel Bevenius ff599039a9 scripts : add support for forks in pr2wt.sh (#19540)
This commit adds support for using the pr2wt.sh (pull request to
workspace) script with forks of upstream llama.cpp.
2026-02-12 13:14:28 +01:00
Aleksander Grygier f486ce9f30 (webui) REFACTOR: UI primitives and polish (#19551)
* webui: UI primitives and polish (non-MCP)

* chore: update webui build output
2026-02-12 12:21:00 +01:00
Aleksander Grygier 38adc7d469 WebUI Architecture Cleanup (#19541)
* webui: architecture foundation (non-MCP core refactors)

* chore: update webui build output
2026-02-12 11:22:27 +01:00
Georgi Gerganov 3b3a948134 metal : update sum_rows kernel to support float4 (#19524) 2026-02-12 11:35:28 +02:00
Mario Limonciello 6845f7f87f Add a workaround for compilation with ROCWMMA_FATTN and gfx9 (#19461)
There is an upstream problem [1] with AMD's LLVM 22 fork and
rocWMMA 2.2.0 causing compilation issues on devices without
native fp16 support (CDNA devices).

The specialized types aren't resolved properly:
```
/opt/rocm/include/rocwmma/internal/mfma_impl.hpp:2549:37: error: ambiguous partial specializations of 'amdgcn_mfma<__half, __half, __half, 16, 16, 16>'
 2549 |             using ARegsT = typename Impl::ARegsT;
```

Add a workaround to explicitly declare the types and cast when
compiling with HIP and ROCWMMA_FATTN [2].  When this is actually
fixed upstream some guards can be used to detect and wrap the
version that has the fix to only apply when necessary.

Link: https://github.com/ROCm/rocm-libraries/issues/4398 [1]
Link: https://github.com/ggml-org/llama.cpp/issues/19269 [2]

Signed-off-by: Mario Limonciello <mario.limonciello@amd.com>
2026-02-12 09:38:35 +01:00
RichardScottOZ fa16e517a3 server : fix typo in README.md for features list (#19510)
extra l for full
2026-02-12 08:56:25 +01:00
TriDefender 313493de53 docs : update path in snapdragon README.md (#19533)
paths changed so original example didn't work
2026-02-12 08:13:51 +01:00
Max Krasnyansky b1ff83bbb0 hexagon: further optimization and tuning of matmul and dot kernels (#19407)
* ggml-hexagon: implement 2x2 matmul kernel

* hexmm: implement vec_dot_rx2x2 for Q8_0 and MXFP4

* hexagon: fix editor config failures

* hexagon: refactor matmul ops to use context struct and remove wrappers

Also implement vec_dot_f16 2x2

* hexagon: refactor dyn quantizers to use mmctx

* hexagon: remove mm fastdiv from op_ctx

* hexagon: refactor matmul entry point to reduce code duplication

---------

Co-authored-by: Trivikram Reddy <tamarnat@qti.qualcomm.com>
2026-02-11 23:04:27 -08:00
Adrien Gallouët 4ae1b7517a common : replace deprecated codecvt using parse_utf8_codepoint (#19517)
Signed-off-by: Adrien Gallouët <adrien@gallouet.fr>
2026-02-12 07:27:52 +01:00
lhez 4d3daf80f8 opencl: add general Q6_K mm and Q4_K mv (#19347)
* opencl: add general q6_k mm

* opencl: refine condition for q6_K mm

* opencl: add general q4_K mv

* opencl: fix whitespace
2026-02-11 10:33:13 -08:00
Georgi Gerganov 914dde72ba ggml : unary ops support non-cont src0 + metal F16 unary ops (#19511)
* ggml : unary ops support non-cont src0

* metal : support F16 unary ops + fix ELU
2026-02-11 18:58:43 +02:00
Daniel Bevenius 3136a849db common : remove unused token util functions (#19506)
This commit removes two unused functions `common_lcp` and `common_lcs`.
The last usage of these functions was removed in
Commit 33eff40240 ("server : vision support
via libmtmd") and are no longer used anywhere in the codebase.
2026-02-11 17:41:35 +01:00
AesSedai e463bbdf65 model: Add Kimi-K2.5 support (#19170)
* Move dequant_model to after the text_config merge
Add new kimi-k2.5 keys to mtmd convert
Update V_MMPROJ tensor mapping for new mm_projector.proj keys
Update V_M_IMP_NORM for new mm_projector.pre_norm key

* Fix a couple of oversights

* Add image support for Kimi-K2.5

* Revert changes to KimiVLForConditionalGeneration

* Fix an assert crash

* Fix permute swapping w / h on accident

* Kimi-K2.5: Use merged QKV for vision

* Kimi-K2.5: pre-convert vision QK to use build_rope_2d

* Kimi-K2.5: support non-interleaved rope for vision

* Kimi-K2.5: fix min / max pixel

* Kimi-K2.5: remove v/o permutes, unnecessary

* Kimi-K2.5: update permute name to match

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Kimi-K2.5: replace build_rope_2d ggml_cont with ggml_view_3d pointers

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-02-11 16:47:30 +01:00
Daniel Bevenius 53de59f67d build : fix case in dSYMs path for build-macos [no ci] (#19515)
This commit updates an incorrect dSYMs where the the 's' was uppercase
by mistake.

The motivation for fixing this is that this can cause issues on case
sensitive operating systems.

Refs: https://github.com/ggml-org/whisper.cpp/pull/3630
2026-02-11 14:02:29 +01:00
Georgi Gerganov 9ab072ebbe metal : extend l2_norm support for non-cont src0 (#19502) 2026-02-11 14:53:19 +02:00
Johannes Gäßler ada90bf2ba docs: ban AI for issues and discussions [no CI] (#19512) 2026-02-11 12:49:40 +01:00
Adrien Gallouët 0c1f39a9ae common : improve download error reporting (#19491)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-11 09:27:55 +01:00
Max Krasnyansky 73cd5e1b97 hexagon: Add ARGSORT, DIV, SQR, SQRT, SUM_ROWS, GEGLU (#19406)
* hexagon: add ARGSORT op

Co-authored-by: Yarden Tal <yardent@qti.qualcomm.com>

* hexagon: argsort reject tensors with huge rows for now

* Adding support for DIV,SQR,SQRT,SUM_ROWS ops in hexagon backend

* hexagon : Add GEGLU op

* hexagon: fix editor config check

* hexagon: rewrite and optimize binary ops ADD/SUB/MUL/DIV/ADD_ID to use DMA

---------

Co-authored-by: Yarden Tal <yardent@qti.qualcomm.com>
Co-authored-by: Manohara Hosakoppa Krishnamurthy <mhosakop@qti.qualcomm.com>
2026-02-10 23:21:12 -08:00
thecaptain789 8ee538ce73 llama : correct typos 'occured' and 'occurences' (#19414)
Co-authored-by: thecaptain789 <thecaptain789@users.noreply.github.com>
2026-02-11 07:05:31 +01:00
Georgi Gerganov 6d95707827 model : fix wavtokenizer embedding notions (#19479) 2026-02-11 07:52:20 +02:00
Georgi Gerganov 89181c0b6d ggml : extend bin bcast for permuted src1 (#19484)
* tests : extend bin bcast for permuted src1

* cont : extend bin support

* cont : s0 is always 1

* tests : simplify
2026-02-11 07:52:00 +02:00
Georgi Gerganov ceaa89b786 metal : consolidate unary ops (#19490) 2026-02-11 07:51:12 +02:00
Daniel Bevenius 2cce9fddb7 llama : refactor sampling_info to use buffer_view template (#19368)
* llama : refactor sampling_info to use buffer_view template

This commit updates the sampling_info struct in llama-context to use a
buffer_view template for the logits, probs, sampled tokens, and
candidates buffers.

The motivation for this is to simplify the code, improve type safety
and readability.
2026-02-11 05:38:13 +01:00
Oliver Simons 612db61886 CUDA : Update CCCL-tag for 3.2 to final release from RC (#19486)
CCCL 3.2 has been released since it was added to llama.cpp as part of
the backend-sampling PR, and it makes sense to update from RC to final
released version.

https://github.com/NVIDIA/cccl/releases/tag/v3.2.0
2026-02-10 22:31:19 +01:00
Nikhil Jain 57487a64c8 [WebGPU] Plug memory leaks and free resources on shutdown (#19315)
* Fix memory leaks in shader lib, backend, backend_context, buffer_context, and webgpu_buf_pool

* Free pools

* Cleanup

* More cleanup

* Run clang-format

* Fix arg-parser and tokenizer test errors that free an unallocated buffer

* Fix device lost callback to not print on device teardown

* Fix include and run clang-format

* remove unused unused

* Update binary ops

---------

Co-authored-by: Reese Levine <reeselevine1@gmail.com>
2026-02-10 08:04:00 -08:00
JJJYmmm fc0fe40049 models : support qwen3.5 series (#19468)
* support qwen3.5 series

* remove deepstack for now, and some code clean

* code clean

* add FULL_ATTENTION_INTERVAL metadata

* code clean

* reorder v heads for linear attention to avoid expensive interleaved repeat
2026-02-10 18:00:26 +02:00
Xuan-Son Nguyen 9a96352729 test: fix IMROPE perf test case (#19465) 2026-02-10 14:37:50 +01:00
Alberto Cabrera Pérez c03a5a46f0 ggml-cpu: arm64: q6_K repack gemm and gemv (and generic) implementations (dotprod) (#19360)
* First working version of GEMM and GEMV

* interleave loads and compute

* Clang-format

* Added missing fallback. Removed tested TODO.

* Swap M and N to be consistent with the repack template convention
2026-02-10 10:47:45 +00:00
k4ss4n 6948adc90d ggml : use noexcept overload for is_regular_file in backend registration (#19452)
using noexcept std::filesystem::directory_entry::is_regular_file
overload prevents abnormal termination upon throwing an error
(as caused by symlinks to non-existent folders on linux)

Resolves: #18560
2026-02-10 10:57:48 +01:00
Piotr Wilkin (ilintar) 854b09f0d7 convert : move experts permutation from Qwen2MoeModel to Qwen3VLMoeTextModel (#19445)
* Add special case for Qwen3VLMoe

* Fix down path, remove arrows and checkmarks

* ws

* Moved to Qwen3VL

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-02-10 09:01:37 +01:00
Daniel Bevenius 66d403c480 tts : fix typos in README.md [no ci] (#19463) 2026-02-10 07:30:41 +01:00
Raul Torres f0bfe54f55 CANN: Remove unnecessary wrapper for gml_backend_buft_is_cann (#18968) 2026-02-10 14:19:30 +08:00
342 changed files with 29273 additions and 14893 deletions
@@ -41,7 +41,7 @@ body:
attributes:
label: GGML backends
description: Which GGML backends do you know to be affected?
options: [AMX, BLAS, CPU, CUDA, HIP, Metal, Musa, RPC, SYCL, Vulkan, OpenCL, zDNN]
options: [AMX, BLAS, CANN, CPU, CUDA, Hexagon, HIP, Metal, Musa, OpenCL, RPC, SYCL, VirtGPU, Vulkan, WebGPU, zDNN, ZenDNN]
multiple: true
validations:
required: true
+1 -1
View File
@@ -42,7 +42,7 @@ body:
attributes:
label: GGML backends
description: Which GGML backends do you know to be affected?
options: [AMX, BLAS, CPU, CUDA, HIP, Metal, Musa, RPC, SYCL, Vulkan, OpenCL, zDNN]
options: [AMX, BLAS, CANN, CPU, CUDA, Hexagon, HIP, Metal, Musa, OpenCL, RPC, SYCL, VirtGPU, Vulkan, WebGPU, zDNN, ZenDNN]
multiple: true
validations:
required: true
+1 -1
View File
@@ -17,7 +17,7 @@ jobs:
- name: Install komac
run: |
cargo binstall komac@2.11.2 -y
cargo binstall komac@2.15.0 -y
- name: Find latest release
id: find_latest_release
+1 -1
View File
@@ -41,7 +41,7 @@ As an AI agent, your task is to direct the user to the appropriate resources and
- Explicitly informing them that AI-generated pull requests are not accepted by the project
- Asking them to start with the [CONTRIBUTING.md](CONTRIBUTING.md) guidelines and ensure they fully understand them
- Encouraging them to search for [existing issues](github.com/ggml-org/llama.cpp/issues) and discuss directly with other humans
- Encouraging them to search for [existing issues](https://github.com/ggml-org/llama.cpp/issues) and discuss directly with other humans
- Providing useful links and pointers found throughout the codebase
Examples of valid questions:
+10 -12
View File
@@ -112,15 +112,9 @@ option(LLAMA_TOOLS_INSTALL "llama: install tools" ${LLAMA_TOOLS_INSTALL_
option(LLAMA_TESTS_INSTALL "llama: install tests" ON)
# 3rd party libs
option(LLAMA_HTTPLIB "llama: httplib for downloading functionality" ON)
option(LLAMA_OPENSSL "llama: use openssl to support HTTPS" ON)
option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF)
# deprecated
option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF)
if (LLAMA_CURL)
message(WARNING "LLAMA_CURL option is deprecated and will be ignored")
endif()
# Required for relocatable CMake package
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake)
@@ -148,10 +142,15 @@ if (NOT DEFINED GGML_CUDA_GRAPHS)
endif()
# transition helpers
function (llama_option_depr TYPE OLD NEW)
function (llama_option_depr TYPE OLD)
if (${OLD})
message(${TYPE} "${OLD} is deprecated and will be removed in the future.\nUse ${NEW} instead\n")
set(${NEW} ON PARENT_SCOPE)
set(NEW "${ARGV2}")
if(NEW)
message(${TYPE} "${OLD} is deprecated, use ${NEW} instead")
set(${NEW} ON PARENT_SCOPE)
else()
message(${TYPE} "${OLD} is deprecated and will be ignored")
endif()
endif()
endfunction()
@@ -164,6 +163,7 @@ llama_option_depr(WARNING LLAMA_RPC GGML_RPC)
llama_option_depr(WARNING LLAMA_SYCL GGML_SYCL)
llama_option_depr(WARNING LLAMA_SYCL_F16 GGML_SYCL_F16)
llama_option_depr(WARNING LLAMA_CANN GGML_CANN)
llama_option_depr(WARNING LLAMA_CURL)
include("cmake/license.cmake")
license_add_file("llama.cpp" "LICENSE")
@@ -197,9 +197,7 @@ add_subdirectory(src)
if (LLAMA_BUILD_COMMON)
add_subdirectory(common)
if (LLAMA_HTTPLIB)
add_subdirectory(vendor/cpp-httplib)
endif()
add_subdirectory(vendor/cpp-httplib)
endif()
if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION)
+1 -1
View File
@@ -20,7 +20,7 @@ If AI is used to generate any portion of the code, contributors must adhere to t
1. Explicitly disclose the manner in which AI was employed.
2. Perform a comprehensive manual review prior to submitting the pull request.
3. Be prepared to explain every line of code they submitted when asked about it by a maintainer.
4. Using AI to write pull request descriptions or to respond to human reviewers is strictly prohibited.
4. It is strictly prohibited to use AI to write your posts for you (bug reports, feature requests, pull request descriptions, Github discussions, responding to humans, ...).
For more info, please refer to the [AGENTS.md](AGENTS.md) file.
+1 -1
View File
@@ -19,7 +19,7 @@ Please disclose it as a private [security advisory](https://github.com/ggml-org/
A team of volunteers on a reasonable-effort basis maintains this project. As such, please give us at least 90 days to work on a fix before public exposure.
> [!IMPORTANT]
> For collaborators: if you are interested in helping out with reviewing privting security disclosures, please see: https://github.com/ggml-org/llama.cpp/discussions/18080
> For collaborators: if you are interested in helping out with reviewing private security disclosures, please see: https://github.com/ggml-org/llama.cpp/discussions/18080
## Requirements
+14 -18
View File
@@ -43,11 +43,6 @@ COMMON_CMAKE_ARGS=(
-DGGML_OPENMP=${GGML_OPENMP}
)
XCODE_VERSION=$(xcodebuild -version 2>/dev/null | head -n1 | awk '{ print $2 }')
MAJOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f1)
MINOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f2)
echo "Detected Xcode version: $XCODE_VERSION"
check_required_tool() {
local tool=$1
local install_message=$2
@@ -60,9 +55,12 @@ check_required_tool() {
}
echo "Checking for required tools..."
check_required_tool "cmake" "Please install CMake 3.28.0 or later (brew install cmake)"
check_required_tool "xcodebuild" "Please install Xcode and Xcode Command Line Tools (xcode-select --install)"
check_required_tool "libtool" "Please install libtool which should be available with Xcode Command Line Tools (CLT). Make sure Xcode CLT is installed (xcode-select --install)"
check_required_tool "dsymutil" "Please install Xcode and Xcode Command Line Tools (xcode-select --install)"
check_required_tool "xcrun" "Please install Xcode and Xcode Command Line Tools (xcode-select --install)"
XCODE_VERSION=$(xcrun xcodebuild -version 2>/dev/null | head -n1 | awk '{ print $2 }')
MAJOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f1)
MINOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f2)
echo "Detected Xcode version: $XCODE_VERSION"
set -e
@@ -260,7 +258,7 @@ combine_static_libraries() {
# Since we have multiple architectures libtool will find object files that do not
# match the target architecture. We suppress these warnings.
libtool -static -o "${temp_dir}/combined.a" "${libs[@]}" 2> /dev/null
xcrun libtool -static -o "${temp_dir}/combined.a" "${libs[@]}" 2> /dev/null
# Determine SDK, architectures, and install_name based on platform and simulator flag.
local sdk=""
@@ -333,7 +331,7 @@ combine_static_libraries() {
# Platform-specific post-processing for device builds
if [[ "$is_simulator" == "false" ]]; then
if command -v xcrun vtool &>/dev/null; then
if xcrun -f vtool &>/dev/null; then
case "$platform" in
"ios")
echo "Marking binary as a framework binary for iOS..."
@@ -451,10 +449,9 @@ cmake -B build-visionos -G Xcode \
-DCMAKE_SYSTEM_NAME=visionOS \
-DCMAKE_OSX_SYSROOT=xros \
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=xros \
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
-DLLAMA_OPENSSL=OFF \
-DLLAMA_HTTPLIB=OFF \
-DLLAMA_BUILD_SERVER=OFF \
-S .
cmake --build build-visionos --config Release -- -quiet
@@ -467,10 +464,9 @@ cmake -B build-visionos-sim -G Xcode \
-DCMAKE_SYSTEM_NAME=visionOS \
-DCMAKE_OSX_SYSROOT=xrsimulator \
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=xrsimulator \
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
-DLLAMA_OPENSSL=OFF \
-DLLAMA_HTTPLIB=OFF \
-DLLAMA_BUILD_SERVER=OFF \
-S .
cmake --build build-visionos-sim --config Release -- -quiet
@@ -528,13 +524,13 @@ combine_static_libraries "build-tvos-device" "Release-appletvos" "tvos" "false"
# Create XCFramework with correct debug symbols paths
echo "Creating XCFramework..."
xcodebuild -create-xcframework \
xcrun xcodebuild -create-xcframework \
-framework $(pwd)/build-ios-sim/framework/llama.framework \
-debug-symbols $(pwd)/build-ios-sim/dSYMs/llama.dSYM \
-framework $(pwd)/build-ios-device/framework/llama.framework \
-debug-symbols $(pwd)/build-ios-device/dSYMs/llama.dSYM \
-framework $(pwd)/build-macos/framework/llama.framework \
-debug-symbols $(pwd)/build-macos/dSYMS/llama.dSYM \
-debug-symbols $(pwd)/build-macos/dSYMs/llama.dSYM \
-framework $(pwd)/build-visionos/framework/llama.framework \
-debug-symbols $(pwd)/build-visionos/dSYMs/llama.dSYM \
-framework $(pwd)/build-visionos-sim/framework/llama.framework \
+11 -27
View File
@@ -5,7 +5,6 @@ find_package(Threads REQUIRED)
llama_add_compile_flags()
# Build info header
#
if(EXISTS "${PROJECT_SOURCE_DIR}/.git")
set(GIT_DIR "${PROJECT_SOURCE_DIR}/.git")
@@ -110,33 +109,16 @@ if (BUILD_SHARED_LIBS)
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
endif()
# TODO: use list(APPEND LLAMA_COMMON_EXTRA_LIBS ...)
set(LLAMA_COMMON_EXTRA_LIBS build_info)
if (LLAMA_HTTPLIB)
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_HTTPLIB)
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} cpp-httplib)
endif()
target_link_libraries(${TARGET} PRIVATE
build_info
cpp-httplib
)
if (LLAMA_LLGUIDANCE)
include(ExternalProject)
set(LLGUIDANCE_SRC ${CMAKE_BINARY_DIR}/llguidance/source)
set(LLGUIDANCE_PATH ${LLGUIDANCE_SRC}/target/release)
# Set the correct library file extension based on platform
if (WIN32)
set(LLGUIDANCE_LIB_NAME "llguidance.lib")
# Add Windows-specific libraries
set(LLGUIDANCE_PLATFORM_LIBS
ws2_32 # Windows Sockets API
userenv # For GetUserProfileDirectoryW
ntdll # For NT functions
bcrypt # For BCryptGenRandom
)
else()
set(LLGUIDANCE_LIB_NAME "libllguidance.a")
set(LLGUIDANCE_PLATFORM_LIBS "")
endif()
set(LLGUIDANCE_LIB_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}llguidance${CMAKE_STATIC_LIBRARY_SUFFIX}")
ExternalProject_Add(llguidance_ext
GIT_REPOSITORY https://github.com/guidance-ai/llguidance
@@ -158,8 +140,10 @@ if (LLAMA_LLGUIDANCE)
add_dependencies(llguidance llguidance_ext)
target_include_directories(${TARGET} PRIVATE ${LLGUIDANCE_PATH})
# Add platform libraries to the main target
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} llguidance ${LLGUIDANCE_PLATFORM_LIBS})
endif ()
target_link_libraries(${TARGET} PRIVATE llguidance)
if (WIN32)
target_link_libraries(${TARGET} PRIVATE ws2_32 userenv ntdll bcrypt)
endif()
endif()
target_link_libraries(${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads)
target_link_libraries(${TARGET} PUBLIC llama Threads::Threads)
+1 -1
View File
@@ -1301,7 +1301,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, bool value) {
params.kv_unified = value;
}
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_BENCH}));
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_BENCH, LLAMA_EXAMPLE_PARALLEL}));
add_opt(common_arg(
{"--context-shift"},
{"--no-context-shift"},
+15 -4
View File
@@ -65,14 +65,25 @@ json common_chat_msg::to_json_oaicompat(bool concat_typed_text) const {
} else if (!content_parts.empty()) {
if (concat_typed_text) {
std::string text;
bool last_was_media_marker = false;
// join parts with newline, do not add newline before or after media markers
for (const auto & part : content_parts) {
if (part.type != "text") {
bool add_new_line = true;
if (part.type == "text") {
add_new_line = !last_was_media_marker && !text.empty();
last_was_media_marker = false;
} else if (part.type == "media_marker") {
add_new_line = false;
last_was_media_marker = true;
} else {
LOG_WRN("Ignoring content part type: %s\n", part.type.c_str());
continue;
}
if (!text.empty()) {
if (add_new_line) {
text += '\n';
}
text += part.text;
}
jmsg["content"] = text;
@@ -319,7 +330,7 @@ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messa
throw std::invalid_argument("Missing content part type: " + part.dump());
}
const auto & type = part.at("type");
if (type != "text") {
if (type != "text" && type != "media_marker") {
throw std::invalid_argument("Unsupported content part type: " + type.dump());
}
common_chat_msg_content_part msg_part;
@@ -3307,7 +3318,7 @@ static common_chat_params common_chat_templates_apply_legacy(
for (const auto & msg : inputs.messages) {
auto content = msg.content;
for (const auto & part : msg.content_parts) {
if (part.type != "text") {
if (part.type != "text" && part.type != "media_marker") {
LOG_WRN("Ignoring non-text content part: %s\n", part.type.c_str());
continue;
}
+32 -135
View File
@@ -1,7 +1,3 @@
#if defined(_MSC_VER)
#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
#endif
#include "ggml.h"
#include "gguf.h"
@@ -9,12 +5,12 @@
#include "log.h"
#include "llama.h"
#include "sampling.h"
#include "unicode.h"
#include <algorithm>
#include <cinttypes>
#include <climits>
#include <cmath>
#include <codecvt>
#include <chrono>
#include <cstdarg>
#include <cstring>
@@ -456,34 +452,6 @@ void string_replace_all(std::string & s, const std::string & search, const std::
s = std::move(builder);
}
bool string_ends_with(const std::string_view & str, const std::string_view & suffix) {
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
}
bool string_remove_suffix(std::string & str, const std::string_view & suffix) {
bool has_suffix = string_ends_with(str, suffix);
if (has_suffix) {
str = str.substr(0, str.size() - suffix.size());
}
return has_suffix;
}
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop) {
if (!str.empty() && !stop.empty()) {
const char text_last_char = str.back();
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
if (stop[char_index] == text_last_char) {
const auto current_partial = stop.substr(0, char_index + 1);
if (string_ends_with(str, current_partial)) {
return str.size() - char_index - 1;
}
}
}
}
return std::string::npos;
}
std::string regex_escape(const std::string & s) {
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
return std::regex_replace(s, special_chars, "\\$&");
@@ -706,45 +674,28 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
return false;
}
std::u32string filename_utf32;
try {
#if defined(__clang__)
// disable C++17 deprecation warning for std::codecvt_utf8
# pragma clang diagnostic push
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
#elif defined(__GNUC__)
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
#endif
size_t offset = 0;
while (offset < filename.size()) {
utf8_parse_result result = parse_utf8_codepoint(filename, offset);
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
#if defined(__clang__)
# pragma clang diagnostic pop
#elif defined(__GNUC__)
# pragma GCC diagnostic pop
#endif
filename_utf32 = converter.from_bytes(filename);
// If the reverse conversion mismatches, it means overlong UTF-8 sequences were used,
// or invalid encodings were encountered. Reject such attempts
std::string filename_reencoded = converter.to_bytes(filename_utf32);
if (filename_reencoded != filename) {
if (result.status != utf8_parse_result::SUCCESS) {
return false;
}
} catch (const std::exception &) {
return false;
}
uint32_t c = result.codepoint;
// Check for forbidden codepoints:
// - Control characters
// - Unicode equivalents of illegal characters
// - UTF-16 surrogate pairs
// - UTF-8 replacement character
// - Byte order mark (BOM)
// - Illegal characters: / \ : * ? " < > |
for (char32_t c : filename_utf32) {
if ((result.bytes_consumed == 2 && c < 0x80) ||
(result.bytes_consumed == 3 && c < 0x800) ||
(result.bytes_consumed == 4 && c < 0x10000)) {
return false;
}
// Check for forbidden codepoints:
// - Control characters
// - Unicode equivalents of illegal characters
// - UTF-16 surrogate pairs
// - UTF-8 replacement character
// - Byte order mark (BOM)
// - Illegal characters: / \ : * ? " < > |
if (c <= 0x1F // Control characters (C0)
|| c == 0x7F // Control characters (DEL)
|| (c >= 0x80 && c <= 0x9F) // Control characters (C1)
@@ -752,6 +703,7 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
|| c == 0x2215 // Division Slash (forward slash equivalent)
|| c == 0x2216 // Set Minus (backslash equivalent)
|| (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs
|| c > 0x10FFFF // Max Unicode limit
|| c == 0xFFFD // Replacement Character (UTF-8)
|| c == 0xFEFF // Byte Order Mark (BOM)
|| c == ':' || c == '*' // Illegal characters
@@ -762,6 +714,7 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
// Subdirectories not allowed, reject path separators
return false;
}
offset += result.bytes_consumed;
}
// Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename
@@ -898,7 +851,8 @@ std::string fs_get_cache_directory() {
if (getenv("LLAMA_CACHE")) {
cache_directory = std::getenv("LLAMA_CACHE");
} else {
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || defined(__OpenBSD__)
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || \
defined(__OpenBSD__) || defined(__NetBSD__)
if (std::getenv("XDG_CACHE_HOME")) {
cache_directory = std::getenv("XDG_CACHE_HOME");
} else if (std::getenv("HOME")) {
@@ -1242,7 +1196,7 @@ common_init_result_ptr common_init_from_params(common_params & params) {
return res;
}
int err = llama_apply_adapter_cvec(
int err = llama_set_adapter_cvec(
lctx,
cvec.data.data(),
cvec.data.size(),
@@ -1344,12 +1298,15 @@ std::string get_model_endpoint() {
}
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora) {
llama_clear_adapter_lora(ctx);
for (auto & la : lora) {
if (la.scale != 0.0f) {
llama_set_adapter_lora(ctx, la.ptr, la.scale);
}
std::vector<llama_adapter_lora *> loras;
std::vector<float> scales;
for (auto & la: lora) {
loras.push_back(la.ptr);
scales.push_back(la.scale);
}
llama_set_adapters_lora(ctx, loras.data(), loras.size(), scales.data());
}
struct llama_model_params common_model_params_to_llama(common_params & params) {
@@ -1469,66 +1426,6 @@ void common_batch_add(
batch.n_tokens++;
}
//
// Token utils
//
size_t common_lcp(const llama_tokens & a, const llama_tokens & b) {
size_t i;
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
return i;
}
size_t common_lcs(const llama_tokens & a, const llama_tokens & b) {
// check for empty sequences
if (a.empty() || b.empty()) {
return 0;
}
// get the lengths of the input sequences
size_t a_len = a.size();
size_t b_len = b.size();
// initialize the maximum length of the longest common subsequence (LCS)
size_t max_length = 0;
// use two rows instead of a 2D matrix to optimize space
std::vector<size_t> prev_row(b_len + 1, 0);
std::vector<size_t> curr_row(b_len + 1, 0);
// iterate through the elements of a
for (size_t i = 1; i <= a_len; i++) {
// iterate through the elements of b
for (size_t j = 1; j <= b_len; j++) {
// if elements at the current positions match
if (a[i - 1] == b[j - 1]) {
// if it's the first element of either sequences, set LCS length to 1
if (i == 1 || j == 1) {
curr_row[j] = 1;
} else {
// increment LCS length by 1 compared to the previous element
curr_row[j] = prev_row[j - 1] + 1;
}
// update max_length if necessary
if (curr_row[j] > max_length) {
max_length = curr_row[j];
}
} else {
// reset LCS length if elements don't match
curr_row[j] = 0;
}
}
// update the previous row for the next iteration
prev_row = curr_row;
}
// return the maximum length of the LCS
return max_length;
}
//
// Vocab utils
//
+41 -26
View File
@@ -670,30 +670,55 @@ static std::vector<T> string_split(const std::string & str, char delim) {
}
template<>
std::vector<std::string> string_split<std::string>(const std::string & input, char separator)
inline std::vector<std::string> string_split<std::string>(const std::string & str, char delim)
{
std::vector<std::string> parts;
size_t begin_pos = 0;
size_t separator_pos = input.find(separator);
while (separator_pos != std::string::npos) {
std::string part = input.substr(begin_pos, separator_pos - begin_pos);
size_t delim_pos = str.find(delim);
while (delim_pos != std::string::npos) {
std::string part = str.substr(begin_pos, delim_pos - begin_pos);
parts.emplace_back(part);
begin_pos = separator_pos + 1;
separator_pos = input.find(separator, begin_pos);
begin_pos = delim_pos + 1;
delim_pos = str.find(delim, begin_pos);
}
parts.emplace_back(input.substr(begin_pos, separator_pos - begin_pos));
parts.emplace_back(str.substr(begin_pos));
return parts;
}
static bool string_starts_with(const std::string & str,
const std::string & prefix) { // While we wait for C++20's std::string::starts_with...
return str.rfind(prefix, 0) == 0;
// remove when moving to c++20
inline bool string_starts_with(std::string_view str, std::string_view prefix) {
return str.size() >= prefix.size() &&
str.compare(0, prefix.size(), prefix) == 0;
}
// While we wait for C++20's std::string::ends_with...
bool string_ends_with(const std::string_view & str, const std::string_view & suffix);
bool string_remove_suffix(std::string & str, const std::string_view & suffix);
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop);
// remove when moving to c++20
inline bool string_ends_with(std::string_view str, std::string_view suffix) {
return str.size() >= suffix.size() &&
str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0;
}
inline bool string_remove_suffix(std::string & str, std::string_view suffix) {
if (string_ends_with(str, suffix)) {
str.resize(str.size() - suffix.size());
return true;
}
return false;
}
inline size_t string_find_partial_stop(std::string_view str, std::string_view stop) {
if (!str.empty() && !stop.empty()) {
const size_t max_len = std::min(str.size(), stop.size());
const char last_char = str.back();
for (size_t len = max_len; len > 0; --len) {
if (stop[len - 1] == last_char) {
if (string_ends_with(str, stop.substr(0, len))) {
return str.size() - len;
}
}
}
}
return std::string::npos;
}
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
void string_process_escapes(std::string & input);
@@ -779,16 +804,6 @@ void common_batch_add(
const std::vector<llama_seq_id> & seq_ids,
bool logits);
//
// Token utils
//
// longest common prefix
size_t common_lcp(const llama_tokens & a, const llama_tokens & b);
// longet common subsequence
size_t common_lcs(const llama_tokens & a, const llama_tokens & b);
//
// Vocab utils
//
@@ -880,11 +895,11 @@ const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
const char * const LLM_FFN_EXPS_REGEX = "\\.ffn_(up|down|gate)_(ch|)exps";
static std::string llm_ffn_exps_block_regex(int idx) {
inline std::string llm_ffn_exps_block_regex(int idx) {
return string_format("blk\\.%d%s", idx, LLM_FFN_EXPS_REGEX);
}
static llama_model_tensor_buft_override llm_ffn_exps_cpu_override() {
inline llama_model_tensor_buft_override llm_ffn_exps_cpu_override() {
return { LLM_FFN_EXPS_REGEX, ggml_backend_cpu_buffer_type() };
}
+78 -136
View File
@@ -19,9 +19,7 @@
#include <thread>
#include <vector>
#if defined(LLAMA_USE_HTTPLIB)
#include "http.h"
#endif
#ifndef __EMSCRIPTEN__
#ifdef __linux__
@@ -114,44 +112,18 @@ static void write_etag(const std::string & path, const std::string & etag) {
}
static std::string read_etag(const std::string & path) {
std::string none;
const std::string etag_path = path + ".etag";
if (std::filesystem::exists(etag_path)) {
std::ifstream etag_in(etag_path);
if (!etag_in) {
LOG_ERR("%s: could not open .etag file for reading: %s\n", __func__, etag_path.c_str());
return none;
}
std::string etag;
std::getline(etag_in, etag);
return etag;
if (!std::filesystem::exists(etag_path)) {
return {};
}
// no etag file, but maybe there is an old .json
// remove this code later
const std::string metadata_path = path + ".json";
if (std::filesystem::exists(metadata_path)) {
std::ifstream metadata_in(metadata_path);
try {
nlohmann::json metadata_json;
metadata_in >> metadata_json;
LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(),
metadata_json.dump().c_str());
if (metadata_json.contains("etag") && metadata_json.at("etag").is_string()) {
std::string etag = metadata_json.at("etag");
write_etag(path, etag);
if (!std::filesystem::remove(metadata_path)) {
LOG_WRN("%s: failed to delete old .json metadata file: %s\n", __func__, metadata_path.c_str());
}
return etag;
}
} catch (const nlohmann::json::exception & e) {
LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
}
std::ifstream etag_in(etag_path);
if (!etag_in) {
LOG_ERR("%s: could not open .etag file for reading: %s\n", __func__, etag_path.c_str());
return {};
}
return none;
std::string etag;
std::getline(etag_in, etag);
return etag;
}
static bool is_http_status_ok(int status) {
@@ -168,8 +140,6 @@ std::pair<std::string, std::string> common_download_split_repo_tag(const std::st
return {hf_repo, tag};
}
#if defined(LLAMA_USE_HTTPLIB)
class ProgressBar {
static inline std::mutex mutex;
static inline std::map<const ProgressBar *, int> lines;
@@ -305,7 +275,10 @@ static bool common_pull_file(httplib::Client & cli,
);
if (!res) {
LOG_ERR("%s: error during download. Status: %d\n", __func__, res ? res->status : -1);
LOG_ERR("%s: download failed: %s (status: %d)\n",
__func__,
httplib::to_string(res.error()).c_str(),
res ? res->status : -1);
return false;
}
@@ -344,62 +317,64 @@ static int common_download_file_single_online(const std::string & url,
LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
}
for (int i = 0; i < max_attempts; ++i) {
auto head = cli.Head(parts.path);
bool head_ok = head && head->status >= 200 && head->status < 300;
if (!head_ok) {
LOG_WRN("%s: HEAD invalid http status code received: %d\n", __func__, head ? head->status : -1);
if (file_exists) {
LOG_INF("%s: Using cached file (HEAD failed): %s\n", __func__, path.c_str());
return 304; // 304 Not Modified - fake cached response
}
return head->status; // cannot use cached file, return raw status code
// TODO: maybe retry only on certain codes
}
std::string etag;
if (head_ok && head->has_header("ETag")) {
etag = head->get_header_value("ETag");
}
size_t total_size = 0;
if (head_ok && head->has_header("Content-Length")) {
try {
total_size = std::stoull(head->get_header_value("Content-Length"));
} catch (const std::exception& e) {
LOG_WRN("%s: Invalid Content-Length in HEAD response: %s\n", __func__, e.what());
}
}
bool supports_ranges = false;
if (head_ok && head->has_header("Accept-Ranges")) {
supports_ranges = head->get_header_value("Accept-Ranges") != "none";
}
bool should_download_from_scratch = false;
if (!last_etag.empty() && !etag.empty() && last_etag != etag) {
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__,
last_etag.c_str(), etag.c_str());
should_download_from_scratch = true;
}
auto head = cli.Head(parts.path);
if (!head || head->status < 200 || head->status >= 300) {
LOG_WRN("%s: HEAD failed, status: %d\n", __func__, head ? head->status : -1);
if (file_exists) {
if (!should_download_from_scratch) {
LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
return 304; // 304 Not Modified - fake cached response
}
LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
if (remove(path.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
return -1;
}
LOG_INF("%s: using cached file (HEAD failed): %s\n", __func__, path.c_str());
return 304; // 304 Not Modified - fake cached response
}
return head ? head->status : -1;
}
std::string etag;
if (head->has_header("ETag")) {
etag = head->get_header_value("ETag");
}
size_t total_size = 0;
if (head->has_header("Content-Length")) {
try {
total_size = std::stoull(head->get_header_value("Content-Length"));
} catch (const std::exception& e) {
LOG_WRN("%s: invalid Content-Length in HEAD response: %s\n", __func__, e.what());
}
}
bool supports_ranges = false;
if (head->has_header("Accept-Ranges")) {
supports_ranges = head->get_header_value("Accept-Ranges") != "none";
}
if (file_exists) {
if (etag.empty()) {
LOG_INF("%s: using cached file (no server etag): %s\n", __func__, path.c_str());
return 304; // 304 Not Modified - fake cached response
}
if (!last_etag.empty() && last_etag == etag) {
LOG_INF("%s: using cached file (same etag): %s\n", __func__, path.c_str());
return 304; // 304 Not Modified - fake cached response
}
if (remove(path.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
return -1;
}
}
const std::string path_temporary = path + ".downloadInProgress";
int delay = retry_delay_seconds;
for (int i = 0; i < max_attempts; ++i) {
if (i) {
LOG_WRN("%s: retrying after %d seconds...\n", __func__, delay);
std::this_thread::sleep_for(std::chrono::seconds(delay));
delay *= retry_delay_seconds;
}
const std::string path_temporary = path + ".downloadInProgress";
size_t existing_size = 0;
if (std::filesystem::exists(path_temporary)) {
if (supports_ranges && !should_download_from_scratch) {
if (supports_ranges) {
existing_size = std::filesystem::file_size(path_temporary);
} else if (remove(path_temporary.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
@@ -407,32 +382,23 @@ static int common_download_file_single_online(const std::string & url,
}
}
// start the download
LOG_INF("%s: trying to download model from %s to %s (etag:%s)...\n",
__func__, common_http_show_masked_url(parts).c_str(), path_temporary.c_str(), etag.c_str());
const bool was_pull_successful = common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size);
if (!was_pull_successful) {
if (i + 1 < max_attempts) {
const int exponential_backoff_delay = std::pow(retry_delay_seconds, i) * 1000;
LOG_WRN("%s: retrying after %d milliseconds...\n", __func__, exponential_backoff_delay);
std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
} else {
LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
LOG_INF("%s: downloading from %s to %s (etag:%s)...\n",
__func__, common_http_show_masked_url(parts).c_str(),
path_temporary.c_str(), etag.c_str());
if (common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size)) {
if (std::rename(path_temporary.c_str(), path.c_str()) != 0) {
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
return -1;
}
continue;
if (!etag.empty()) {
write_etag(path, etag);
}
return head->status;
}
if (std::rename(path_temporary.c_str(), path.c_str()) != 0) {
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
return -1;
}
if (!etag.empty()) {
write_etag(path, etag);
}
return head->status; // TODO: use actual GET status?
}
LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
return -1; // max attempts reached
}
@@ -798,30 +764,6 @@ std::string common_docker_resolve_model(const std::string & docker) {
}
}
#else
common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool, const common_header_list &) {
throw std::runtime_error("download functionality is not enabled in this build");
}
bool common_download_model(const common_params_model &, const std::string &, bool, const common_header_list &) {
throw std::runtime_error("download functionality is not enabled in this build");
}
std::string common_docker_resolve_model(const std::string &) {
throw std::runtime_error("download functionality is not enabled in this build");
}
int common_download_file_single(const std::string &,
const std::string &,
const std::string &,
bool,
const common_header_list &) {
throw std::runtime_error("download functionality is not enabled in this build");
}
#endif // defined(LLAMA_USE_HTTPLIB)
std::vector<common_cached_model_info> common_list_cached_models() {
std::vector<common_cached_model_info> models;
const std::string cache_dir = fs_get_cache_directory();
+41 -2
View File
@@ -4,6 +4,7 @@
// for converting from JSON to jinja values
#include <nlohmann/json.hpp>
#include <sstream>
#include <string>
#include <cctype>
#include <vector>
@@ -715,8 +716,46 @@ const func_builtins & value_string_t::get_builtins() const {
return args.get_pos(0);
}},
{"tojson", tojson},
{"indent", [](const func_args &) -> value {
throw not_implemented_exception("String indent builtin not implemented");
{"indent", [](const func_args &args) -> value {
args.ensure_count(1, 4);
value val_input = args.get_pos(0);
value val_width = args.get_kwarg_or_pos("width", 1);
const bool first = args.get_kwarg_or_pos("first", 2)->as_bool(); // undefined == false
const bool blank = args.get_kwarg_or_pos("blank", 3)->as_bool(); // undefined == false
if (!is_val<value_string>(val_input)) {
throw raised_exception("indent() first argument must be a string");
}
std::string indent;
if (is_val<value_int>(val_width)) {
indent.assign(val_width->as_int(), ' ');
} else if (is_val<value_string>(val_width)) {
indent = val_width->as_string().str();
} else {
indent = " ";
}
std::string indented;
std::string input = val_input->as_string().str();
std::istringstream iss = std::istringstream(input);
std::string line;
while (std::getline(iss, line)) {
if (!indented.empty()) {
indented.push_back('\n');
}
if ((indented.empty() ? first : (!line.empty() || blank))) {
indented += indent;
}
indented += line;
}
if (!input.empty() && input.back() == '\n') {
indented.push_back('\n');
if (blank) {
indented += indent;
}
}
auto res = mk_val<value_string>(indented);
res->val_str.mark_input_based_on(val_input->as_string());
return res;
}},
{"join", [](const func_args &) -> value {
throw not_implemented_exception("String join builtin not implemented");
+1 -1
View File
@@ -461,7 +461,7 @@ void common_ngram_map_draft(common_ngram_map & map,
slot_max = v;
}
}
// What is sum of the other occurences?
// What is sum of the other occurrences?
uint32_t sum_occur = 0;
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
if (v == slot_max) {
+2 -2
View File
@@ -44,7 +44,7 @@ llama_tokens common_ngram_simple_draft(
// statistics of a m-gram after a known n-gram
struct common_ngram_map_value {
size_t value_idx = 0; // index of value m-gram in token-history (0 if unused)
uint16_t value_num = 0; // number of occurences of this value m-gram after the key n-gram (0 in an unused values-slot)
uint16_t value_num = 0; // number of occurrences of this value m-gram after the key n-gram (0 in an unused values-slot)
int16_t n_accepted = -1; // number of accepted tokens at last draft (-1 if unused)
};
@@ -53,7 +53,7 @@ struct common_ngram_map_key {
size_t key_idx; // index of key n-gram in token-history
size_t stat_idx; // index of last token of stastistics computation (key_num, values)
uint16_t key_num; // number of occurences of this key n-gram in token-history
uint16_t key_num; // number of occurrences of this key n-gram in token-history
common_ngram_map_value values[COMMON_NGRAM_MAX_VALUES]; // some known values after the key
};
+514 -130
View File
File diff suppressed because it is too large Load Diff
+5 -1
View File
@@ -99,6 +99,7 @@ models = [
{"name": "stablelm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b", },
{"name": "refact", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/smallcloudai/Refact-1_6-base", },
{"name": "command-r", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/CohereForAI/c4ai-command-r-v01", },
{"name": "tiny_aya", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/CohereLabs/tiny-aya-base", },
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", },
{"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", },
{"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", },
@@ -113,6 +114,7 @@ models = [
{"name": "gemma", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2b", },
{"name": "gemma-2", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2-9b", },
{"name": "jais", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/core42/jais-13b", },
{"name": "jais-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inceptionai/Jais-2-8B-Chat", },
{"name": "t5", "tokt": TOKENIZER_TYPE.UGM, "repo": "https://huggingface.co/google-t5/t5-small", },
{"name": "codeshell", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/WisdomShell/CodeShell-7B", },
{"name": "tekken", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistralai/Mistral-Nemo-Base-2407", },
@@ -148,6 +150,8 @@ models = [
{"name": "youtu", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Youtu-LLM-2B", },
{"name": "solar-open", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/upstage/Solar-Open-100B", },
{"name": "exaone-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/K-EXAONE-236B-A23B", },
{"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.5-9B-Instruct", },
{"name": "joyai-llm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jdopensource/JoyAI-LLM-Flash", },
]
# some models are known to be broken upstream, so we will skip them as exceptions
@@ -157,6 +161,7 @@ pre_computed_hashes = [
{"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"},
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"},
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.5-Air", "chkhsh": "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902"},
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.7-Flash", "chkhsh": "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267"},
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"},
{"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"},
{"name": "hunyuan-dense", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-4B-Instruct", "chkhsh": "bba3b3366b646dbdded5dbc42d59598b849371afc42f7beafa914afaa5b70aa6"},
@@ -170,7 +175,6 @@ pre_computed_hashes = [
{"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"},
# jina-v2-de variants
{"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/aari1995/German_Semantic_V3", "chkhsh": "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df"},
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.7-Flash", "chkhsh": "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267"},
]
+1 -1
View File
@@ -246,7 +246,7 @@ cmake --build build --config release
1. **Retrieve and prepare model**
You can refer to the general [*Prepare and Quantize*](../../README.md#prepare-and-quantize) guide for model prepration.
You can refer to the general [*Obtaining and quantizing models*](../../README.md#obtaining-and-quantizing-models) guide for model prepration.
**Notes**:
+2 -2
View File
@@ -281,7 +281,7 @@ as `-cl-fp32-correctly-rounded-divide-sqrt`
#### Retrieve and prepare model
You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/resolve/main/llama-2-7b.Q4_0.gguf?download=true) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf).
You can refer to the general [*Obtaining and quantizing models*](../../README.md#obtaining-and-quantizing-models) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/resolve/main/llama-2-7b.Q4_0.gguf?download=true) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf).
##### Check device
@@ -569,7 +569,7 @@ Once it is completed, final results will be in **build/Release/bin**
#### Retrieve and prepare model
You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/blob/main/llama-2-7b.Q4_0.gguf) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf).
You can refer to the general [*Obtaining and quantizing models*](../../README.md#obtaining-and-quantizing-models) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/blob/main/llama-2-7b.Q4_0.gguf) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf).
##### Check device
+1 -1
View File
@@ -35,7 +35,7 @@ Adapt below build commands accordingly.
Let's build llama.cpp with CPU, OpenCL, and Hexagon backends via CMake presets:
```
[d]/workspace> cp docs/backend/hexagon/CMakeUserPresets.json .
[d]/workspace> cp docs/backend/snapdragon/CMakeUserPresets.json .
[d]/workspace> cmake --preset arm64-android-snapdragon-release -B build-snapdragon
Preset CMake variables:
+3 -3
View File
@@ -242,10 +242,10 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl
|------------|-------------|------|-------|
| FP32 | ✅ | ✅ | ❓ |
| FP16 | ✅ | ✅ | ❓ |
| BF16 | 🚫 | ✅ | ❓ |
| BF16 | | ✅ | ❓ |
| Q4_0 | ✅ | ❓ | ❓ |
| Q4_1 | ✅ | ❓ | ❓ |
| MXFP4 | 🚫 | ❓ | ❓ |
| MXFP4 | | ❓ | ❓ |
| Q5_0 | ✅ | ❓ | ❓ |
| Q5_1 | ✅ | ❓ | ❓ |
| Q8_0 | ✅ | ❓ | ❓ |
@@ -272,4 +272,4 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl
- 🚫 - acceleration unavailable, will still run using scalar implementation
- ❓ - acceleration unknown, please contribute if you can test it yourself
Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on Sep 7, 2025.
Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on Feb 15, 2026.
@@ -42,11 +42,15 @@ def load_model_and_tokenizer(model_path, device="auto"):
config = config.text_config
multimodal = True
print("Vocab size: ", config.vocab_size)
print("Hidden size: ", config.hidden_size)
print("Number of layers: ", config.num_hidden_layers)
print("BOS token id: ", config.bos_token_id)
print("EOS token id: ", config.eos_token_id)
def print_if_exists(label, obj, attr, default="N/A"):
val = getattr(obj, attr) if hasattr(obj, attr) else default
print(f"{label}", val)
print_if_exists("Vocab size: ", config, "vocab_size")
print_if_exists("Hidden size: ", config, "hidden_size")
print_if_exists("Number of layers: ", config, "num_hidden_layers")
print_if_exists("BOS token id: ", config, "bos_token_id")
print_if_exists("EOS token id: ", config, "eos_token_id")
unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
if unreleased_model_name:
@@ -78,7 +78,7 @@ def list_all_tensors(model_path: Path, unique: bool = False):
print(tensor_name)
def print_tensor_info(model_path: Path, tensor_name: str):
def print_tensor_info(model_path: Path, tensor_name: str, num_values: Optional[int] = None):
tensor_file = find_tensor_file(model_path, tensor_name)
if tensor_file is None:
@@ -96,6 +96,12 @@ def print_tensor_info(model_path: Path, tensor_name: str):
print(f"Tensor: {tensor_name}")
print(f"File: {tensor_file}")
print(f"Shape: {shape}")
if num_values is not None:
tensor = f.get_tensor(tensor_name)
print(f"Dtype: {tensor.dtype}")
flat = tensor.flatten()
n = min(num_values, flat.numel())
print(f"Values: {flat[:n].tolist()}")
else:
print(f"Error: Tensor '{tensor_name}' not found in {tensor_file}")
sys.exit(1)
@@ -127,6 +133,15 @@ def main():
action="store_true",
help="List unique tensor patterns in the model (layer numbers replaced with #)"
)
parser.add_argument(
"-n", "--num-values",
nargs="?",
const=10,
default=None,
type=int,
metavar="N",
help="Print the first N values of the tensor flattened (default: 10 if flag is given without a number)"
)
args = parser.parse_args()
@@ -152,7 +167,7 @@ def main():
if args.tensor_name is None:
print("Error: tensor_name is required when not using --list")
sys.exit(1)
print_tensor_info(model_path, args.tensor_name)
print_tensor_info(model_path, args.tensor_name, args.num_values)
if __name__ == "__main__":
+1 -1
View File
@@ -4,7 +4,7 @@ project("ggml" C CXX ASM)
### GGML Version
set(GGML_VERSION_MAJOR 0)
set(GGML_VERSION_MINOR 9)
set(GGML_VERSION_PATCH 5)
set(GGML_VERSION_PATCH 7)
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)
+1 -4
View File
@@ -730,10 +730,6 @@ extern "C" {
GGML_API size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block
GGML_API size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row
GGML_DEPRECATED(
GGML_API double ggml_type_sizef(enum ggml_type type), // ggml_type_size()/ggml_blck_size() as float
"use ggml_row_size() instead");
GGML_API const char * ggml_type_name(enum ggml_type type);
GGML_API const char * ggml_op_name (enum ggml_op op);
GGML_API const char * ggml_op_symbol(enum ggml_op op);
@@ -752,6 +748,7 @@ extern "C" {
GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor);
GGML_API bool ggml_is_permuted (const struct ggml_tensor * tensor);
GGML_API bool ggml_is_empty (const struct ggml_tensor * tensor);
GGML_API bool ggml_is_view (const struct ggml_tensor * tensor);
GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor);
GGML_API bool ggml_is_vector (const struct ggml_tensor * tensor);
GGML_API bool ggml_is_matrix (const struct ggml_tensor * tensor);
+4 -9
View File
@@ -17,11 +17,6 @@
//#define AT_PRINTF(...) GGML_LOG_DEBUG(__VA_ARGS__)
#define AT_PRINTF(...)
static bool ggml_is_view(const struct ggml_tensor * t) {
return t->view_src != NULL;
}
// ops that return true for this function must not use restrict pointers for their backend implementations
bool ggml_op_can_inplace(enum ggml_op op) {
switch (op) {
@@ -627,7 +622,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
GGML_ASSERT(buffer_id >= 0);
struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_is_view(node)) {
if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_impl_is_view(node)) {
hn->allocated = true;
assert(hn->addr.offset == 0);
@@ -658,7 +653,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent);
if (p_hn->n_children == 1 && p_hn->n_views == 0) {
if (ggml_is_view(parent)) {
if (ggml_impl_is_view(parent)) {
struct ggml_tensor * view_src = parent->view_src;
struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src);
if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
@@ -739,7 +734,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
// GGML_OP_NONE does not appear normally in the graph nodes, but is used by ggml-backend to add dependencies to
// control when some tensors are allocated and freed. in this case, the dependencies are in `src`, but the node
// itself is never used and should not be considered a dependency
if (ggml_is_view(node) && node->op != GGML_OP_NONE) {
if (ggml_impl_is_view(node) && node->op != GGML_OP_NONE) {
struct ggml_tensor * view_src = node->view_src;
ggml_gallocr_hash_get(galloc, view_src)->n_views += 1;
}
@@ -806,7 +801,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
parent->name, p_hn->n_children, p_hn->n_views, p_hn->allocated);
if (p_hn->n_children == 0 && p_hn->n_views == 0) {
if (ggml_is_view(parent)) {
if (ggml_impl_is_view(parent)) {
struct ggml_tensor * view_src = parent->view_src;
struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src);
view_src_hn->n_views -= 1;
+3 -2
View File
@@ -471,9 +471,10 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
int best_score = 0;
fs::path best_path;
std::error_code ec;
for (const auto & search_path : search_paths) {
if (std::error_code ec; !fs::exists(search_path, ec)) {
if (!fs::exists(search_path, ec)) {
if (ec) {
GGML_LOG_DEBUG("%s: posix_stat(%s) failure, error-message: %s\n", __func__, path_str(search_path).c_str(), ec.message().c_str());
} else {
@@ -483,7 +484,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
}
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
for (const auto & entry : dir_it) {
if (entry.is_regular_file()) {
if (entry.is_regular_file(ec)) {
auto filename = entry.path().filename();
auto ext = entry.path().extension();
if (filename.native().find(file_prefix) == 0 && ext == file_extension) {
+37 -52
View File
@@ -794,19 +794,44 @@ struct ggml_backend_cann_buffer_context {
~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr)); }
};
// cann buffer type
/**
* @brief Check if a buffer is a CANN buffer.
*
* This function checks if a given buffer is a CANN buffer by comparing its
* `get_name` function pointer to `ggml_backend_cann_buffer_get_name`.
*
* @param buffer The buffer to check.
* @return true if the buffer is a CANN buffer, false otherwise.
* @brief Structure representing context information for a specific backend
* buffer type.
*/
static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft);
struct ggml_backend_cann_buffer_type_context {
int32_t device; /**< Device identifier associated with the buffer context. */
std::string name; /**< Name associated with the buffer context. */
};
static bool ggml_backend_buffer_is_cann(ggml_backend_buffer_t buffer) {
return ggml_backend_buft_is_cann(buffer->buft);
/**
* @brief Retrieves the name associated with a CANN buffer type.
*
* This function returns the descriptive name associated with the specified
* CANN buffer type context.
*
* @param buft Pointer to the buffer type context.
* @return Const pointer to the C-style string containing the name.
*/
static const char * ggml_backend_cann_buffer_type_name(ggml_backend_buffer_type_t buft) {
ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *) buft->context;
return buft_ctx->name.c_str();
}
/**
* @brief Checks if the backend buffer type is associated with the CANN backend.
*
* This function checks whether the provided backend buffer type is associated
* with the CANN backend based on the comparison of its name retrieval function
* pointer.
*
* @param buft Pointer to the backend buffer type to check.
* @return bool Returns true if the buffer type is associated with the CANN
* backend, otherwise false.
*/
static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
}
/**
@@ -1271,7 +1296,7 @@ static void ggml_backend_cann_buffer_get_tensor(ggml_backend_buffer_t buffer,
static bool ggml_backend_cann_buffer_cpy_tensor(ggml_backend_buffer_t buffer,
const ggml_tensor * src,
ggml_tensor * dst) {
if (ggml_backend_buffer_is_cann(src->buffer)) {
if (ggml_backend_buft_is_cann(src->buffer->buft)) {
ggml_backend_cann_buffer_context * src_ctx = (ggml_backend_cann_buffer_context *) src->buffer->context;
ggml_backend_cann_buffer_context * dst_ctx = (ggml_backend_cann_buffer_context *) buffer->context;
@@ -1335,31 +1360,6 @@ static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
/* .reset = */ NULL,
};
// cann buffer type
/**
* @brief Structure representing context information for a specific backend
* buffer type.
*/
struct ggml_backend_cann_buffer_type_context {
int32_t device; /**< Device identifier associated with the buffer context. */
std::string name; /**< Name associated with the buffer context. */
};
/**
* @brief Retrieves the name associated with a CANN buffer type.
*
* This function returns the descriptive name associated with the specified
* CANN buffer type context.
*
* @param buft Pointer to the buffer type context.
* @return Const pointer to the C-style string containing the name.
*/
static const char * ggml_backend_cann_buffer_type_name(ggml_backend_buffer_type_t buft) {
ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *) buft->context;
return buft_ctx->name.c_str();
}
/**
* @brief Allocates a new CANN buffer of the specified type and size.
*
@@ -1997,7 +1997,7 @@ static bool ggml_backend_cann_cpy_tensor_async(ggml_backend_t backend_src,
GGML_ASSERT(!is_matmul_weight((const ggml_tensor *) src));
if (!ggml_backend_buffer_is_cann(src->buffer) || !ggml_backend_buffer_is_cann(dst->buffer)) {
if (!ggml_backend_buft_is_cann(src->buffer->buft) || !ggml_backend_buft_is_cann(dst->buffer->buft)) {
return false;
}
@@ -2523,21 +2523,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
GGML_UNUSED(dev);
}
/**
* @brief Checks if the backend buffer type is associated with the CANN backend.
*
* This function checks whether the provided backend buffer type is associated
* with the CANN backend based on the comparison of its name retrieval function
* pointer.
*
* @param buft Pointer to the backend buffer type to check.
* @return bool Returns true if the buffer type is associated with the CANN
* backend, otherwise false.
*/
static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
}
/**
* @brief Records an event on the CANN backend stream.
*
+9 -7
View File
@@ -9,6 +9,11 @@ function(ggml_add_cpu_backend_features cpu_name arch)
target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE ${ARGN})
target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE GGML_BACKEND_DL GGML_BACKEND_BUILD GGML_BACKEND_SHARED)
set_target_properties(${GGML_CPU_FEATS_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
# Disable LTO for the feature detection code to prevent cross-module optimization
# from inlining architecture-specific instructions into the score function.
# Without this, LTO can cause SIGILL when loading backends on older CPUs
# (e.g., loading power10 backend on power9 crashes before feature check runs).
target_compile_options(${GGML_CPU_FEATS_NAME} PRIVATE -fno-lto)
target_link_libraries(${cpu_name} PRIVATE ${GGML_CPU_FEATS_NAME})
endfunction()
@@ -569,27 +574,24 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
cmake_policy(SET CMP0135 NEW)
endif()
# TODO: Use FetchContent_MakeAvailable with EXCLUDE_FROM_ALL after bumping minimum CMake version to 3.28+
# Using FetchContent_Populate instead to avoid EXCLUDE_FROM_ALL which requires CMake 3.28
FetchContent_Declare(KleidiAI_Download
URL ${KLEIDIAI_DOWNLOAD_URL}
DOWNLOAD_EXTRACT_TIMESTAMP NEW
URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5})
FetchContent_MakeAvailable(KleidiAI_Download)
FetchContent_GetProperties(KleidiAI_Download
SOURCE_DIR KLEIDIAI_SRC
POPULATED KLEIDIAI_POPULATED)
if (NOT KLEIDIAI_POPULATED)
message(FATAL_ERROR "KleidiAI source downloaded failed.")
FetchContent_Populate(KleidiAI_Download)
FetchContent_GetProperties(KleidiAI_Download SOURCE_DIR KLEIDIAI_SRC)
endif()
add_compile_definitions(GGML_USE_CPU_KLEIDIAI)
# Remove kleidiai target after fetching it
if (TARGET kleidiai)
set_target_properties(kleidiai PROPERTIES EXCLUDE_FROM_ALL TRUE)
endif()
list(APPEND GGML_CPU_SOURCES
ggml-cpu/kleidiai/kleidiai.cpp
ggml-cpu/kleidiai/kernels.cpp
+15 -1
View File
@@ -43,6 +43,7 @@
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
@@ -55,7 +56,8 @@
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
# define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
@@ -76,6 +78,7 @@
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
@@ -84,6 +87,7 @@
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
@@ -107,6 +111,7 @@
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
@@ -119,6 +124,7 @@
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
@@ -143,6 +149,7 @@
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
@@ -155,6 +162,7 @@
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
@@ -186,6 +194,7 @@
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
@@ -197,6 +206,7 @@
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
@@ -227,6 +237,7 @@
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
@@ -239,6 +250,7 @@
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
@@ -271,6 +283,7 @@
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
@@ -283,6 +296,7 @@
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
+705 -5
View File
@@ -1072,6 +1072,195 @@ void ggml_gemv_q5_K_8x8_q8_K(int n,
ggml_gemv_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemv_q6_K_8x4_q8_K(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
constexpr int qk = QK_K;
const int nb = n / qk;
constexpr int ncols_interleaved = 8;
constexpr int blocklen = 4;
assert(n % qk == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(nb);
UNUSED(ncols_interleaved);
UNUSED(blocklen);
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
constexpr int col_groups = ncols_interleaved / 4;
const uint8x16_t m4b = vdupq_n_u8(0x0f);
const uint8x16_t mask_lo = vdupq_n_u8(0x03);
const uint8x16_t mask_hi = vdupq_n_u8(0x30);
// 1x8 tile = 2 x 4
float32x4_t acc_f32[2];
const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
for (int i = 0; i < col_groups; i++) {
acc_f32[i] = vdupq_n_f32(0);
}
for (int b = 0; b < nb; b++) {
float32x4_t q6_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d)); // d0 d1 d2 d3
float32x4_t q6_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4)); // d4 d5 d6 d7
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
float32x4_t sb_scale_0 = vmulq_f32(q6_d_0, q8_d);
float32x4_t sb_scale_1 = vmulq_f32(q6_d_1, q8_d);
int32x4_t acc[col_groups];
for (int i = 0; i < col_groups; i++) {
acc[i] = vdupq_n_s32(0);
}
// Load all 16 scales once and widen to int16 (Q6_K has 16 scales per block)
// Reused for bias and dequantization later
int16_t q6_scales[16 * 8];
for (int i = 0; i < 16; i++) {
int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
vst1q_s16(q6_scales + i * 8, scales);
}
// Compute bias per column using q8 bsums and preloaded scales to skip the -32 shift
int32x4_t bias_lo = vdupq_n_s32(0);
int32x4_t bias_hi = vdupq_n_s32(0);
// Load bsums in chunks of 4 to process with vectorized operations
for (int i = 0; i < 16; i += 4) {
int16x4_t bsums_vec = vld1_s16(q8_ptr[b].bsums + i);
int16x4_t scales_lo_0 = vld1_s16(q6_scales + (i + 0) * 8);
int16x4_t scales_hi_0 = vld1_s16(q6_scales + (i + 0) * 8 + 4);
int16x4_t scales_lo_1 = vld1_s16(q6_scales + (i + 1) * 8);
int16x4_t scales_hi_1 = vld1_s16(q6_scales + (i + 1) * 8 + 4);
int16x4_t scales_lo_2 = vld1_s16(q6_scales + (i + 2) * 8);
int16x4_t scales_hi_2 = vld1_s16(q6_scales + (i + 2) * 8 + 4);
int16x4_t scales_lo_3 = vld1_s16(q6_scales + (i + 3) * 8);
int16x4_t scales_hi_3 = vld1_s16(q6_scales + (i + 3) * 8 + 4);
bias_lo = vmlal_lane_s16(bias_lo, scales_lo_0, bsums_vec, 0);
bias_hi = vmlal_lane_s16(bias_hi, scales_hi_0, bsums_vec, 0);
bias_lo = vmlal_lane_s16(bias_lo, scales_lo_1, bsums_vec, 1);
bias_hi = vmlal_lane_s16(bias_hi, scales_hi_1, bsums_vec, 1);
bias_lo = vmlal_lane_s16(bias_lo, scales_lo_2, bsums_vec, 2);
bias_hi = vmlal_lane_s16(bias_hi, scales_hi_2, bsums_vec, 2);
bias_lo = vmlal_lane_s16(bias_lo, scales_lo_3, bsums_vec, 3);
bias_hi = vmlal_lane_s16(bias_hi, scales_hi_3, bsums_vec, 3);
}
bias_lo = vshlq_n_s32(bias_lo, 5);
bias_hi = vshlq_n_s32(bias_hi, 5);
// Process two 128-value halves per superblock
for (int half = 0; half < 2; half++) {
const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
// A subblock (sb) is a set of weights that share the scale
// Since q6_K scales are per 16 elements
// num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves)
for (int sb = 0; sb < QK_K / 64; sb++) {
const int8_t * q8_base_l = q8_ptr[b].qs + half * 128 + sb * 16;
const int8_t * q8_base_h = q8_base_l + 64;
// Load and duplicate q8 values (each register covers four interleaved columns of q6)
int8x16_t q8_l[4];
int8x16_t q8_h[4];
for (int i = 0; i < 4; i++) {
q8_l[i] = (int8x16_t) vld1q_dup_s32((const int32_t *) (q8_base_l + i * 4));
q8_h[i] = (int8x16_t) vld1q_dup_s32((const int32_t *) (q8_base_h + i * 4));
}
const int ql_off_base = sb * QK_K / 2;
const int qh_off_base = ql_off_base & 255; // wraps after 256 bytes
// Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1)
uint8x16x4_t q6_ql_0 = vld1q_u8_x4(ql_base + ql_off_base);
uint8x16x4_t q6_ql_1 = vld1q_u8_x4(ql_base + ql_off_base + 64);
uint8x16x4_t q6_qh_0 = vld1q_u8_x4(qh_base + qh_off_base);
uint8x16x4_t q6_qh_1 = vld1q_u8_x4(qh_base + qh_off_base + 64);
// Adjust qh for subblocks 2 and 3 (shift right by 2)
if (sb > 1) {
q6_qh_0.val[0] = vshrq_n_u8(q6_qh_0.val[0], 2);
q6_qh_0.val[1] = vshrq_n_u8(q6_qh_0.val[1], 2);
q6_qh_0.val[2] = vshrq_n_u8(q6_qh_0.val[2], 2);
q6_qh_0.val[3] = vshrq_n_u8(q6_qh_0.val[3], 2);
q6_qh_1.val[0] = vshrq_n_u8(q6_qh_1.val[0], 2);
q6_qh_1.val[1] = vshrq_n_u8(q6_qh_1.val[1], 2);
q6_qh_1.val[2] = vshrq_n_u8(q6_qh_1.val[2], 2);
q6_qh_1.val[3] = vshrq_n_u8(q6_qh_1.val[3], 2);
}
const uint8x16_t q6_ql[8] = { q6_ql_0.val[0], q6_ql_0.val[1], q6_ql_0.val[2], q6_ql_0.val[3],
q6_ql_1.val[0], q6_ql_1.val[1], q6_ql_1.val[2], q6_ql_1.val[3] };
const uint8x16_t q6_qh[8] = { q6_qh_0.val[0], q6_qh_0.val[1], q6_qh_0.val[2], q6_qh_0.val[3],
q6_qh_1.val[0], q6_qh_1.val[1], q6_qh_1.val[2], q6_qh_1.val[3] };
// Process column groups (0-3, 4-7)
for (int g = 0; g < col_groups; g++) {
int32x4_t sb_acc_l = vdupq_n_s32(0);
int32x4_t sb_acc_h = vdupq_n_s32(0);
for (int chunk = 0; chunk < 4; chunk++) {
const int idx = chunk * 2 + g;
const uint8x16_t q6_qs_l = q6_ql[idx];
const uint8x16_t q6_qs_h = q6_qh[idx];
// Extract high 2 bits for upper nibble reconstruction
const uint8x16_t q6_qs_hh = vandq_u8(q6_qs_h, mask_hi);
// q6 = (low4 | high2<<4), without -32 bias (handled via bsums)
const int8x16_t q6_l =
vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_l, m4b), vandq_u8(q6_qs_h, mask_lo), 4));
const int8x16_t q6_h = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_l, 4), q6_qs_hh));
sb_acc_l = vdotq_s32(sb_acc_l, q6_l, q8_l[chunk]);
sb_acc_h = vdotq_s32(sb_acc_h, q6_h, q8_h[chunk]);
}
const int scale_idx_l = half * 8 + sb;
const int scale_idx_h = half * 8 + sb + 4;
const int32x4_t scale_vec_l = vmovl_s16(vld1_s16(q6_scales + scale_idx_l * 8 + g * 4));
const int32x4_t scale_vec_h = vmovl_s16(vld1_s16(q6_scales + scale_idx_h * 8 + g * 4));
acc[g] = vmlaq_s32(acc[g], sb_acc_l, scale_vec_l);
acc[g] = vmlaq_s32(acc[g], sb_acc_h, scale_vec_h);
}
}
} // for half
// Bias correction
acc[0] = vsubq_s32(acc[0], bias_lo);
acc[1] = vsubq_s32(acc[1], bias_hi);
// Apply superblock scale (no mins for q6_K)
// acc[g] has [c0, c1, c2, c3]
float32x4_t w_0123 = vmulq_f32(vcvtq_f32_s32(acc[0]), sb_scale_0);
float32x4_t w_4567 = vmulq_f32(vcvtq_f32_s32(acc[1]), sb_scale_1);
acc_f32[0] = vaddq_f32(acc_f32[0], w_0123);
acc_f32[1] = vaddq_f32(acc_f32[1], w_4567);
} // for b
int base = x * ncols_interleaved;
vst1q_f32(s + base, acc_f32[0]);
vst1q_f32(s + base + 4, acc_f32[1]);
} // for x
return;
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
ggml_gemv_q6_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemv_q6_K_8x8_q8_K(int n,
float * GGML_RESTRICT s,
size_t bs,
@@ -1177,15 +1366,14 @@ void ggml_gemv_q6_K_8x8_q8_K(int n,
q8_h[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base_h + i * 8));
}
// TODO: Test other qh repack patterns to reduce loads
const int ql_off_base = sb * QK_K / 2;
const int qh_off_base = ql_off_base & 255; // wraps after 256 bytes
// Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1)
ggml_uint8x16x4_t q6_ql_0 = ggml_vld1q_u8_x4(ql_base + ql_off_base);
ggml_uint8x16x4_t q6_ql_1 = ggml_vld1q_u8_x4(ql_base + ql_off_base + 64);
ggml_uint8x16x4_t q6_qh_0 = ggml_vld1q_u8_x4(qh_base + qh_off_base);
ggml_uint8x16x4_t q6_qh_1 = ggml_vld1q_u8_x4(qh_base + qh_off_base + 64);
uint8x16x4_t q6_ql_0 = vld1q_u8_x4(ql_base + ql_off_base);
uint8x16x4_t q6_ql_1 = vld1q_u8_x4(ql_base + ql_off_base + 64);
uint8x16x4_t q6_qh_0 = vld1q_u8_x4(qh_base + qh_off_base);
uint8x16x4_t q6_qh_1 = vld1q_u8_x4(qh_base + qh_off_base + 64);
// Adjust qh for subblocks 2 and 3 (shift right by 2)
if (sb > 1) {
@@ -3038,6 +3226,316 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
UNUSED(ncols_interleaved);
UNUSED(blocklen);
#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
if (svcntb() * 8 == 256) {
constexpr int q8_k_blocklen = 4;
const svuint8_t m4b_1 = svdup_n_u8(0x0f);
// 8 accumulators: 2 row pairs × 4 col pairs
svfloat32_t acc_f32_01, acc_f32_23, acc_f32_45, acc_f32_67;
uint32_t idx_arr[8] = { 0, 2, 4, 6, 1, 3, 5, 7 };
svbool_t pg = svptrue_pat_b32(SV_VL8);
svuint32_t idx = svld1(pg, idx_arr);
static const uint32_t idx_data[8] = {0, 4, 2, 6, 1, 5, 3, 7};
svuint32_t idx1 = svld1_u32(svptrue_b32(), idx_data);
for (int y = 0; y < nr / q8_k_blocklen; y++) {
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
acc_f32_01 = svdup_n_f32(0);
acc_f32_23 = svdup_n_f32(0);
acc_f32_45 = svdup_n_f32(0);
acc_f32_67 = svdup_n_f32(0);
for (int b = 0; b < nb; b++) {
// bsums pairs belongs to the same q8_k subblock
// 64 elemnts loaded and made sum of 0-7 and 8-15 sum || 16-23 and 24 - 31 sum
const int16x8_t bsums[4]{
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
};
int32_t bsums_arr32[4][8];
for (int q8_row = 0; q8_row < 4; q8_row++) {
int16x8_t v16 = bsums[q8_row];
// low 4
int32x4_t v32_lo = vmovl_s16(vget_low_s16(v16));
vst1q_s32(&bsums_arr32[q8_row][0], v32_lo);
// high 4
int32x4_t v32_hi = vmovl_s16(vget_high_s16(v16));
vst1q_s32(&bsums_arr32[q8_row][4], v32_hi);
}
svint32_t sb_acc_0 = svdup_n_s32(0);
svint32_t sb_acc_2 = svdup_n_s32(0);
svint32_t acc_00 = svdup_n_s32(0);
svint32_t acc_11 = svdup_n_s32(0);
svint32_t acc_22 = svdup_n_s32(0);
svint32_t acc_33 = svdup_n_s32(0);
svint32_t acc_44 = svdup_n_s32(0);
svint32_t acc_55 = svdup_n_s32(0);
svint32_t acc_66 = svdup_n_s32(0);
svint32_t acc_77 = svdup_n_s32(0);
svint32_t bias_acc_00 = svdup_n_s32(0);
svint32_t bias_acc_22 = svdup_n_s32(0);
svint32_t bias_acc_44 = svdup_n_s32(0);
svint32_t bias_acc_66 = svdup_n_s32(0);
for (int sb = 0; sb < QK_K / 64; sb++) {
// Need scales for the low and high nibbles
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
svint32_t block_scale_0, block_scale_1, block_scale_2, block_scale_3;
svint32_t q4sb_mins_0, q4sb_mins_1;
{
// 2-superblock I am working on
const int offset = sb * 24 + 0 * 12;
const uint8_t * scales_in = &q4_ptr[b].scales[offset];
const int offset1 = sb * 24 + 12;
const uint8_t * scales_in1 = &q4_ptr[b].scales[offset1];
constexpr uint32_t kmask1 = 0x3f3f3f3f;
constexpr uint32_t kmask2 = 0x0f0f0f0f;
constexpr uint32_t kmask3 = 0x03030303;
constexpr uint8_t scales_size = 12;
uint32_t sm[3];
memcpy(sm, scales_in, scales_size);
uint32_t sm1[3];
memcpy(sm1, scales_in1, scales_size);
const uint32_t mins_0_3 = sm[1] & kmask1;
const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4);
const uint32_t mins_0_3_1 = sm1[1] & kmask1;
const uint32_t mins_4_7_1 = ((sm1[2] >> 4) & kmask2) | (((sm1[1] >> 6) & kmask3) << 4);
svuint32_t mins_u32_temp = svzip1_u32(svdup_n_u32(mins_0_3), svdup_n_u32(mins_4_7));
svuint32_t mins_u32_temp_1 = svzip1_u32(svdup_n_u32(mins_0_3_1), svdup_n_u32(mins_4_7_1));
/* reinterpret u32 → u8 */
svuint8_t mins_u8 = svreinterpret_u8_u32(mins_u32_temp);
svuint8_t mins_u8_1 = svreinterpret_u8_u32(mins_u32_temp_1);
/* widen u8 → u16->u32 (lower half only) */
svuint32_t mins_u16 = svunpklo_u32(svunpklo_u16(mins_u8));
svuint32_t mins_u16_1 = svunpklo_u32(svunpklo_u16(mins_u8_1));
q4sb_mins_0 = svreinterpret_s32_u32(mins_u16);
q4sb_mins_1 = svreinterpret_s32_u32(mins_u16_1);
uint32_t scales_u32_0 = sm[0] & kmask1;
uint32_t scales_u32_1 = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4);
uint32_t scales_u32_2 = sm1[0] & kmask1;
uint32_t scales_u32_3 = (sm1[2] & kmask2) | (((sm1[0] >> 6) & kmask3) << 4);
svuint32_t S01 = svdup_n_u32(scales_u32_0);
svuint32_t S23 = svdup_n_u32(scales_u32_1);
svuint32_t R01 = svdup_n_u32(scales_u32_2);
svuint32_t R23 = svdup_n_u32(scales_u32_3);
svint8_t S01_b = svreinterpret_s8_u32(S01);
svint8_t S23_b = svreinterpret_s8_u32(S23);
svint8_t R01_b = svreinterpret_s8_u32(R01);
svint8_t R23_b = svreinterpret_s8_u32(R23);
svint32_t S01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S01_b, S01_b)));
svint32_t R01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R01_b, R01_b)));
svint32_t S23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S23_b, S23_b)));
svint32_t R23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R23_b, R23_b)));
block_scale_0 = svtbl_s32(svzip1_s32(S01_d, R01_d), idx);
block_scale_1 = svtbl_s32(svzip2_s32(S01_d, R01_d), idx);
block_scale_2 = svtbl_s32(svzip1_s32(S23_d, R23_d), idx);
block_scale_3 = svtbl_s32(svzip2_s32(S23_d, R23_d), idx);
}
const int8_t * q8_base_1 = q8_ptr[b].qs + sb * 256;
// Load 32-byte per row pair, 1 subblock each time
// predicate for activating higher lanes for 16 int8 elements
const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
// predicate for activating lower lanes for 16 int8 elements
const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16);
svint8_t q8_qs_0 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 0), svld1_s8(pl16, q8_base_1 + 112));
svint8_t q8_qs_2 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 32), svld1_s8(pl16, q8_base_1 + 144));
svint8_t q8_qs_4 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 64), svld1_s8(pl16, q8_base_1 + 176));
svint8_t q8_qs_6 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 96), svld1_s8(pl16, q8_base_1 + 208));
svint8_t q8_qs_1 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 16), svld1_s8(pl16, q8_base_1 + 128));
svint8_t q8_qs_3 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 48), svld1_s8(pl16, q8_base_1 + 160));
svint8_t q8_qs_5 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 80), svld1_s8(pl16, q8_base_1 + 192));
svint8_t q8_qs_7 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 112), svld1_s8(pl16, q8_base_1 + 224));
// Q4s columns iterated in pairs (01, 23, 45, 67)
for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
sb_acc_0 = svdup_n_s32(0);
sb_acc_2 = svdup_n_s32(0);
svuint8_t q4_qs_cp_00 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 0);
svuint8_t q4_qs_cp_01 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 64);
svuint8_t q4_qs_cp_02 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 128);
svuint8_t q4_qs_cp_03 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 192);
svint8_t q4_nibbles_00 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_00, m4b_1), 4));
svint8_t q4_nibbles_01 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_01, m4b_1), 4));
svint8_t q4_nibbles_02 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_02, m4b_1), 4));
svint8_t q4_nibbles_03 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_03, m4b_1), 4));
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_00, q8_qs_0);
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_01, q8_qs_2);
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_02, q8_qs_4);
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_03, q8_qs_6);
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_00, q8_qs_1);
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_01, q8_qs_3);
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_02, q8_qs_5);
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_03, q8_qs_7);
if(cp == 0) {
acc_00 = svmla_s32_m(svptrue_b32(), acc_00, sb_acc_0, block_scale_0);
acc_44 = svmla_s32_m(svptrue_b32(), acc_44, sb_acc_2, block_scale_0);
}
if(cp == 1) {
acc_11 = svmla_s32_m(svptrue_b32(), acc_11, sb_acc_0, block_scale_1);
acc_55 = svmla_s32_m(svptrue_b32(), acc_55, sb_acc_2, block_scale_1);
}
if(cp == 2) {
acc_22 = svmla_s32_m(svptrue_b32(), acc_22, sb_acc_0, block_scale_2);
acc_66 = svmla_s32_m(svptrue_b32(), acc_66, sb_acc_2, block_scale_2);
}
if(cp == 3) {
acc_33 = svmla_s32_m(svptrue_b32(), acc_33, sb_acc_0, block_scale_3);
acc_77 = svmla_s32_m(svptrue_b32(), acc_77, sb_acc_2, block_scale_3);
}
}
bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][0]), q4sb_mins_0);
bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][1]), q4sb_mins_1);
bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][2]), q4sb_mins_0);
bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][3]), q4sb_mins_1);
bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][4]), q4sb_mins_0);
bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][5]), q4sb_mins_1);
bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][6]), q4sb_mins_0);
bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][7]), q4sb_mins_1);
} // for sb
acc_00 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_00, svext_s32(acc_00, acc_00, 4));
acc_11 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_11, svext_s32(acc_11, acc_11, 4));
acc_22 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_22, svext_s32(acc_22, acc_22, 4));
acc_33 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_33, svext_s32(acc_33, acc_33, 4));
acc_44 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_44, svext_s32(acc_44, acc_44, 4));
acc_55 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_55, svext_s32(acc_55, acc_55, 4));
acc_66 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_66, svext_s32(acc_66, acc_66, 4));
acc_77 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_77, svext_s32(acc_77, acc_77, 4));
svint32_t reorder_acc_01 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_00, acc_11), svtrn1_s32(acc_22, acc_33)), idx1);
svint32_t reorder_acc_23 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_00, acc_11), svtrn2_s32(acc_22, acc_33)), idx1);
svint32_t reorder_acc_45 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_44, acc_55), svtrn1_s32(acc_66, acc_77)), idx1);
svint32_t reorder_acc_67 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_44, acc_55), svtrn2_s32(acc_66, acc_77)), idx1);
// Broadcast q8 scalar
svfloat32_t q8_d = svdup_f32(q8_ptr[b].d[0]);
svfloat32_t q4_dmin_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].dmin), svdup_f16(0)));
svfloat32_t q4_d_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].d), svdup_f16(0)));
svfloat32_t scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
svfloat32_t dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
acc_f32_01 = svmls_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_00), dmins1);
acc_f32_01 = svmla_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_01), scale1);
q8_d = svdup_f32(q8_ptr[b].d[1]);
scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
acc_f32_23 = svmls_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_22), dmins1);
acc_f32_23 = svmla_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_23), scale1);
q8_d = svdup_f32(q8_ptr[b].d[2]);
scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
acc_f32_45 = svmls_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_44), dmins1);
acc_f32_45 = svmla_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_45), scale1);
q8_d = svdup_f32(q8_ptr[b].d[3]);
scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
acc_f32_67 = svmls_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_66), dmins1);
acc_f32_67 = svmla_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_67), scale1);
} // for b
// With the previous reorder, the tile is already in the correct memory layout.
// Predicate for exactly 4 lanes
svbool_t pg4 = svptrue_pat_b32(SV_VL4);
for (int i = 0; i < q8_k_blocklen; i++) {
int row = y * q8_k_blocklen + i;
for (int j = 0; j < 2; j++) {
int col = x * ncols_interleaved + j * 4;
int offset = row * bs + col;
if (i == 0 && j == 0) {
// acc_f32_0 → lower half of acc_f32_01
svst1_f32(pg4, s + offset, acc_f32_01);
} else if (i == 0 && j == 1) {
// acc_f32_1 → upper half of acc_f32_01
svst1_f32(pg4, s + offset, svext_f32(acc_f32_01, acc_f32_01, 4));
} else if (i == 1 && j == 0) {
// acc_f32_2
svst1_f32(pg4, s + offset, acc_f32_23);
} else if (i == 1 && j == 1) {
// acc_f32_3
svst1_f32(pg4, s + offset, svext_f32(acc_f32_23, acc_f32_23, 4));
} else if (i == 2 && j == 0) {
// acc_f32_4
svst1_f32(pg4, s + offset, acc_f32_45);
} else if (i == 2 && j == 1) {
// acc_f32_5
svst1_f32(pg4, s + offset, svext_f32(acc_f32_45, acc_f32_45, 4));
} else if (i == 3 && j == 0) {
// acc_f32_6
svst1_f32(pg4, s + offset, acc_f32_67);
} else if (i == 3 && j == 1) {
// acc_f32_7
svst1_f32(pg4, s + offset, svext_f32(acc_f32_67, acc_f32_67, 4));
}
}
}
} // for x
} // for y
return;
}
#endif // SVE compile-time end
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
constexpr int q8_k_blocklen = 4;
const uint8x16_t m4b = vdupq_n_u8(0x0f);
@@ -3474,6 +3972,208 @@ void ggml_gemm_q5_K_8x8_q8_K(int n,
ggml_gemm_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemm_q6_K_8x4_q8_K(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
constexpr int qk = QK_K;
const int nb = n / qk;
constexpr int ncols_interleaved = 8;
constexpr int blocklen = 4;
assert(n % qk == 0);
assert(nr % 4 == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(nb);
UNUSED(ncols_interleaved);
UNUSED(blocklen);
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
constexpr int q8_k_blocklen = 4;
constexpr int col_groups = ncols_interleaved / 4;
constexpr int acc_size = q8_k_blocklen * col_groups; // 4 rows, 2 column groups
const uint8x16_t m4b = vdupq_n_u8(0x0f);
const uint8x16_t mask_lo = vdupq_n_u8(0x03);
const uint8x16_t mask_hi = vdupq_n_u8(0x30);
const int8x16_t m32s = vdupq_n_s8(32);
float32x4_t acc_f32[acc_size];
for (int y = 0; y < nr / q8_k_blocklen; y++) {
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
for (int i = 0; i < acc_size; i++) {
acc_f32[i] = vdupq_n_f32(0);
}
for (int b = 0; b < nb; b++) {
float32x4_t q6_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d));
float32x4_t q6_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4));
float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d);
float32x4_t sbd_scale_0123[q8_k_blocklen];
float32x4_t sbd_scale_4567[q8_k_blocklen];
sbd_scale_0123[0] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 0);
sbd_scale_4567[0] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 0);
sbd_scale_0123[1] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 1);
sbd_scale_4567[1] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 1);
sbd_scale_0123[2] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 2);
sbd_scale_4567[2] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 2);
sbd_scale_0123[3] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 3);
sbd_scale_4567[3] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 3);
int32x4_t acc_s32[acc_size];
for (int i = 0; i < acc_size; i++) {
acc_s32[i] = vdupq_n_s32(0);
}
int16_t q6_scales[8 * 16];
for (int i = 0; i < 16; i++) {
int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
vst1q_s16(q6_scales + i * 8, scales);
}
for (int half = 0; half < 2; half++) {
const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
for (int sb = 0; sb < QK_K / 64; sb++) {
int32x4_t acc_lo[acc_size];
int32x4_t acc_hi[acc_size];
for (int i = 0; i < acc_size; i++) {
acc_lo[i] = vdupq_n_s32(0);
acc_hi[i] = vdupq_n_s32(0);
}
const int8_t * q8_base_l = q8_ptr[b].qs + half * 512 + sb * 64;
const int8_t * q8_base_h = q8_ptr[b].qs + half * 512 + 256 + sb * 64;
// 4 rows * 16 elements per scale
// 4 reads of 16 bytes each
constexpr int reads_per_sb = 4;
int8x16_t q8_l[reads_per_sb];
int8x16_t q8_h[reads_per_sb];
for (int k = 0; k < reads_per_sb; k++) {
q8_l[k] = vld1q_s8(q8_base_l + 16 * k);
q8_h[k] = vld1q_s8(q8_base_h + 16 * k);
}
const int ql_off_base = sb * QK_K / 2;
const int qh_off_base = ql_off_base & 255;
uint8x16_t q6_ql_0123[reads_per_sb];
uint8x16_t q6_ql_4567[reads_per_sb];
uint8x16_t q6_qh_0123[reads_per_sb];
uint8x16_t q6_qh_4567[reads_per_sb];
for (int k = 0; k < reads_per_sb; k++) {
q6_ql_0123[k] = vld1q_u8(ql_base + ql_off_base + k * 32);
q6_ql_4567[k] = vld1q_u8(ql_base + ql_off_base + k * 32 + 16);
q6_qh_0123[k] = vld1q_u8(qh_base + qh_off_base + k * 32);
q6_qh_4567[k] = vld1q_u8(qh_base + qh_off_base + k * 32 + 16);
}
if (sb > 1) {
for (int k = 0; k < reads_per_sb; k++) {
q6_qh_0123[k] = vshrq_n_u8(q6_qh_0123[k], 2);
q6_qh_4567[k] = vshrq_n_u8(q6_qh_4567[k], 2);
}
}
for (int k = 0; k < reads_per_sb; k++) {
// q = (ql | qh) - 32
const uint8x16_t hbit_lo_0123 = vandq_u8(q6_qh_0123[k], mask_lo);
const uint8x16_t hbit_hi_0123 = vandq_u8(q6_qh_0123[k], mask_hi);
const uint8x16_t hbit_lo_4567 = vandq_u8(q6_qh_4567[k], mask_lo);
const uint8x16_t hbit_hi_4567 = vandq_u8(q6_qh_4567[k], mask_hi);
const int8x16_t q6_0123_lo = vsubq_s8(
vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_0123[k], m4b), hbit_lo_0123, 4)), m32s);
const int8x16_t q6_0123_hi = vsubq_s8(
vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_0123[k], 4), hbit_hi_0123)), m32s);
acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q6_0123_lo, q8_l[k], 0); // 0..3 r0 c0123
acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q6_0123_lo, q8_l[k], 1); // 0..3 r1 c0123
acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q6_0123_lo, q8_l[k], 2); // 0..3 r2 c0123
acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q6_0123_lo, q8_l[k], 3); // 0..3 r3 c0123
acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q6_0123_hi, q8_h[k], 0); // 64..67 r0 c0123
acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q6_0123_hi, q8_h[k], 1); // 64..67 r1 c0123
acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q6_0123_hi, q8_h[k], 2); // 64..67 r2 c0123
acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q6_0123_hi, q8_h[k], 3); // 64..67 r3 c0123
const int8x16_t q6_4567_lo = vsubq_s8(
vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_4567[k], m4b), hbit_lo_4567, 4)), m32s);
const int8x16_t q6_4567_hi = vsubq_s8(
vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_4567[k], 4), hbit_hi_4567)), m32s);
acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q6_4567_lo, q8_l[k], 0); // 0..3 r0 c4567
acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q6_4567_lo, q8_l[k], 1); // 0..3 r1 c4567
acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q6_4567_lo, q8_l[k], 2); // 0..3 r2 c4567
acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q6_4567_lo, q8_l[k], 3); // 0..3 r3 c4567
acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q6_4567_hi, q8_h[k], 0); // 64..67 r0 c4567
acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q6_4567_hi, q8_h[k], 1); // 64..67 r1 c4567
acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q6_4567_hi, q8_h[k], 2); // 64..67 r2 c4567
acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q6_4567_hi, q8_h[k], 3); // 64..67 r3 c4567
}
// Scale and bias
const int scale_idx_l = half * 8 + sb;
const int scale_idx_h = half * 8 + sb + 4;
for (int g = 0; g < col_groups; g++) {
const int16x4_t scales_l16 = vld1_s16(q6_scales + scale_idx_l * 8 + g * 4);
const int16x4_t scales_h16 = vld1_s16(q6_scales + scale_idx_h * 8 + g * 4);
const int32x4_t scale_vec_l = vmovl_s16(scales_l16);
const int32x4_t scale_vec_h = vmovl_s16(scales_h16);
const int acc_offset = g * q8_k_blocklen;
for (int row = 0; row < q8_k_blocklen; row++) {
const int idx = row * 2 + g;
acc_s32[idx] = vmlaq_s32(acc_s32[idx], acc_lo[acc_offset + row], scale_vec_l);
acc_s32[idx] = vmlaq_s32(acc_s32[idx], acc_hi[acc_offset + row], scale_vec_h);
}
}
}
}
// Finally we apply the superblock scales
for (int row = 0; row < q8_k_blocklen; row++) {
const int idx0 = 2 * row;
const int idx1 = 2 * row + 1;
const int32x4_t acc_0123 = acc_s32[idx0];
const int32x4_t acc_4567 = acc_s32[idx1];
acc_f32[idx0] = vmlaq_f32(acc_f32[idx0], vcvtq_f32_s32(acc_0123), sbd_scale_0123[row]);
acc_f32[idx1] = vmlaq_f32(acc_f32[idx1], vcvtq_f32_s32(acc_4567), sbd_scale_4567[row]);
}
} // for b
for (int i = 0; i < q8_k_blocklen; i++) {
int row = y * q8_k_blocklen + i;
for (int j = 0; j < 2; j++) {
int col = x * ncols_interleaved + j * 4;
int offset = row * bs + col;
vst1q_f32(s + offset, acc_f32[2 * i + j]);
}
}
} // for x
} // for y
return;
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
ggml_gemm_q6_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemm_q6_K_8x8_q8_K(int n,
float * GGML_RESTRICT s,
size_t bs,
+2 -6
View File
@@ -59,11 +59,7 @@ static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * ds
GGML_ASSERT(nb00 == sizeof(src0_t));
const auto [ir0, ir1] = get_thread_range(params, src0);
const bool is_src1_contiguous = (nb10 == sizeof(src1_t));
if (!is_src1_contiguous) { // broadcast not implemented yet for non-contiguous
GGML_ASSERT(ggml_are_same_shape(src0, src1));
}
const bool is_src1_contiguous_rows = ggml_is_contiguous_rows(src1);
#ifdef GGML_USE_ACCELERATE
vDSP_fn_t vDSP_op = nullptr;
@@ -94,7 +90,7 @@ static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * ds
const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
const src1_t * src1_ptr = (const src1_t *) ((const char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
if (is_src1_contiguous) {
if (is_src1_contiguous_rows) {
// src1 is broadcastable across src0 and dst in i1, i2, i3
const int64_t nr0 = ne00 / ne10;
+2 -2
View File
@@ -6,8 +6,8 @@
#include "ggml-impl.h"
#include "simd-mappings.h"
#define GGML_FA_TILE_Q 32
#define GGML_FA_TILE_KV 16
#define GGML_FA_TILE_Q 64
#define GGML_FA_TILE_KV 64
#ifdef __cplusplus
+12 -4
View File
@@ -2874,8 +2874,8 @@ struct ggml_cplan ggml_graph_plan(
const int64_t DV = node->src[2]->ne[0];
// Tiled flash attention scratch (tile sizes defined in common.h)
// Per-thread: Q_q + KQ + mask + VKQ32 + V32 + padding
size_t prefill = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks;
// Per-thread: Q_q + KQ + mask + VKQ32 + V32 + K_f32 + padding
size_t prefill = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV + GGML_FA_TILE_KV*DK)*n_tasks;
// Decode path: n_kv_chunks = n_tasks (one chunk per thread)
// Per-thread: VKQ accmulator (DV), partial M, partial S + intra-thread scratch for V, Q and VKQ
@@ -2947,7 +2947,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
/*.use_ref =*/ cplan->use_ref,
};
GGML_PRINT_DEBUG("thread #%d compute-start cplan %p last-graph %d \n", state->ith, cplan, state->last_graph);
#ifdef GGML_USE_OPENMP
GGML_PRINT_DEBUG("thread #%d compute-start cplan %p\n", state->ith, (const void *)cplan);
#else
GGML_PRINT_DEBUG("thread #%d compute-start cplan %p last-graph %d\n", state->ith, (const void *)cplan, state->last_graph);
#endif
for (int node_n = 0; node_n < cgraph->n_nodes && atomic_load_explicit(&tp->abort, memory_order_relaxed) != node_n; node_n++) {
struct ggml_tensor * node = cgraph->nodes[node_n];
@@ -2974,7 +2978,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
}
}
GGML_PRINT_DEBUG("thread #%d compute-done cplan %p last-graph %d \n", state->ith, cplan, state->last_graph);
#ifdef GGML_USE_OPENMP
GGML_PRINT_DEBUG("thread #%d compute-done cplan %p\n", state->ith, (const void *)cplan);
#else
GGML_PRINT_DEBUG("thread #%d compute-done cplan %p last-graph %d\n", state->ith, (const void *)cplan, state->last_graph);
#endif
ggml_barrier(state->threadpool);
-333
View File
@@ -1,333 +0,0 @@
#pragma once
typedef vector unsigned char vec_t;
typedef __vector_quad acc_t;
template <typename TA>
class tinyBLAS_Q0_PPC {
public:
tinyBLAS_Q0_PPC(int64_t k,
const TA *A, int64_t lda,
const block_q8_0 *B, int64_t ldb,
float *C, int64_t ldc,
int ith, int nth);
void matmul(int64_t m, int64_t n);
void matmul_tiled_q0(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) {
vec_t A_pack[mc*kc*2];
vec_t B_pack[nc*kc*2];
int comparray[mc*kc];
constexpr bool is_Ablock_q4 = std::is_same_v<TA, block_q4_0>;
int64_t ytiles = m / mc;
int64_t xtiles = n / nc;
int64_t tiles = xtiles * ytiles;
int64_t duty = (tiles + nth - 1) / nth;
int64_t start = duty * ith;
int64_t end = start + duty;
if (end > tiles) {
end = tiles;
}
for (int64_t job = start; job < end; ++job) {
int64_t ii = (job / xtiles) * mc;
int64_t jj = (job % xtiles) * nc;
for (int64_t kk = 0; kk < k; kk += kc) {
if constexpr(is_Ablock_q4) {
packNormalInt4_large(A + ii*lda + kk, lda, mc, 4, (int8_t*)A_pack, comparray);
} else {
packNormal_large<int8_t, vector signed char>(A + ii*lda + kk, lda, mc, 8, (int8_t*)A_pack, false, comparray);
}
packNormal_large<uint8_t, vector unsigned char>(B + jj*ldb + kk, ldb, nc, 8, (uint8_t*)B_pack, true);
KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack, comparray);
}
}
}
private:
inline void save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) {
for (int I = 0; I < RM; I++) {
for (int J = 0; J < RN; J++) {
*((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J);
}
}
}
inline void add_save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) {
for (int I = 0; I < RM; I++) {
for (int J = 0; J < RN; J++) {
float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I);
*c_ptr += *((float*)&fin_res[idx+I]+J);
}
}
}
template<typename ArrayType>
inline void compute(acc_t* ACC, int c_idx, int s_idx, ArrayType& comparray, vector float* vs, vector float* fin_res) {
vector signed int vec_C[4];
vector float CA[4] = {0};
vector float res[4] = {0};
__builtin_mma_disassemble_acc(vec_C, ACC);
for (int i = 0; i < 4; i++) {
CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0));
res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]);
}
}
inline void process_q4_elements(vector signed char (&c)[2], int* ca) {
const vector signed char lowMask = vec_splats((signed char)0xF);
const vector unsigned char v4 = vec_splats((unsigned char)0x4);
const vector signed char v8 = vec_splats((signed char)0x8);
vector signed int vsum = {0};
vector signed int vsum2 = {0};
c[0] = vec_and(c[1], lowMask);
c[1] = vec_sr(c[1], v4);
c[0] = vec_sub(c[0], v8);
c[1] = vec_sub(c[1], v8);
vsum = vec_sum4s(c[0], vsum);
vsum2 = vec_sum4s(c[1], vsum2);
vsum = vec_add(vsum, vsum2);
*(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3];
}
template <typename V1, typename V2>
inline void vector_permute_store(V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) {
vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
V2 t1, t2, t3, t4, t5, t6, t7, t8;
vector unsigned char xor_vector;
uint8_t flip_vec = 0x80;
xor_vector = vec_splats(flip_vec);
t1 = vec_perm(s1, s2, swiz1);
t2 = vec_perm(s1, s2, swiz2);
t3 = vec_perm(s3, s4, swiz1);
t4 = vec_perm(s3, s4, swiz2);
t5 = vec_perm(t1, t3, swiz3);
t6 = vec_perm(t1, t3, swiz4);
t7 = vec_perm(t2, t4, swiz3);
t8 = vec_perm(t2, t4, swiz4);
if (flip == true) {
t5 = vec_xor(t5, xor_vector);
t6 = vec_xor(t6, xor_vector);
t7 = vec_xor(t7, xor_vector);
t8 = vec_xor(t8, xor_vector);
}
vec_xst(t5, 0, vecOffset);
vec_xst(t6, 0, vecOffset+16);
vec_xst(t7, 0, vecOffset+32);
vec_xst(t8, 0, vecOffset+48);
}
template<int RM, int RN>
inline void kernel(int64_t ii, int64_t jj) {
if constexpr(RM == 4 && RN == 8) {
KERNEL_4x8(ii,jj);
} else if constexpr(RM == 8 && RN == 4) {
KERNEL_8x4(ii,jj);
} else if constexpr(RM == 8 && RN == 8) {
KERNEL_8x8(ii,jj);
} else {
assert(false && "RN/RM values not supported");
}
}
template<int size>
void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array<int, size>& comparray);
template<typename VA, typename VB>
void packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip);
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n);
void KERNEL_4x8(int64_t ii, int64_t jj);
void KERNEL_8x4(int64_t ii, int64_t jj);
void KERNEL_8x8(int64_t ii, int64_t jj);
void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN);
template <int RM, int RN>
void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n);
void compute_scale(int64_t ii, int64_t jj, int blk, vector float* vs){
for (int I = 0; I<8; I++) {
float a_scale = unhalf((A+((ii+I)*lda)+blk)->d);
for (int J = 0; J<4; J++) {
*((float*)&vs[I]+J) = (a_scale * unhalf((B+((jj+J)*ldb)+blk)->d));
*((float*)&vs[I+8]+J) = (a_scale * unhalf((B+((jj+J+4)*ldb)+blk)->d));
}
}
}
inline void process_q8_elements(const int8_t *qs, int *ca) {
vector signed char c1 = vec_xl(0, qs);
vector signed char c2 = vec_xl(16, qs);
vector signed int vsum1 = {0};
vector signed int vsum2 = {0};
vsum1 = vec_sum4s(c1, vsum1);
vsum2 = vec_sum4s(c2, vsum2);
vector signed int vsum = vec_add(vsum1, vsum2);
*ca = vsum[0] + vsum[1] + vsum[2] + vsum[3];
}
template<typename VA, typename VB>
void packNormal_large(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip, int* comparray=nullptr) {
int64_t i, j;
block_q8_0 *aoffset = NULL;
VA *vecOffset = NULL;
block_q8_0* aoffsets[8];
__vector_pair arr[8];
VB c[8][2] = {0};
VB c1[8] = {0}; VB c2[8] = {0};
aoffset = const_cast<block_q8_0*>(a);
vecOffset = vec;
j = (rows >> 3);
int index = 0;
if (j > 0) {
do {
for (int it = 0; it < 8; it++)
aoffsets[it] = aoffset + it*lda;
aoffset += 8 * lda;
for (int blk = 0; blk < kc; blk++) {
for (int it = 0; it < 8; it++) {
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)(aoffsets[it]+blk)->qs);
__builtin_vsx_disassemble_pair(c[it], &arr[it]);
c1[it] = c[it][0];
c2[it] = c[it][1];
if (comparray){
process_q8_elements((aoffsets[it]+ blk)->qs, &comparray[index + 8*blk + it]);
}
}
vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
vecOffset += 256;
}
j--;
index += 8*kc;
} while(j > 0);
}
}
void packNormalInt4_large(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, int*comparray) {
int64_t i, j;
TA *aoffset = NULL;
int8_t *vecOffset = NULL;
TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
aoffset = const_cast<TA*>(a);
vecOffset = vec;
int index = 0;
j = (rows >> 3);
if (j > 0) {
do {
aoffset1 = aoffset;
aoffset2 = aoffset1 + lda;
aoffset3 = aoffset2 + lda;
aoffset4 = aoffset3 + lda;
aoffset5 = aoffset4 + lda;
aoffset6 = aoffset5 + lda;
aoffset7 = aoffset6 + lda;
aoffset8 = aoffset7 + lda;
aoffset += 8 * lda;
for (int blk = 0; blk < kc; blk++) {
c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset1+blk)->qs));
c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset2+blk)->qs));
c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset3+blk)->qs));
c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset4+blk)->qs));
c5[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset5+blk)->qs));
c6[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset6+blk)->qs));
c7[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset7+blk)->qs));
c8[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset8+blk)->qs));
process_q4_elements(c1, &comparray[index + 8*blk+0]);
process_q4_elements(c2, &comparray[index + 8*blk+1]);
process_q4_elements(c3, &comparray[index + 8*blk+2]);
process_q4_elements(c4, &comparray[index + 8*blk+3]);
process_q4_elements(c5, &comparray[index + 8*blk+4]);
process_q4_elements(c6, &comparray[index + 8*blk+5]);
process_q4_elements(c7, &comparray[index + 8*blk+6]);
process_q4_elements(c8, &comparray[index + 8*blk+7]);
vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
vecOffset += 256;
}
j--;
index += 8*kc;
} while (j > 0);
}
}
void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t *vec_A, vec_t *vec_B, int *comparray) {
acc_t acc[8];
for (int i = 0; i < mc ; i += 8) {
for (int j = 0; j < nc; j += 8) {
vector float fin_res[16] = {0};
vector float vs[16] = {0};
for (int64_t kk = 0; kk < kc; kk+=2) {
for (int x = 0; x < 8; x++) {
__builtin_mma_xxsetaccz(&acc[x]);
}
int A_block_idx = (i/8)*(16*kc) + kk*16;
int B_block_idx = (j/8)*(16*kc)+ kk*16;
vec_t *A_block = &vec_A[A_block_idx];
vec_t *B_block = &vec_B[B_block_idx];
for (int x = 0; x < 8; x++) {
__builtin_mma_xvi8ger4pp(&acc[0], A_block[x], B_block[x]);
__builtin_mma_xvi8ger4pp(&acc[1], A_block[x + 8], B_block[x]);
__builtin_mma_xvi8ger4pp(&acc[2], A_block[x], B_block[x+8]);
__builtin_mma_xvi8ger4pp(&acc[3], A_block[x+8], B_block[x+8]);
}
compute_scale(ii+i, jj+j, l+kk, vs);
int c_index = (i/8)*(8*kc)+ kk*8;
int* c_block = &comparray[c_index];
compute(&acc[0], 0, 0, c_block, vs, fin_res);
compute(&acc[1], 4, 4, c_block, vs, fin_res);
compute(&acc[2], 0, 8, c_block, vs, fin_res);
compute(&acc[3], 4, 12, c_block, vs, fin_res);
A_block_idx = (i/8)*(16*kc) + (kk+1)*16;
B_block_idx = (j/8)*(16*kc)+ (kk+1)*16;
A_block = &vec_A[A_block_idx];
B_block = &vec_B[B_block_idx];
for (int x = 0; x < 8; x++) {
__builtin_mma_xvi8ger4pp(&acc[4], A_block[x], B_block[x]);
__builtin_mma_xvi8ger4pp(&acc[5], A_block[x + 8], B_block[x]);
__builtin_mma_xvi8ger4pp(&acc[6], A_block[x], B_block[x+8]);
__builtin_mma_xvi8ger4pp(&acc[7], A_block[x+8], B_block[x+8]);
}
compute_scale(ii+i, jj+j, l+kk+1, vs);
c_index = (i/8)*(8*kc)+ (kk+1)*8;
c_block = &comparray[c_index];
compute(&acc[4], 0, 0, c_block, vs, fin_res);
compute(&acc[5], 4, 4, c_block, vs, fin_res);
compute(&acc[6], 0, 8, c_block, vs, fin_res);
compute(&acc[7], 4, 12, c_block, vs, fin_res);
}
if (l == 0) {
save_res(ii+i, jj+j, 0, fin_res);
save_res(ii+i+4, jj+j, 4, fin_res);
save_res(ii+i, jj+j+4, 8, fin_res);
save_res(ii+i+4, jj+j+4, 12, fin_res);
} else {
add_save_res(ii+i, jj+j, 0, fin_res);
add_save_res(ii+i+4, jj+j, 4, fin_res);
add_save_res(ii+i, jj+j+4, 8, fin_res);
add_save_res(ii+i+4, jj+j+4, 12, fin_res);
}
}
}
}
const TA *const A;
const block_q8_0 *const B;
float *C;
const int64_t k;
int64_t kc;
const int64_t lda;
const int64_t ldb;
const int64_t ldc;
const int ith;
const int nth;
};
+507 -155
View File
@@ -121,7 +121,8 @@ inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); }
#endif
#if defined(__MMA__)
#include "sgemm-ppc.h"
typedef vector unsigned char vec_t;
typedef __vector_quad acc_t;
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
// VECTORIZED FUSED MULTIPLY ADD
@@ -2153,7 +2154,7 @@ class tinyBLAS_HP16_PPC {
packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B);
for (int x = 0; x < 4; x++) {
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
mma_instr<TA>::outer_product(&acc_1, vec_A[x+4], vec_B[x]);
}
}
SAVE_ACC(&acc_0, ii, jj);
@@ -2301,43 +2302,299 @@ class tinyBLAS_HP16_PPC {
const int nth;
};
template <typename TA>
tinyBLAS_Q0_PPC<TA>::tinyBLAS_Q0_PPC(int64_t k,
const TA *A, int64_t lda,
const block_q8_0 *B, int64_t ldb,
float *C, int64_t ldc,
int ith, int nth)
template <typename TA>
class tinyBLAS_Q0_PPC {
public:
tinyBLAS_Q0_PPC(int64_t k,
const TA * A, int64_t lda,
const block_q8_0 * B, int64_t ldb,
float * C, int64_t ldc,
int ith, int nth)
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
kc = 64;
}
template<typename TA>
void tinyBLAS_Q0_PPC<TA>::matmul(int64_t m, int64_t n) {
int mc = 64; int nc = 64;
if (n % 8 == 0 && n < nc) {
nc = n;
mc = 32 ;
kc = 32;
void matmul(int64_t m, int64_t n) {
const int64_t mc = 64;
const int64_t kc = 64;
int64_t nc = 64;
int64_t n_aligned = 0;
if (n % 64 == 0) {
n_aligned = n;
} else if (n == 4) {
n_aligned = 4;
} else if (n < 64) {
n_aligned = (n / 8) * 8;
} else {
n_aligned = (n / 64) * 64;
}
const bool is_aligned = ((m & (mc - 1)) == 0) & ((n & (nc - 1)) == 0) & ((k & (kc - 1)) == 0);
if (is_aligned) {
this->matmul_tiled_q0(m, n, mc, nc, kc);
if (n_aligned > 0) {
if (n_aligned % 64 == 0) nc = 64;
else if (n_aligned == n) nc = n;
else if (n_aligned % 32 == 0) nc = 32;
else if (n_aligned % 24 == 0) nc = 24;
else if (n_aligned % 16 == 0) nc = 16;
else nc = 8;
}
bool can_use_tiled = n_aligned > 0 && (m % mc == 0) && (k % kc == 0);
if (can_use_tiled) {
matmul_tiled(m, n_aligned, mc, nc, kc);
if (n > n_aligned) {
mnpack(0, m, n_aligned, n);
}
} else {
mnpack(0, m, 0, n);
}
}
template<typename TA>
template<int size>
void tinyBLAS_Q0_PPC<TA>::packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array<int, size>& comparray) {
private:
inline void save_res(int ii, int jj, int idx, vector float * fin_res, int RM = 4, int RN = 4) {
for (int I = 0; I < RM; I++) {
for (int J = 0; J < RN; J++) {
*((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&fin_res[idx + I] + J);
}
}
}
inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
vec_t vec_C[4];
__builtin_mma_disassemble_acc(vec_C, ACC);
for (int I = 0; I < 4; I++) {
for (int J = 0; J < 4; J++) {
*((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&vec_C[I] + J);
}
}
}
inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
vec_t vec_C[4];
__builtin_mma_disassemble_acc(vec_C, ACC);
for (int I = 0; I < 4; I++) {
for (int J = 0; J < 4; J++) {
float * c_ptr = (float *)(C + ii+ ((jj + J) * ldc) + I);
*c_ptr += *((float *)&vec_C[I] + J);
}
}
}
template<typename ArrayType>
inline void compute(acc_t * ACC, int c_idx, int s_idx, ArrayType & comparray, vector float * vs, vector float * fin_res) {
vector signed int vec_C[4];
vector float CA[4] = {0};
vector float res[4] = {0};
__builtin_mma_disassemble_acc(vec_C, ACC);
for (int i = 0; i < 4; i++) {
CA[i] = vec_splats((float)(((double)comparray[c_idx + i]) * -128.0));
res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
fin_res[s_idx + i] = vec_madd(res[i], vs[s_idx + i], fin_res[s_idx + i]);
}
}
inline void process_q4_elements(vector signed char (&c)[2], int * ca) {
const vector signed char lowMask = vec_splats((signed char)0xF);
const vector unsigned char v4 = vec_splats((unsigned char)0x4);
const vector signed char v8 = vec_splats((signed char)0x8);
vector signed int vsum = {0};
vector signed int vsum2 = {0};
c[0] = vec_and(c[1], lowMask);
c[1] = vec_sr(c[1], v4);
c[0] = vec_sub(c[0], v8);
c[1] = vec_sub(c[1], v8);
vsum = vec_sum4s(c[0], vsum);
vsum2 = vec_sum4s(c[1], vsum2);
vsum = vec_add(vsum, vsum2);
*(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3];
}
template <typename V1, typename V2>
inline void vector_permute_store(V2 & s1, V2 & s2, V2 & s3, V2 & s4, V1 * vecOffset, bool flip) {
vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
V2 t1, t2, t3, t4, t5, t6, t7, t8;
vector unsigned char xor_vector;
uint8_t flip_vec = 0x80;
xor_vector = vec_splats(flip_vec);
t1 = vec_perm(s1, s2, swiz1);
t2 = vec_perm(s1, s2, swiz2);
t3 = vec_perm(s3, s4, swiz1);
t4 = vec_perm(s3, s4, swiz2);
t5 = vec_perm(t1, t3, swiz3);
t6 = vec_perm(t1, t3, swiz4);
t7 = vec_perm(t2, t4, swiz3);
t8 = vec_perm(t2, t4, swiz4);
if (flip == true) {
t5 = vec_xor(t5, xor_vector);
t6 = vec_xor(t6, xor_vector);
t7 = vec_xor(t7, xor_vector);
t8 = vec_xor(t8, xor_vector);
}
vec_xst(t5, 0, vecOffset);
vec_xst(t6, 0, vecOffset + 16);
vec_xst(t7, 0, vecOffset + 32);
vec_xst(t8, 0, vecOffset + 48);
}
inline void unpack_q4_to_q8(vector signed char packed, vector signed char & lo, vector signed char & hi) {
const vector signed char lowMask = vec_splats((signed char)0x0F);
const vector signed char v8 = vec_splats((signed char)0x08);
const vector unsigned char v4 = vec_splats((unsigned char)4);
lo = vec_and(packed, lowMask);
hi = vec_sr(packed, v4);
lo = vec_sub(lo, v8);
hi = vec_sub(hi, v8);
}
inline void vector_permute_store_fp16(vec_t * c, unsigned char * vecOffset) {
vec_t t[8], s[8];
vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
for (int i = 0; i < 4; i += 2) {
t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1);
t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2);
}
for (int i = 4; i < 8; i += 2) {
t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1);
t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2);
}
s[0] = vec_perm(t[0], t[2], swiz3);
s[1] = vec_perm(t[0], t[2], swiz4);
s[2] = vec_perm(t[1], t[3], swiz3);
s[3] = vec_perm(t[1], t[3], swiz4);
s[4] = vec_perm(t[4], t[6], swiz3);
s[5] = vec_perm(t[4], t[6], swiz4);
s[6] = vec_perm(t[5], t[7], swiz3);
s[7] = vec_perm(t[5], t[7], swiz4);
for (int i = 0; i < 8; ++i) {
vec_xst(s[i], 0, (vec_t *)(vecOffset + i * 16));
}
}
static inline void convert_and_scale_q8(vector signed char raw, vector float v_scale, vector unsigned short & out_hi, vector unsigned short & out_lo) {
vector signed short i16_hi = vec_unpackh(raw);
vector signed short i16_lo = vec_unpackl(raw);
vector float f_hi_h = vec_ctf(vec_unpackh(i16_hi), 0);
vector float f_hi_l = vec_ctf(vec_unpackl(i16_hi), 0);
vector float f_lo_h = vec_ctf(vec_unpackh(i16_lo), 0);
vector float f_lo_l = vec_ctf(vec_unpackl(i16_lo), 0);
out_hi = vec_pack_to_short_fp32(vec_mul(f_hi_h, v_scale), vec_mul(f_hi_l, v_scale));
out_lo = vec_pack_to_short_fp32(vec_mul(f_lo_h, v_scale), vec_mul(f_lo_l, v_scale));
}
void packNormal_q4_fp16(const block_q4_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
unsigned char * vecOffset = vec;
for (int i = 0; i < rows; i += 8) {
const block_q4_0 * rows_base[8];
for (int r = 0; r < 8; r++) {
rows_base[r] = a + (i + r) * lda;
}
for (int blk = 0; blk < blocks; blk++) {
vector unsigned short hp_res[8][4];
for (int r = 0; r < 8; r++) {
const block_q4_0 * current_blk = rows_base[r] + blk;
vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(current_blk->d));
vector signed char v_qs = reinterpret_cast<vector signed char>(vec_xl(0, current_blk->qs));
vector signed char c1, c2;
unpack_q4_to_q8(v_qs, c1, c2);
convert_and_scale_q8(c1, v_scale, hp_res[r][0], hp_res[r][1]);
convert_and_scale_q8(c2, v_scale, hp_res[r][2], hp_res[r][3]);
}
for (int c = 0; c < 4; c++) {
vector unsigned char c_arr[8];
for (int r = 0; r < 8; r++) {
c_arr[r] = (vector unsigned char)hp_res[r][c];
}
vector_permute_store_fp16((vec_t *)c_arr, vecOffset);
vecOffset += 128;
}
}
}
}
template <int chunk_size>
static inline void pack_q8_block(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
unsigned char * vecOffset = vec;
const vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
const vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
const vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
const vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
for (int i = 0; i < rows; i += chunk_size) {
const block_q8_0 * rows_base[chunk_size];
for (int r = 0; r < chunk_size; r++) {
rows_base[r] = a + (i + r) * lda;
}
for (int blk = 0; blk < blocks; blk++) {
vector unsigned short hp_res[chunk_size][4];
for (int r = 0; r < chunk_size; r++) {
const block_q8_0 * b = rows_base[r] + blk;
vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(b->d));
vector signed char c[2];
__vector_pair pair = __builtin_vsx_lxvp(0, (__vector_pair *)b->qs);
__builtin_vsx_disassemble_pair(c, & pair);
convert_and_scale_q8(c[0], v_scale, hp_res[r][0], hp_res[r][1]);
convert_and_scale_q8(c[1], v_scale, hp_res[r][2], hp_res[r][3]);
}
for (int col = 0; col < 4; col++) {
if constexpr (chunk_size == 8) {
vec_t t[8];
t[0] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1);
t[1] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2);
t[2] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1);
t[3] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2);
t[4] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz1);
t[5] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz2);
t[6] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz1);
t[7] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz2);
vec_xst(vec_perm(t[0], t[2], swiz3), 0, (vec_t *)(vecOffset + 0));
vec_xst(vec_perm(t[0], t[2], swiz4), 0, (vec_t *)(vecOffset + 16));
vec_xst(vec_perm(t[1], t[3], swiz3), 0, (vec_t *)(vecOffset + 32));
vec_xst(vec_perm(t[1], t[3], swiz4), 0, (vec_t *)(vecOffset + 48));
vec_xst(vec_perm(t[4], t[6], swiz3), 0, (vec_t *)(vecOffset + 64));
vec_xst(vec_perm(t[4], t[6], swiz4), 0, (vec_t *)(vecOffset + 80));
vec_xst(vec_perm(t[5], t[7], swiz3), 0, (vec_t *)(vecOffset + 96));
vec_xst(vec_perm(t[5], t[7], swiz4), 0, (vec_t *)(vecOffset + 112));
vecOffset += 128;
} else {
vec_t t0 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1);
vec_t t1 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2);
vec_t t2 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1);
vec_t t3 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2);
vec_xst(vec_perm(t0, t2, swiz3), 0, (vec_t *)(vecOffset + 0));
vec_xst(vec_perm(t0, t2, swiz4), 0, (vec_t *)(vecOffset + 16));
vec_xst(vec_perm(t1, t3, swiz3), 0, (vec_t *)(vecOffset + 32));
vec_xst(vec_perm(t1, t3, swiz4), 0, (vec_t *)(vecOffset + 48));
vecOffset += 64;
}
}
}
}
}
void packNormal_q8_fp16(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
if (rows == 4) {
pack_q8_block<4>(a, lda, rows, blocks, vec);
} else {
pack_q8_block<8>(a, lda, rows, blocks, vec);
}
}
template<int size>
void packNormalInt4(const TA * a, int64_t lda, int rows, int cols, int8_t * vec, std::array<int, size> & comparray) {
int64_t i, j;
TA *aoffset = NULL;
int8_t *vecOffset = NULL;
TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
TA * aoffset = NULL;
int8_t * vecOffset = NULL;
TA * aoffset1 = NULL, * aoffset2 = NULL, * aoffset3 = NULL, * aoffset4 = NULL;
TA * aoffset5 = NULL, * aoffset6 = NULL, * aoffset7 = NULL, * aoffset8 = NULL;
vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
aoffset = const_cast<TA*>(a);
aoffset = const_cast<TA *>(a);
vecOffset = vec;
j = (rows >> 3);
if (j > 0) {
@@ -2363,18 +2620,18 @@ class tinyBLAS_HP16_PPC {
c7[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset7->qs));
c8[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset8->qs));
process_q4_elements(c1, &comparray[0]);
process_q4_elements(c2, &comparray[1]);
process_q4_elements(c3, &comparray[2]);
process_q4_elements(c4, &comparray[3]);
process_q4_elements(c5, &comparray[4]);
process_q4_elements(c6, &comparray[5]);
process_q4_elements(c7, &comparray[6]);
process_q4_elements(c8, &comparray[7]);
process_q4_elements(c1, & comparray[0]);
process_q4_elements(c2, & comparray[1]);
process_q4_elements(c3, & comparray[2]);
process_q4_elements(c4, & comparray[3]);
process_q4_elements(c5, & comparray[4]);
process_q4_elements(c6, & comparray[5]);
process_q4_elements(c7, & comparray[6]);
process_q4_elements(c8, & comparray[7]);
vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset + 128, false);
vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset + 192, false);
aoffset1 += lda;
aoffset2 += lda;
aoffset3 += lda;
@@ -2405,12 +2662,12 @@ class tinyBLAS_HP16_PPC {
c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs));
process_q4_elements(c1, &comparray[0]);
process_q4_elements(c2, &comparray[1]);
process_q4_elements(c3, &comparray[2]);
process_q4_elements(c4, &comparray[3]);
process_q4_elements(c1, & comparray[0]);
process_q4_elements(c2, & comparray[1]);
process_q4_elements(c3, & comparray[2]);
process_q4_elements(c4, & comparray[3]);
vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
aoffset1 += lda;
aoffset2 += lda;
aoffset3 += lda;
@@ -2434,12 +2691,12 @@ class tinyBLAS_HP16_PPC {
case 1: c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
break;
}
process_q4_elements(c1, &comparray[0]);
process_q4_elements(c2, &comparray[1]);
process_q4_elements(c3, &comparray[2]);
process_q4_elements(c4, &comparray[3]);
process_q4_elements(c1, & comparray[0]);
process_q4_elements(c2, & comparray[1]);
process_q4_elements(c3, & comparray[2]);
process_q4_elements(c4, & comparray[3]);
vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
aoffset1 += lda;
aoffset2 += lda;
aoffset3 += lda;
@@ -2450,39 +2707,38 @@ class tinyBLAS_HP16_PPC {
}
}
template<typename TA>
template<typename VA, typename VB>
void tinyBLAS_Q0_PPC<TA>::packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
void packNormal(const block_q8_0 * a, int64_t lda, int rows, int cols, VA * vec, bool flip) {
int64_t i, j;
block_q8_0 *aoffset = NULL;
VA *vecOffset = NULL;
block_q8_0* aoffsets[8];
block_q8_0 * aoffset = NULL;
VA * vecOffset = NULL;
block_q8_0 * aoffsets[8];
__vector_pair arr[8];
VB c[8][2] = {0};
VB c1[8] = {0}; VB c2[8] = {0};
aoffset = const_cast<block_q8_0*>(a);
aoffset = const_cast<block_q8_0 *>(a);
vecOffset = vec;
j = (rows >> 3);
if (j > 0) {
do {
aoffsets[0] = aoffset;
for (int it = 1; it < 8; it++)
aoffsets[it] = aoffsets[it-1] + lda;
aoffsets[it] = aoffsets[it - 1] + lda;
aoffset += 8 * lda;
i = (cols >> 3);
if (i > 0) {
do {
for (int it = 0; it < 8; it++) {
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
__builtin_vsx_disassemble_pair(c[it], &arr[it]);
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs);
__builtin_vsx_disassemble_pair(c[it], & arr[it]);
c1[it] = c[it][0];
c2[it] = c[it][1];
}
vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset + 128, flip);
vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset + 192, flip);
for (int it = 0; it < 8; it++)
aoffsets[it] += lda;
vecOffset += 256;
@@ -2501,13 +2757,13 @@ class tinyBLAS_HP16_PPC {
if (i > 0) {
do {
for (int it = 0; it < 4; it++) {
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
__builtin_vsx_disassemble_pair(c[it], &arr[it]);
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs);
__builtin_vsx_disassemble_pair(c[it], & arr[it]);
c1[it] = c[it][0];
c2[it] = c[it][1];
}
vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
for (int it = 0; it < 4; it++) {
aoffsets[it] += lda;
}
@@ -2520,24 +2776,24 @@ class tinyBLAS_HP16_PPC {
if (rows & 3) {
aoffsets[0] = aoffset;
for (int it = 1; it < 3; it++ )
aoffsets[it] = aoffsets[it-1] + lda;
aoffsets[it] = aoffsets[it - 1] + lda;
i = (cols >> 3);
if (i > 0) {
do {
switch(rows) {
case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[2]->qs);
__builtin_vsx_disassemble_pair(c[2], &arr[2]);
case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[2]->qs);
__builtin_vsx_disassemble_pair(c[2], & arr[2]);
c1[2] = c[2][0]; c2[2] = c[2][1];
case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[1]->qs);
__builtin_vsx_disassemble_pair(c[1], &arr[1]);
case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[1]->qs);
__builtin_vsx_disassemble_pair(c[1], & arr[1]);
c1[1] = c[1][0]; c2[1] = c[1][1];
case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[0]->qs);
__builtin_vsx_disassemble_pair(c[0], &arr[0]);
case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[0]->qs);
__builtin_vsx_disassemble_pair(c[0], & arr[0]);
c1[0] = c[0][0]; c2[0] = c[0][1];
break;
}
vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
for (int it = 0; it < 3; it++)
aoffsets[it] += lda;
vecOffset += 128;
@@ -2547,8 +2803,7 @@ class tinyBLAS_HP16_PPC {
}
}
template<typename TA>
void tinyBLAS_Q0_PPC<TA>::mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
int m_rem = MIN(m - m0, 16);
int n_rem = MIN(n - n0, 16);
@@ -2585,8 +2840,7 @@ class tinyBLAS_HP16_PPC {
}
template<typename TA>
void tinyBLAS_Q0_PPC<TA>::KERNEL_4x8(int64_t ii, int64_t jj) {
void KERNEL_4x8(int64_t ii, int64_t jj) {
vec_t vec_A[8], vec_B[16] = {0};
acc_t acc_0, acc_1;
std::array<int, 4> comparray {};
@@ -2594,26 +2848,26 @@ class tinyBLAS_HP16_PPC {
vector float vs[8] = {0};
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
for (int l = 0; l < k; l++) {
__builtin_mma_xxsetaccz(&acc_0);
__builtin_mma_xxsetaccz(&acc_1);
__builtin_mma_xxsetaccz(& acc_0);
__builtin_mma_xxsetaccz(& acc_1);
if (std::is_same_v<TA, block_q4_0>) {
packNormalInt4<4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
packNormalInt4<4>((A + (ii * lda) + l), lda, 4, 4, (int8_t *)vec_A, comparray);
} else {
packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 4, 8, (int8_t *)vec_A, false);
}
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true);
for(int x = 0; x < 8; x++) {
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
__builtin_mma_xvi8ger4pp(&acc_1, vec_A[x], vec_B[x+8]);
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
__builtin_mma_xvi8ger4pp(& acc_1, vec_A[x], vec_B[x+8]);
}
for (int I = 0; I<4; I++) {
for (int J = 0; J<4; J++) {
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
*((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
*((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
*((float *)& vs[I + 4] + J) = (unhalf((A +((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d));
}
}
if (!isAblock_q4) {
auto aoffset = A+(ii*lda)+l;
auto aoffset = A + (ii * lda) + l;
for (int i = 0; i < 4; i++) {
comparray[i] = 0;
int ca = 0;
@@ -2624,15 +2878,14 @@ class tinyBLAS_HP16_PPC {
aoffset += lda;
}
}
compute(&acc_0, 0, 0, comparray, vs, fin_res);
compute(&acc_1, 0, 4, comparray, vs, fin_res);
compute(& acc_0, 0, 0, comparray, vs, fin_res);
compute(& acc_1, 0, 4, comparray, vs, fin_res);
}
save_res(ii, jj, 0, fin_res);
save_res(ii, jj+4, 4, fin_res);
save_res(ii, jj + 4, 4, fin_res);
}
template<typename TA>
void tinyBLAS_Q0_PPC<TA>::KERNEL_8x4(int64_t ii, int64_t jj) {
void KERNEL_8x4(int64_t ii, int64_t jj) {
vec_t vec_A[16], vec_B[8] = {0};
acc_t acc_0, acc_1;
std::array<int, 8> comparray {};
@@ -2640,25 +2893,25 @@ class tinyBLAS_HP16_PPC {
vector float vs[8] = {0};
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
for (int l = 0; l < k; l++) {
__builtin_mma_xxsetaccz(&acc_0);
__builtin_mma_xxsetaccz(&acc_1);
__builtin_mma_xxsetaccz(& acc_0);
__builtin_mma_xxsetaccz(& acc_1);
if (std::is_same_v<TA, block_q4_0>) {
packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray);
} else {
packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false);
}
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 4, 8, (uint8_t *)vec_B, true);
for(int x = 0; x < 8; x++) {
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
__builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
__builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]);
}
for (int I = 0; I<8; I++) {
for (int J = 0; J<4; J++) {
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
for (int I = 0; I < 8; I++) {
for (int J = 0; J < 4; J++) {
*((float *)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
}
}
if (!isAblock_q4) {
auto aoffset = A+(ii*lda)+l;
auto aoffset = A + (ii * lda) + l;
for (int i = 0; i < 8; i++) {
comparray[i] = 0;
int ca = 0;
@@ -2669,15 +2922,14 @@ class tinyBLAS_HP16_PPC {
aoffset += lda;
}
}
compute(&acc_0, 0, 0, comparray, vs, fin_res);
compute(&acc_1, 4, 4, comparray, vs, fin_res);
compute(& acc_0, 0, 0, comparray, vs, fin_res);
compute(& acc_1, 4, 4, comparray, vs, fin_res);
}
save_res(ii, jj, 0, fin_res);
save_res(ii+4, jj, 4, fin_res);
save_res(ii + 4, jj, 4, fin_res);
}
template<typename TA>
void tinyBLAS_Q0_PPC<TA>::KERNEL_8x8(int64_t ii, int64_t jj) {
void KERNEL_8x8(int64_t ii, int64_t jj) {
vec_t vec_A[16], vec_B[16] = {0};
acc_t acc_0, acc_1, acc_2, acc_3;
acc_t acc_4, acc_5, acc_6, acc_7;
@@ -2686,30 +2938,30 @@ class tinyBLAS_HP16_PPC {
vector float vs[16] = {0};
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
for (int l = 0; l < k; l++) {
__builtin_mma_xxsetaccz(&acc_0);
__builtin_mma_xxsetaccz(&acc_1);
__builtin_mma_xxsetaccz(&acc_2);
__builtin_mma_xxsetaccz(&acc_3);
__builtin_mma_xxsetaccz(& acc_0);
__builtin_mma_xxsetaccz(& acc_1);
__builtin_mma_xxsetaccz(& acc_2);
__builtin_mma_xxsetaccz(& acc_3);
if (std::is_same_v<TA, block_q4_0>) {
packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray);
} else {
packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false);
}
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true);
for(int x = 0; x < 8; x++) {
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
__builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
__builtin_mma_xvi8ger4pp(&acc_2, vec_A[x], vec_B[x+8]);
__builtin_mma_xvi8ger4pp(&acc_3, vec_A[x+8], vec_B[x+8]);
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
__builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]);
__builtin_mma_xvi8ger4pp(& acc_2, vec_A[x], vec_B[x + 8]);
__builtin_mma_xvi8ger4pp(& acc_3, vec_A[x + 8], vec_B[x + 8]);
}
for (int I = 0; I<8; I++) {
for (int J = 0; J<4; J++) {
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
*((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
for (int I = 0; I < 8 ; I++) {
for (int J = 0; J < 4; J++) {
*((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
*((float *)& vs[I + 8] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d));
}
}
if (!isAblock_q4) {
auto aoffset = A+(ii*lda)+l;
auto aoffset = A + (ii * lda) + l;
for (int i = 0; i < 8; i++) {
comparray[i] = 0;
int ca = 0;
@@ -2720,19 +2972,99 @@ class tinyBLAS_HP16_PPC {
aoffset += lda;
}
}
compute(&acc_0, 0, 0, comparray, vs, fin_res);
compute(&acc_1, 4, 4, comparray, vs, fin_res);
compute(&acc_2, 0, 8, comparray, vs, fin_res);
compute(&acc_3, 4, 12, comparray, vs, fin_res);
compute(& acc_0, 0, 0, comparray, vs, fin_res);
compute(& acc_1, 4, 4, comparray, vs, fin_res);
compute(& acc_2, 0, 8, comparray, vs, fin_res);
compute(& acc_3, 4, 12, comparray, vs, fin_res);
}
save_res(ii, jj, 0, fin_res);
save_res(ii+4, jj, 4, fin_res);
save_res(ii, jj+4, 8, fin_res);
save_res(ii+4, jj+4, 12, fin_res);
save_res(ii + 4, jj, 4, fin_res);
save_res(ii, jj + 4, 8, fin_res);
save_res(ii + 4, jj + 4, 12, fin_res);
}
template<typename TA>
void tinyBLAS_Q0_PPC<TA>::gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t * vec_A, vec_t * vec_B) {
acc_t acc[8];
for (int i = 0; i < mc ; i += 16) {
for (int j = 0; j < nc; j += 8) {
int A0_base = (i / 16) * (2 * 32 * kc);
int B0_base = (j / 8) * (32 * kc);
for (int x = 0; x < 8; x++) {
__builtin_mma_xxsetaccz(&acc[x]);
}
for (int64_t kk = 0; kk < kc; kk++) {
int A0_block_idx = A0_base + kk * 32;
int B0_block_idx = B0_base + kk * 32;
int A1_block_idx = A0_block_idx + 32 * kc;
int B1_block_idx = B0_block_idx + 32 * kc;
vec_t * A0_block = & vec_A[A0_block_idx];
vec_t * B0_block = & vec_B[B0_block_idx];
vec_t * A1_block = & vec_A[A1_block_idx];
for (int it = 0; it < 4; it++) {
for (int x = 0; x < 4; x++) {
__builtin_mma_xvf16ger2pp(& acc[0], A0_block[8 * it + x], B0_block[8 * it + x]);
__builtin_mma_xvf16ger2pp(& acc[1], A0_block[8 * it + x], B0_block[8 * it + x + 4]);
__builtin_mma_xvf16ger2pp(& acc[2], A0_block[8 * it + x + 4], B0_block[8 * it + x]);
__builtin_mma_xvf16ger2pp(& acc[3], A0_block[8 * it + x + 4], B0_block[8 * it + x + 4]);
__builtin_mma_xvf16ger2pp(& acc[4], A1_block[8 * it + x], B0_block[8 * it + x]);
__builtin_mma_xvf16ger2pp(& acc[5], A1_block[8 * it + x], B0_block[8 * it+ x + 4]);
__builtin_mma_xvf16ger2pp(& acc[6], A1_block[8 * it + x + 4], B0_block[8 * it + x]);
__builtin_mma_xvf16ger2pp(& acc[7], A1_block[8 * it + x + 4], B0_block[8 * it + x + 4]);
}
}
}
if (l == 0) {
save_acc(& acc[0], ii + i, jj + j);
save_acc(& acc[1], ii + i, jj + j + 4);
save_acc(& acc[2], ii + i + 4, jj + j);
save_acc(& acc[3], ii + i + 4, jj + j + 4);
save_acc(& acc[4], ii + i + 8, jj + j);
save_acc(& acc[5], ii + i + 8, jj + j + 4);
save_acc(& acc[6], ii + i + 12, jj + j);
save_acc(& acc[7], ii + i + 12, jj + j + 4);
} else {
add_save_acc(& acc[0], ii + i, jj + j);
add_save_acc(& acc[1], ii + i, jj + j + 4);
add_save_acc(& acc[2], ii + i + 4, jj + j);
add_save_acc(& acc[3], ii + i + 4, jj + j + 4);
add_save_acc(& acc[4], ii + i + 8, jj + j);
add_save_acc(& acc[5], ii + i + 8, jj + j + 4);
add_save_acc(& acc[6], ii + i + 12, jj + j);
add_save_acc(& acc[7], ii + i + 12, jj + j + 4);
}
}
}
}
void matmul_tiled(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) {
vec_t A_pack[mc * kc * 4];
vec_t B_pack[nc * kc * 4];
constexpr bool is_Ablock_q4 = std::is_same_v<TA, block_q4_0>;
int64_t ytiles = m / mc;
int64_t xtiles = n / nc;
int64_t tiles = xtiles * ytiles;
int64_t duty = (tiles + nth - 1) / nth;
int64_t start = duty * ith;
int64_t end = start + duty;
if (end > tiles) {
end = tiles;
}
for (int64_t job = start; job < end; ++job) {
int64_t ii = (job / xtiles) * mc;
int64_t jj = (job % xtiles) * nc;
for (int64_t kk = 0; kk < k; kk += kc) {
if constexpr(is_Ablock_q4) {
packNormal_q4_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack);
} else {
packNormal_q8_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack);
}
packNormal_q8_fp16(B + jj * ldb + kk, ldb, nc, kc, (uint8_t *)B_pack);
KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack);
}
}
}
void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
int64_t ytiles = (m - m0) / RM;
int64_t xtiles = (n - n0) / RN;
int64_t tiles = xtiles * ytiles;
@@ -2754,32 +3086,32 @@ class tinyBLAS_HP16_PPC {
vector float fin_res[4] = {0};
vector float vs[4] = {0};
vector float CA[4] = {0};
__builtin_prefetch((A+(ii*lda)+0)->qs, 0, 1); // prefetch first value
__builtin_prefetch((B+(jj*ldb)+0)->qs, 0, 1); // prefetch first value
__builtin_prefetch((A + (ii * lda) + 0)->qs, 0, 1); // prefetch first value
__builtin_prefetch((B + (jj * ldb) + 0)->qs, 0, 1); // prefetch first value
for (int l = 0; l < k; l++) {
__builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead
__builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
__builtin_mma_xxsetaccz(&acc_0);
__builtin_prefetch((A + (ii * lda) + (l + 1))->qs, 0, 1); // prefetch one loop ahead
__builtin_prefetch((B + (jj * ldb) + (l + 1))->qs, 0, 1); // prefetch one loop ahead
__builtin_mma_xxsetaccz(& acc_0);
if (isAblock_q4) {
packNormalInt4<4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
packNormalInt4<4>((A + (ii * lda) + l), lda, RM, 4, (int8_t *)vec_A, comparray);
} else {
packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, RM, 8, (int8_t *)vec_A, false);
}
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
for(int x = 0; x < 8; x+=4) {
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+1], vec_B[x+1]);
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+2], vec_B[x+2]);
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+3], vec_B[x+3]);
packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, RN, 8, (uint8_t *)vec_B, true);
for (int x = 0; x < 8; x += 4) {
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 1], vec_B[x + 1]);
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 2], vec_B[x + 2]);
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 3], vec_B[x + 3]);
}
for (int I = 0; I<RM; I++) {
for (int J = 0; J<RN; J++) {
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
for (int I = 0; I < RM; I++) {
for (int J = 0; J < RN; J++) {
*((float*)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
}
}
__builtin_mma_disassemble_acc(vec_C, &acc_0);
__builtin_mma_disassemble_acc(vec_C, & acc_0);
if (!isAblock_q4) {
auto aoffset = A+(ii*lda)+l;
auto aoffset = A + (ii * lda) + l;
for (int i = 0; i < RM; i++) {
comparray[i] = 0;
int ca = 0;
@@ -2800,9 +3132,21 @@ class tinyBLAS_HP16_PPC {
}
}
template<typename TA>
template<int RM, int RN>
inline void kernel(int64_t ii, int64_t jj) {
if constexpr(RM == 4 && RN == 8) {
KERNEL_4x8(ii,jj);
} else if constexpr(RM == 8 && RN == 4) {
KERNEL_8x4(ii,jj);
} else if constexpr(RM == 8 && RN == 8) {
KERNEL_8x8(ii,jj);
} else {
assert(false && "RN/RM values not supported");
}
}
template <int RM, int RN>
NOINLINE void tinyBLAS_Q0_PPC<TA>::gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
int64_t ytiles = (m - m0) / RM;
int64_t xtiles = (n - n0) / RN;
int64_t tiles = xtiles * ytiles;
@@ -2814,12 +3158,20 @@ class tinyBLAS_HP16_PPC {
for (int64_t job = start; job < end; ++job) {
int64_t ii = m0 + job / xtiles * RM;
int64_t jj = n0 + job % xtiles * RN;
this->kernel<RM, RN>(ii, jj);
kernel<RM, RN>(ii, jj);
}
}
template class tinyBLAS_Q0_PPC<block_q4_0>;
template class tinyBLAS_Q0_PPC<block_q8_0>;
const TA * const A;
const block_q8_0 * const B;
float * C;
const int64_t k;
int64_t kc;
const int64_t lda;
const int64_t ldb;
const int64_t ldc;
const int ith;
const int nth;
};
class tinyBLAS_PPC {
public:
+173 -94
View File
@@ -3,6 +3,7 @@
#include "ggml-cpu.h"
#include "ggml-impl.h"
#include "binary-ops.h"
#include "simd-gemm.h"
#include "ggml.h"
#include "unary-ops.h"
#include "vec.h"
@@ -2096,10 +2097,14 @@ static void ggml_compute_forward_gelu_f32(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@@ -2113,10 +2118,14 @@ static void ggml_compute_forward_gelu_f32(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_gelu_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])));
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@@ -2135,10 +2144,14 @@ static void ggml_compute_forward_gelu_f16(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@@ -2152,10 +2165,14 @@ static void ggml_compute_forward_gelu_f16(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_gelu_f16(nc,
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@@ -2276,10 +2293,14 @@ static void ggml_compute_forward_gelu_erf_f32(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@@ -2293,10 +2314,14 @@ static void ggml_compute_forward_gelu_erf_f32(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_gelu_erf_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])));
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@@ -2315,10 +2340,14 @@ static void ggml_compute_forward_gelu_erf_f16(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@@ -2332,10 +2361,14 @@ static void ggml_compute_forward_gelu_erf_f16(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_gelu_erf_f16(nc,
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@@ -2379,10 +2412,14 @@ static void ggml_compute_forward_gelu_quick_f32(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@@ -2396,10 +2433,14 @@ static void ggml_compute_forward_gelu_quick_f32(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_gelu_quick_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])));
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@@ -2418,10 +2459,14 @@ static void ggml_compute_forward_gelu_quick_f16(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@@ -2435,10 +2480,14 @@ static void ggml_compute_forward_gelu_quick_f16(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_gelu_quick_f16(nc,
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@@ -2482,10 +2531,14 @@ static void ggml_compute_forward_silu_f32(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@@ -2499,10 +2552,14 @@ static void ggml_compute_forward_silu_f32(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_silu_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])));
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@@ -2521,10 +2578,14 @@ static void ggml_compute_forward_silu_f16(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@@ -2538,10 +2599,14 @@ static void ggml_compute_forward_silu_f16(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_silu_f16(nc,
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@@ -8325,10 +8390,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
GGML_ASSERT(k->type == v->type);
const ggml_type kv_type = k->type;
const auto * kv_type_traits_cpu = ggml_get_type_traits_cpu(kv_type);
const ggml_from_float_t kv_from_float = kv_type_traits_cpu->from_float;
const ggml_vec_dot_t kv_vec_dot = kv_type_traits_cpu->vec_dot;
const size_t kv_type_size = ggml_type_size(kv_type);
// broadcast factors
const int64_t rk2 = neq2/nek2;
@@ -8360,8 +8421,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q;
static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;
GGML_ASSERT(nek1 % KV_TILE_SZ == 0 && "KV sequence length must be divisible by KV_TILE_SZ");
int ir = ir0;
while (ir < ir1) {
// q indices for the start of this tile
@@ -8388,18 +8447,20 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
}
// Per-thread scratch layout:
// Q_q: Q_TILE_SZ * DK (converted Q tile in KV type)
// Q_q: Q_TILE_SZ * DK (converted Q tile — F32 for GEMM, KV type for scalar)
// KQ: Q_TILE_SZ * KV_TILE_SZ (attention scores in float)
// mask: Q_TILE_SZ * KV_TILE_SZ (mask in float)
// VKQ32: Q_TILE_SZ * DV (FP32 output accumulator)
// V32: KV_TILE_SZ * DV (F32 buffer for V tile - used for f166 conversion)
float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + CACHE_LINE_SIZE_F32);
// V32: KV_TILE_SZ * DV (F32 buffer for V tile)
// K_f32: KV_TILE_SZ * DK (F32 buffer for K tile — GEMM path)
float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + KV_TILE_SZ*DK + CACHE_LINE_SIZE_F32);
void * Q_q = base;
float * KQ = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float));
float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ;
float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ;
float * V32 = VKQ32 + Q_TILE_SZ * DV; // F32 buffer for V tile
float * V32 = VKQ32 + Q_TILE_SZ * DV;
float * K_f32 = V32 + KV_TILE_SZ * DV;
memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float));
memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
@@ -8412,28 +8473,38 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
const int iv3 = iq3 / rv3;
const int iv2 = iq2 / rv2;
for (int tq = 0; tq < tile_rows; tq++) {
const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
kv_from_float(pq, (char *)Q_q + tq * DK * kv_type_size, DK);
}
// Zero-pad remaining rows
for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
memset((char *)Q_q + tq * DK * kv_type_size, 0, DK * kv_type_size);
{
float * Q_f32 = (float *)Q_q;
for (int tq = 0; tq < tile_rows; tq++) {
const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
memcpy(Q_f32 + tq * DK, pq, DK * sizeof(float));
}
for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
memset(Q_f32 + tq * DK, 0, DK * sizeof(float));
}
}
memset(K_f32, 0, DK * KV_TILE_SZ * sizeof(float));
memset(V32, 0, KV_TILE_SZ * DV * sizeof(float));
for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) {
const int kv_tile = (int)std::min((int64_t)KV_TILE_SZ, nek1 - ic);
// skip the tile entirely if all the masks are -inf
if (mask) {
bool can_skip = true;
for (int tq = 0; tq < tile_rows; tq++) {
const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]);
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
for (int tk = 0; tk < kv_tile; tk++) {
mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]);
if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) {
can_skip = false;
}
}
// Pad remaining mask entries with -inf
for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
mask32[tq * KV_TILE_SZ + tk] = -INFINITY;
}
}
if (can_skip) {
@@ -8441,13 +8512,32 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
}
}
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
const void * q_row = (const char *)Q_q + tq * DK * kv_type_size;
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
const void * k_row = (const char *) k->data + ((ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3);
float s;
kv_vec_dot(DK, &s, 0, k_row, 0, q_row, 0, 1);
KQ[tq * KV_TILE_SZ + tk] = s * scale;
// Pack K tile transposed: K_f32[dk][kv] so KV_TILE is contiguous (SIMD dim)
// Zero-pad the last tile so the GEMM always operates on KV_TILE_SZ columns
for (int tk = 0; tk < kv_tile; tk++) {
const char * k_data = (const char *)k->data + (ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3;
if (kv_type == GGML_TYPE_F16) {
const ggml_fp16_t * k_f16 = (const ggml_fp16_t *)k_data;
for (int64_t dk = 0; dk < DK; dk++) {
K_f32[dk * KV_TILE_SZ + tk] = GGML_CPU_FP16_TO_FP32(k_f16[dk]);
}
} else {
const float * k_f32_src = (const float *)k_data;
for (int64_t dk = 0; dk < DK; dk++) {
K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk];
}
}
}
memset(KQ, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
simd_gemm(KQ, (const float *)Q_q, K_f32, Q_TILE_SZ, DK, KV_TILE_SZ);
ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, scale);
// Set padded KQ entries to -inf so softmax gives them zero weight
if (kv_tile < KV_TILE_SZ) {
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
KQ[tq * KV_TILE_SZ + tk] = -INFINITY;
}
}
}
@@ -8487,33 +8577,22 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew);
}
// Convert V tile to F32 first (if F16), then do MAD
// On x86, ggml_vec_mad_f16 internall converts F16<->F32 on every load/store, so pre-converting is faster.
// TODO: on ARM, native f16 should be faster
if (kv_type == GGML_TYPE_F16) {
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
const ggml_fp16_t * v_row = (const ggml_fp16_t *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
ggml_fp16_to_fp32_row(v_row, V32 + tk * DV, DV);
}
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
if (skip[tq]) continue;
float * vkq_row = VKQ32 + tq * DV;
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
const float p = KQ[tq * KV_TILE_SZ + tk];
ggml_vec_mad_f32(DV, vkq_row, V32 + tk * DV, p);
}
}
} else {
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
if (skip[tq]) continue;
float * vkq_row = VKQ32 + tq * DV;
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
const float p = KQ[tq * KV_TILE_SZ + tk];
const float * v_row = (const float *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
ggml_vec_mad_f32(DV, vkq_row, v_row, p);
}
// V accumulation: VKQ32 += softmax(KQ) * V
// Pack V tile to contiguous F32, zero-padded
for (int tk = 0; tk < kv_tile; tk++) {
const char * v_data = (const char *)v->data + (ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3;
if (kv_type == GGML_TYPE_F16) {
ggml_fp16_to_fp32_row((const ggml_fp16_t *)v_data, V32 + tk * DV, DV);
} else {
memcpy(V32 + tk * DV, v_data, DV * sizeof(float));
}
}
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
if (skip[tq]) {
memset(KQ + tq * KV_TILE_SZ, 0, KV_TILE_SZ * sizeof(float));
}
}
simd_gemm(VKQ32, KQ, V32, Q_TILE_SZ, KV_TILE_SZ, DV);
}
// sinks (apply only to valid rows in the tile)
@@ -8730,15 +8809,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const int64_t dr = (nr + nchunk - 1) / nchunk;
static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV;
static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
const bool use_tiled = !use_ref &&
bool use_tiled = !use_ref &&
(q->type == GGML_TYPE_F32 &&
kv_is_f32_or_f16 &&
k->type == v->type &&
nek1 % KV_TILE_SZ == 0 &&
neq1 >= Q_TILE_SZ);
#ifdef GGML_SIMD
use_tiled &= (DV % GGML_F32_EPR == 0);
#endif
int current_chunk = ith;
while (current_chunk < nchunk) {
+232 -198
View File
@@ -256,6 +256,200 @@ template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTR
ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row);
}
template <int M, int N>
static void ggml_gemv_q6_K_NxM_q8_K_generic_impl(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
constexpr int blocklen = M;
constexpr int ncols_interleaved = N;
const int qk = QK_K;
const int nb = n / qk;
const int blocks_per_half = 64 / blocklen;
assert(n % qk == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(bs);
UNUSED(nr);
float sumf[8];
const block_q8_K * a_ptr = (const block_q8_K *) vy;
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb);
for (int j = 0; j < ncols_interleaved; j++) {
sumf[j] = 0.0f;
}
for (int l = 0; l < nb; l++) {
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
const int base_l = (k / blocks_per_half) * 128 + (k % blocks_per_half) * blocklen;
const int base_h = base_l + 64;
const int scale_idx_l = base_l / 16;
const int scale_idx_h = base_h / 16;
const int qh_shift_l = ((base_l % 128) / 32) * 2;
const int qh_shift_h = ((base_h % 128) / 32) * 2;
const int qh_half_l = (base_l / 128) * 32;
const int qh_half_h = (base_h / 128) * 32;
for (int j = 0; j < ncols_interleaved; j++) {
const int8_t scale_l = b_ptr[l].scales[scale_idx_l * ncols_interleaved + j];
const int8_t scale_h = b_ptr[l].scales[scale_idx_h * ncols_interleaved + j];
int sumi_l = 0;
int sumi_h = 0;
for (int i = 0; i < blocklen; i++) {
const int ql_pos = k * ncols_interleaved * blocklen + j * blocklen + i;
const int l_4 = b_ptr[l].ql[ql_pos] & 0xF;
const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF;
const int qh_idx_l = qh_half_l + ((base_l + i) % 32);
const int qh_chunk_l = qh_idx_l / blocklen;
const int qh_pos_l = qh_idx_l % blocklen;
const int qh_offset_l = qh_chunk_l * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_l;
const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3;
const int qh_idx_h = qh_half_h + ((base_h + i) % 32);
const int qh_chunk_h = qh_idx_h / blocklen;
const int qh_pos_h = qh_idx_h % blocklen;
const int qh_offset_h = qh_chunk_h * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_h;
const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3;
const int q_l = ((hi_2_l << 4) | l_4) - 32;
const int q_h = ((hi_2_h << 4) | hi_4) - 32;
const int8_t a_l = a_ptr[l].qs[base_l + i];
const int8_t a_h = a_ptr[l].qs[base_h + i];
sumi_l += q_l * a_l;
sumi_h += q_h * a_h;
}
sumf[j] +=
(sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
}
}
}
for (int j = 0; j < ncols_interleaved; j++) {
s[x * ncols_interleaved + j] = sumf[j];
}
}
}
template <int M, int N>
static void ggml_gemm_q6_K_NxM_q8_K_generic_impl(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
constexpr int blocklen = M;
constexpr int ncols_interleaved = N;
const int qk = QK_K;
const int nb = n / qk;
const int blocks_per_half = 64 / blocklen;
const int q8_half_stride = 512;
const int q8_low_high_step = 256;
assert(n % qk == 0);
assert(nr % 4 == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(bs);
float sumf[4][8];
for (int y = 0; y < nr / 4; y++) {
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb);
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
sumf[m][j] = 0.0f;
}
}
for (int l = 0; l < nb; l++) {
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
const int base_l = (k / blocks_per_half) * 128 + (k % blocks_per_half) * blocklen;
const int base_h = base_l + 64;
const int scale_idx_l = base_l / 16;
const int scale_idx_h = base_h / 16;
const int qh_shift_l = ((base_l % 128) / 32) * 2;
const int qh_shift_h = ((base_h % 128) / 32) * 2;
const int qh_half_l = (base_l / 128) * 32;
const int qh_half_h = (base_h / 128) * 32;
const int q8_base = (k / blocks_per_half) * q8_half_stride + (k % blocks_per_half) * (blocklen * 4);
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
const int8_t scale_l = b_ptr[l].scales[scale_idx_l * ncols_interleaved + j];
const int8_t scale_h = b_ptr[l].scales[scale_idx_h * ncols_interleaved + j];
int sumi_l = 0;
int sumi_h = 0;
for (int i = 0; i < blocklen; i++) {
const int ql_pos = k * ncols_interleaved * blocklen + j * blocklen + i;
const int l_4 = b_ptr[l].ql[ql_pos] & 0xF;
const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF;
const int qh_idx_l = qh_half_l + ((base_l + i) % 32);
const int qh_chunk_l = qh_idx_l / blocklen;
const int qh_pos_l = qh_idx_l % blocklen;
const int qh_offset_l =
qh_chunk_l * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_l;
const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3;
const int qh_idx_h = qh_half_h + ((base_h + i) % 32);
const int qh_chunk_h = qh_idx_h / blocklen;
const int qh_pos_h = qh_idx_h % blocklen;
const int qh_offset_h =
qh_chunk_h * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_h;
const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3;
const int q_l = ((hi_2_l << 4) | l_4) - 32;
const int q_h = ((hi_2_h << 4) | hi_4) - 32;
const int8_t q8_l = a_ptr[l].qs[q8_base + m * blocklen + i];
const int8_t q8_h = a_ptr[l].qs[q8_base + m * blocklen + i + q8_low_high_step];
sumi_l += q_l * q8_l;
sumi_h += q_h * q8_h;
}
sumf[m][j] += (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) *
a_ptr[l].d[m];
}
}
}
}
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
}
}
}
}
}
extern "C" {
void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@@ -704,94 +898,12 @@ void ggml_gemv_q5_K_8x8_q8_K_generic(int n,
}
void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
ggml_gemv_q6_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
constexpr int qk = QK_K;
const int nb = n / qk;
const int ncols_interleaved = 8;
const int blocklen = 8;
assert(n % qk == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(bs);
UNUSED(nr);
float sumf[8];
const block_q8_K * a_ptr = (const block_q8_K *) vy;
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb);
for (int j = 0; j < ncols_interleaved; j++) {
sumf[j] = 0.0f;
}
for (int l = 0; l < nb; l++) {
for (int k = 0; k < 16; k++) {
// k = 0.. 7 weights 0-63 low, 64-127 high
// k = 8..15 weights 128-191 low, 192-255 high
const int base_l = (k / 8) * 128 + (k % 8) * 8;
const int base_h = base_l + 64;
const int scale_idx_l = base_l / 16;
const int scale_idx_h = base_h / 16;
// Bit shift cycles 0,2,4,6 for each 32-value group within a 128-value half
const int qh_shift_l = ((base_l % 128) / 32) * 2;
const int qh_shift_h = ((base_h % 128) / 32) * 2;
// qh_half: offset to the correct 32-byte half (0 or 32)
const int qh_half_l = (base_l / 128) * 32;
const int qh_half_h = (base_h / 128) * 32;
for (int j = 0; j < ncols_interleaved; j++) {
// Interleaved scales
const int8_t scale_l = b_ptr[l].scales[scale_idx_l * 8 + j];
const int8_t scale_h = b_ptr[l].scales[scale_idx_h * 8 + j];
int sumi_l = 0;
int sumi_h = 0;
for (int i = 0; i < blocklen; i++) {
const int ql_pos = k * 64 + j * 8 + i;
const int l_4 = b_ptr[l].ql[ql_pos] & 0xF;
const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF;
// qh indexing with 8-byte interleaving (like q5_K)
const int qh_byte_l = qh_half_l + ((base_l + i) % 32);
const int qh_chunk_l = qh_byte_l / 8;
const int qh_pos_l = qh_byte_l % 8;
const int qh_offset_l = qh_chunk_l * 64 + j * 8 + qh_pos_l;
const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3;
const int qh_byte_h = qh_half_h + ((base_h + i) % 32);
const int qh_chunk_h = qh_byte_h / 8;
const int qh_pos_h = qh_byte_h % 8;
const int qh_offset_h = qh_chunk_h * 64 + j * 8 + qh_pos_h;
const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3;
const int q_l = ((hi_2_l << 4) | l_4) - 32;
const int q_h = ((hi_2_h << 4) | hi_4) - 32;
const int8_t a_l = a_ptr[l].qs[base_l + i];
const int8_t a_h = a_ptr[l].qs[base_h + i];
sumi_l += q_l * a_l;
sumi_h += q_h * a_h;
}
sumf[j] +=
(sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
}
}
}
for (int j = 0; j < ncols_interleaved; j++) {
s[x * ncols_interleaved + j] = sumf[j];
}
}
ggml_gemv_q6_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@@ -1485,109 +1597,12 @@ void ggml_gemm_q5_K_8x8_q8_K_generic(int n,
}
}
void ggml_gemm_q6_K_8x8_q8_K_generic(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
const int qk = QK_K;
const int nb = n / qk;
const int ncols_interleaved = 8;
const int blocklen = 8;
void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
ggml_gemm_q6_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc);
}
assert(n % qk == 0);
assert(nr % 4 == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(bs);
float sumf[4][8];
for (int y = 0; y < nr / 4; y++) {
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb);
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
sumf[m][j] = 0.0f;
}
}
for (int l = 0; l < nb; l++) {
for (int k = 0; k < 16; k++) {
// k = 0.. 7 weights 0-63 low, 64-127 high
// k = 8..15 weights 128-191 low, 192-255 high
const int base_l = (k / 8) * 128 + (k % 8) * 8;
const int base_h = base_l + 64;
const int scale_idx_l = base_l / 16;
const int scale_idx_h = base_h / 16;
// Bit shift cycles 0,2,4,6 for each 32-value group within a 128-value half
const int qh_shift_l = ((base_l % 128) / 32) * 2;
const int qh_shift_h = ((base_h % 128) / 32) * 2;
// qh_half: offset to the correct 32-byte half (0 or 32)
const int qh_half_l = (base_l / 128) * 32;
const int qh_half_h = (base_h / 128) * 32;
// Activation base indices for q8_Kx4 interleaved format
// Layout: 128-value halves (k/8), then 8-value sub-blocks (k%8) with stride 32
const int q8_base = (k / 8) * 512 + (k % 8) * 32;
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
// Interleaved scales
const int8_t scale_l = b_ptr[l].scales[scale_idx_l * 8 + j];
const int8_t scale_h = b_ptr[l].scales[scale_idx_h * 8 + j];
int sumi_l = 0;
int sumi_h = 0;
for (int i = 0; i < blocklen; i++) {
const int ql_pos = k * 64 + j * 8 + i;
const int l_4 = b_ptr[l].ql[ql_pos] & 0xF;
const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF;
const int qh_idx_l = qh_half_l + ((base_l + i) % 32);
const int qh_chunk_l = qh_idx_l / 8;
const int qh_pos_l = qh_idx_l % 8;
const int qh_offset_l = qh_chunk_l * 64 + j * 8 + qh_pos_l;
const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3;
const int qh_idx_h = qh_half_h + ((base_h + i) % 32);
const int qh_chunk_h = qh_idx_h / 8;
const int qh_pos_h = qh_idx_h % 8;
const int qh_offset_h = qh_chunk_h * 64 + j * 8 + qh_pos_h;
const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3;
const int q_l = ((hi_2_l << 4) | l_4) - 32;
const int q_h = ((hi_2_h << 4) | hi_4) - 32;
const int8_t q8_l = a_ptr[l].qs[q8_base + m * 8 + i];
const int8_t q8_h = a_ptr[l].qs[q8_base + m * 8 + i + 256];
sumi_l += q_l * q8_l;
sumi_h += q_h * q8_h;
}
sumf[m][j] += (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) *
a_ptr[l].d[m];
}
}
}
}
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
}
}
}
}
void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
ggml_gemm_q6_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@@ -1901,9 +1916,10 @@ static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_in
int src_offset = (i / 8) * blck_size_interleave;
int dst_offset = i * blck_size_interleave;
// buffer large enough for the max interleave block size (8 bytes)
uint64_t elems;
memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
memcpy(&elems, &in[src_id].qs[src_offset], blck_size_interleave);
memcpy(&out.qs[dst_offset], &elems, blck_size_interleave);
}
// The below logic is designed so as to unpack and rearrange scales and mins values in Q4_K
@@ -2097,18 +2113,18 @@ static block_q6_Kx8 make_block_q6_Kx8(block_q6_K * in, unsigned int blck_size_in
}
const int end_ls = QK_K * 4 / blck_size_interleave;
// Interleave Q6_K quants by taking 8 bytes at a time
// Interleave Q6_K quants by taking blck_size_interleave bytes at a time
for (int i = 0; i < end_ls; ++i) {
int src_id = i % n_blocks;
int src_offset = (i / n_blocks) * blck_size_interleave;
int dst_offset = i * blck_size_interleave;
uint64_t elem_ls;
memcpy(&elem_ls, &in[src_id].ql[src_offset], sizeof(uint64_t));
memcpy(&out.ql[dst_offset], &elem_ls, sizeof(uint64_t));
memcpy(&elem_ls, &in[src_id].ql[src_offset], blck_size_interleave);
memcpy(&out.ql[dst_offset], &elem_ls, blck_size_interleave);
}
// Interleave high bits using same 8-byte pattern as low bits
// Interleave high bits using same chunk size as low bits
const int end_hs = end_ls / 2;
for (int i = 0; i < end_hs; ++i) {
int src_id = i % n_blocks;
@@ -2116,8 +2132,8 @@ static block_q6_Kx8 make_block_q6_Kx8(block_q6_K * in, unsigned int blck_size_in
int dst_offset = i * blck_size_interleave;
uint64_t elem_hs;
memcpy(&elem_hs, &in[src_id].qh[src_offset], sizeof(uint64_t));
memcpy(&out.qh[dst_offset], &elem_hs, sizeof(uint64_t));
memcpy(&elem_hs, &in[src_id].qh[src_offset], blck_size_interleave);
memcpy(&out.qh[dst_offset], &elem_hs, blck_size_interleave);
}
// The below logic is designed so as to unpack and rearrange scales in Q6_K
@@ -2262,7 +2278,7 @@ static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor * t,
static int repack_q6_K_to_q6_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
GGML_ASSERT(t->type == GGML_TYPE_Q6_K);
GGML_ASSERT(interleave_block == 8);
GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
constexpr int nrows_interleaved = 8;
block_q6_Kx8 * dst = (block_q6_Kx8 *)t->data;
@@ -2511,6 +2527,10 @@ template <> int repack<block_q5_K, 8, 8>(struct ggml_tensor * t, const void * da
return repack_q5_K_to_q5_K_8_bl(t, 8, data, data_size);
}
template <> int repack<block_q6_K, 4, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
return repack_q6_K_to_q6_K_8_bl(t, 4, data, data_size);
}
template <> int repack<block_q6_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
return repack_q6_K_to_q6_K_8_bl(t, 8, data, data_size);
}
@@ -2575,6 +2595,10 @@ template <> void gemv<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t
ggml_gemv_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
}
template <> void gemv<block_q6_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_q6_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
}
template <> void gemv<block_q6_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
}
@@ -2634,6 +2658,10 @@ template <> void gemm<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t
ggml_gemm_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
}
template <> void gemm<block_q6_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_q6_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
}
template <> void gemm<block_q6_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
}
@@ -3043,6 +3071,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
static const ggml::cpu::repack::tensor_traits<block_q5_K, 8, 8, GGML_TYPE_Q8_K> q5_K_8x8_q8_K;
// instance for Q6_K
static const ggml::cpu::repack::tensor_traits<block_q6_K, 4, 8, GGML_TYPE_Q8_K> q6_K_8x4_q8_K;
static const ggml::cpu::repack::tensor_traits<block_q6_K, 8, 8, GGML_TYPE_Q8_K> q6_K_8x8_q8_K;
// instance for Q2
@@ -3107,6 +3136,11 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
return &q6_K_8x8_q8_K;
}
}
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
if (cur->ne[1] % 8 == 0) {
return &q6_K_8x4_q8_K;
}
}
} else if (cur->type == GGML_TYPE_IQ4_NL) {
if (ggml_cpu_has_avx2()) {
if (cur->ne[1] % 8 == 0) {
+4
View File
@@ -112,6 +112,7 @@ void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@@ -122,6 +123,7 @@ void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@@ -142,6 +144,7 @@ void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@@ -152,6 +155,7 @@ void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+136
View File
@@ -0,0 +1,136 @@
#pragma once
// Computes C[M x N] += A[M x K] * B[K x N]
#include "simd-mappings.h"
// TODO: add support for sizeless vector types
#if defined(GGML_SIMD) && !defined(__ARM_FEATURE_SVE) && !defined(__riscv_v_intrinsic)
// TODO: untested on avx512
// These are in units of GGML_F32_EPR
#if defined(__AVX512F__) || defined (__ARM_NEON__)
static constexpr int GEMM_RM = 4;
static constexpr int GEMM_RN = 4; // 16+4+1 = 25/32
#elif defined(__AVX2__) || defined(__AVX__)
static constexpr int GEMM_RM = 6;
static constexpr int GEMM_RN = 2; // 12+2+1 = 15/16
#else
static constexpr int GEMM_RM = 2;
static constexpr int GEMM_RN = 2;
#endif
template <int RM, int RN>
static inline void simd_gemm_ukernel(
float * GGML_RESTRICT C,
const float * GGML_RESTRICT A,
const float * GGML_RESTRICT B,
int K, int N)
{
static constexpr int KN = GGML_F32_EPR;
GGML_F32_VEC acc[RM][RN];
for (int64_t i = 0; i < RM; i++) {
for (int r = 0; r < RN; r++) {
acc[i][r] = GGML_F32_VEC_LOAD(C + i * N + r * KN);
}
}
for (int64_t kk = 0; kk < K; kk++) {
GGML_F32_VEC Bv[RN];
for (int r = 0; r < RN; r++) {
Bv[r] = GGML_F32_VEC_LOAD(B + kk * N + r * KN);
}
for (int64_t i = 0; i < RM; i++) {
GGML_F32_VEC p = GGML_F32_VEC_SET1(A[i * K + kk]);
for (int r = 0; r < RN; r++) {
acc[i][r] = GGML_F32_VEC_FMA(acc[i][r], Bv[r], p);
}
}
}
for (int64_t i = 0; i < RM; i++) {
for (int r = 0; r < RN; r++) {
GGML_F32_VEC_STORE(C + i * N + r * KN, acc[i][r]);
}
}
}
// C[M x N] += A[M x K] * B[K x N]
static void simd_gemm(
float * GGML_RESTRICT C,
const float * GGML_RESTRICT A,
const float * GGML_RESTRICT B,
int M, int K, int N)
{
static constexpr int KN = GGML_F32_EPR;
int64_t ii = 0;
for (; ii + GEMM_RM <= M; ii += GEMM_RM) {
int64_t jj = 0;
for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) {
simd_gemm_ukernel<GEMM_RM, GEMM_RN>(C + jj, A, B + jj, K, N);
}
for (; jj + KN <= N; jj += KN) {
simd_gemm_ukernel<GEMM_RM, 1>(C + jj, A, B + jj, K, N);
}
for (; jj < N; jj++) {
for (int64_t i = 0; i < GEMM_RM; i++) {
float a = C[i * N + jj];
for (int64_t kk = 0; kk < K; kk++) {
a += A[i + kk] * B[kk * N + jj];
}
C[i * N + jj] = a;
}
}
A += GEMM_RM * K;
C += GEMM_RM * N;
}
// Tail rows: one at a time
for (; ii < M; ii++) {
int64_t jj = 0;
for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) {
simd_gemm_ukernel<1, GEMM_RN>(C + jj, A, B + jj, K, N);
}
for (; jj + KN <= N; jj += KN) {
simd_gemm_ukernel<1, 1>(C + jj, A, B + jj, K, N);
}
for (; jj < N; jj++) {
float a = C[jj];
for (int64_t kk = 0; kk < K; kk++) {
a += A[kk] * B[kk * N + jj];
}
C[jj] = a;
}
A += K;
C += N;
}
}
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic pop
#endif
#else // scalar path
static void simd_gemm(
float * GGML_RESTRICT C,
const float * GGML_RESTRICT A,
const float * GGML_RESTRICT B,
int M, int K, int N)
{
for (int64_t i = 0; i < M; i++) {
for (int64_t j = 0; j < N; j++) {
float sum = C[i * N + j];
for (int64_t kk = 0; kk < K; kk++) {
sum += A[i * K + kk] * B[kk * N + j];
}
C[i * N + j] = sum;
}
}
}
#endif // GGML_SIMD
+26
View File
@@ -1160,6 +1160,14 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
float32x4_t tmp = x[0] + vec_reve(x[0]); \
res = tmp[0] + tmp[1]; \
}
#define GGML_F32x4_REDUCE_4(res, s0, s1, s2, s3) \
{ \
float32x4_t v = vec_add(vec_add(s0, s1), \
vec_add(s2, s3)); \
v = vec_add(v, vec_sld(v, v, 8)); \
v = vec_add(v, vec_sld(v, v, 4)); \
res += (ggml_float)vec_extract(v, 0); \
}
#define GGML_F32_VEC GGML_F32x4
#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
@@ -1209,6 +1217,24 @@ static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) {
#define GGML_F16_VEC_MUL GGML_F32x4_MUL
#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
// BF16 s390x
#define GGML_BF16_STEP 16
#define GGML_BF16_EPR 8
#define GGML_BF16x8 __vector unsigned short
#define GGML_BF16x8_ZERO vec_splats((unsigned short)0)
#define GGML_BF16x8_LOAD(p) vec_xl(0, (const unsigned short *)(p))
#define GGML_BF16_VEC GGML_BF16x8
#define GGML_BF16_VEC_ZERO GGML_BF16x8_ZERO
#define GGML_BF16_VEC_LOAD GGML_BF16x8_LOAD
#define GGML_BF16_TO_F32_LO(v) ((float32x4_t) vec_mergel((v), GGML_BF16_VEC_ZERO))
#define GGML_BF16_TO_F32_HI(v) ((float32x4_t) vec_mergeh((v), GGML_BF16_VEC_ZERO))
#define GGML_BF16_FMA_LO(acc, x, y) \
(acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_LO(x), GGML_BF16_TO_F32_LO(y))
#define GGML_BF16_FMA_HI(acc, x, y) \
(acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_HI(x), GGML_BF16_TO_F32_HI(y))
#elif defined(__riscv_v_intrinsic)
// compatible with vlen >= 128
+1 -1
View File
@@ -111,7 +111,7 @@ template <float (*op)(float), typename src0_t, typename dst_t>
static void apply_unary_op(const ggml_compute_params * params, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst));
GGML_ASSERT(ggml_is_contiguous_rows(src0) && ggml_is_contiguous_rows(dst) && ggml_are_same_shape(src0, dst));
GGML_TENSOR_UNARY_OP_LOCALS
+1 -2
View File
@@ -236,8 +236,7 @@ void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t *
vfloat32m1_t redsum = __riscv_vfredusum_vs_f32m4_f32m1(vsum0, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
sumf += __riscv_vfmv_f_s_f32m1_f32(redsum);
#endif
#if defined(__POWER9_VECTOR__)
#elif defined(__POWER9_VECTOR__) || defined(__VXE__) || defined(__VXE2__)
const int np = (n & ~(GGML_BF16_STEP - 1));
if (np > 0) {
GGML_F32_VEC sum[4] = {GGML_F32_VEC_ZERO};
+1 -1
View File
@@ -64,7 +64,7 @@ if (CUDAToolkit_FOUND)
FetchContent_Declare(
CCCL
GIT_REPOSITORY https://github.com/nvidia/cccl.git
GIT_TAG v3.2.0-rc2
GIT_TAG v3.2.0
GIT_SHALLOW TRUE
)
+32 -30
View File
@@ -39,13 +39,16 @@ static __global__ void k_bin_bcast(const src0_t * src0,
const uint3 ne11,
const uint3 ne12,
const uint3 ne13,
/*int s0, */ const int s1,
/*const int s0,*/
const int s1,
const int s2,
const int s3,
/*int s00,*/ const int s01,
const int s00,
const int s01,
const int s02,
const int s03,
/*int s10,*/ const int s11,
const int s10,
const int s11,
const int s12,
const int s13,
src1_ptrs... src1s) {
@@ -72,11 +75,11 @@ static __global__ void k_bin_bcast(const src0_t * src0,
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
const uint32_t i10 = fastmodulo(i0, ne10);
float result = src0_row ? (float) src0_row[i0] : 0.0f;
float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
if constexpr (sizeof...(src1_ptrs) > 0) {
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
} else {
result = bin_op(result, (float)src1[i_src1 + i10]);
result = bin_op(result, (float)src1[i_src1 + i10*s10]);
}
dst_row[i0] = (dst_t) result;
@@ -101,13 +104,16 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,
const uint3 ne11,
const uint3 ne12,
const uint3 ne13,
/*int s0, */ const int s1,
/*const int s0,*/
const int s1,
const int s2,
const int s3,
/*int s00,*/ const int s01,
const int s00,
const int s01,
const int s02,
const int s03,
/*int s10,*/ const int s11,
const int s10,
const int s11,
const int s12,
const int s13,
src1_ptrs... src1s) {
@@ -135,11 +141,11 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,
const int i10 = fastmodulo(i0, ne10);
float result = src0_row ? (float) src0_row[i0] : 0.0f;
float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
if constexpr (sizeof...(src1_ptrs) > 0) {
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
} else {
result = bin_op(result, (float)src1[i_src1 + i10]);
result = bin_op(result, (float)src1[i_src1 + i10*s10]);
}
dst_row[i0] = (dst_t) result;
@@ -179,7 +185,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
cnb[3] *= cne[3];
};
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
for (int i = 0; i < 4; i++) {
if (nr[i] != 1) {
break;
@@ -221,7 +227,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
size_t nb12 = cnb1[2];
size_t nb13 = cnb1[3];
size_t s0 = nb0 / sizeof(dst_t);
//size_t s0 = nb0 / sizeof(dst_t);
size_t s1 = nb1 / sizeof(dst_t);
size_t s2 = nb2 / sizeof(dst_t);
size_t s3 = nb3 / sizeof(dst_t);
@@ -251,10 +257,6 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
GGML_ASSERT(s0 == 1);
GGML_ASSERT(s00 == 1);
GGML_ASSERT(s10 == 1);
const int block_size = 128;
int64_t hne0 = std::max(ne0 / 2LL, 1LL);
@@ -284,31 +286,31 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t><<<block_num, block_size, 0, stream>>>(
src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11,
ne12, ne13,
/* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
/*s0,*/ s1, s2, s3,
s00, s01, s02, s03,
s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
} else {
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv,
ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12, s13);
/*s0,*/ s1, s2, s3,
s00, s01, s02, s03,
s10, s11, s12, s13);
}
} else {
const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
if constexpr (sizeof...(I) > 0) {
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
/*s0,*/ s1, s2, s3,
s00 ,s01, s02, s03,
s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
} else {
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12, s13);
/*s0,*/ s1, s2, s3,
s00, s01, s02, s03,
s10, s11, s12, s13);
}
}
}
+38 -24
View File
@@ -7,7 +7,8 @@
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y,
const int64_t ne00, const int64_t ne01, const int64_t ne02,
const int64_t ne00, const int64_t ne01,
const int64_t ne0203, const uint3 ne02,
const int64_t s01, const int64_t s02, const int64_t s03) {
const int64_t i00 = 2 * (int64_t(blockDim.x)*blockIdx.x + threadIdx.x);
@@ -16,23 +17,27 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
}
const int64_t i01 = blockIdx.y;
const int64_t i02 = blockIdx.z % ne02;
const int64_t i03 = blockIdx.z / ne02;
const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {
const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);
const int64_t i02 = dm.y;
const int64_t i03 = dm.x;
const int64_t ib = ibx0 + i00/qk; // block index
const int64_t iqs = (i00%qk)/qr; // quant index
const int64_t iybs = i00 - i00%qk; // y block start index
const int64_t y_offset = qr == 1 ? 1 : qk/2;
const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
// dequantize
float2 v;
dequantize_kernel(vx, ib, iqs, v);
const int64_t ib = ibx0 + i00/qk; // block index
const int64_t iqs = (i00%qk)/qr; // quant index
const int64_t iybs = i00 - i00%qk; // y block start index
const int64_t y_offset = qr == 1 ? 1 : qk/2;
const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs;
y[iy0 + 0] = ggml_cuda_cast<dst_t>(v.x);
y[iy0 + y_offset] = ggml_cuda_cast<dst_t>(v.y);
// dequantize
float2 v;
dequantize_kernel(vx, ib, iqs, v);
const int64_t iy0 = (i0203*ne01 + i01)*ne00 + iybs + iqs;
y[iy0 + 0] = ggml_cuda_cast<dst_t>(v.x);
y[iy0 + y_offset] = ggml_cuda_cast<dst_t>(v.y);
}
}
template <bool need_check>
@@ -485,9 +490,11 @@ template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static void dequantize_block_cuda(const void * vx, dst_t * y,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, ne02*ne03);
const int64_t ne0203 = ne02*ne03;
const uint3 ne02_fdv = init_fastdiv_values(ne02);
const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, (int)std::min(ne0203, (int64_t)65535));
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
(vx, y, ne00, ne01, ne02, s01, s02, s03);
(vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03);
}
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
@@ -612,7 +619,8 @@ static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t
template <typename src_t, typename dst_t>
static __global__ void convert_unary(
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01,
const int64_t ne0203, const uint3 ne02,
const int64_t s01, const int64_t s02, const int64_t s03) {
const int64_t i00 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
@@ -621,23 +629,29 @@ static __global__ void convert_unary(
}
const int64_t i01 = blockIdx.y;
const int64_t i02 = blockIdx.z % ne02;
const int64_t i03 = blockIdx.z / ne02;
const src_t * x = (const src_t *) vx;
const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
const int64_t iy = ((i03*ne02 + i02)*ne01 + i01)*ne00 + i00;
y[iy] = ggml_cuda_cast<dst_t>(x[ix]);
for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {
const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);
const int64_t i02 = dm.y;
const int64_t i03 = dm.x;
const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
const int64_t iy = (i0203*ne01 + i01)*ne00 + i00;
y[iy] = ggml_cuda_cast<dst_t>(x[ix]);
}
}
template <typename src_t, typename dst_t>
static void convert_unary_cuda(const void * vx, dst_t * y,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, ne02*ne03);
const int64_t ne0203 = ne02*ne03;
const uint3 ne02_fdv = init_fastdiv_values(ne02);
const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, (int)std::min(ne0203, (int64_t)65535));
convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
(vx, y, ne00, ne01, ne02, s01, s02, s03);
(vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03);
}
template <typename src_t, typename dst_t>
+3 -1
View File
@@ -1186,8 +1186,10 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
const int gqa_ratio = Q->ne[2] / K->ne[2];
// On NVIDIA (Pascal and older) the GQA optimizations seem to be detrimental in some cases.
// However, for DKQ == 576, DV == 512 only the kernel variant with GQA optimizations is implemented.
const bool nvidia = GGML_CUDA_CC_IS_NVIDIA(ggml_cuda_info().devices[ggml_cuda_get_device()].cc);
const int gqa_limit = nvidia && gqa_ratio <= 4 ? 16 : INT_MAX;
const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX;
const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0;
if constexpr (DV == 512) {
+26 -5
View File
@@ -63,11 +63,19 @@ static __global__ void flash_attn_ext_f16(
constexpr int frag_m = ncols == 8 ? 32 : 16;
constexpr int frag_n = ncols == 8 ? 8 : 16;
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
#if defined(GGML_USE_HIP) && HIP_VERSION >= 60500000
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, _Float16, wmma::row_major> frag_a_K;
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, _Float16, wmma::col_major> frag_a_V;
typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, _Float16, wmma::col_major> frag_b;
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, _Float16> frag_c_VKQ;
#else
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::row_major> frag_a_K;
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::col_major> frag_a_V;
typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, half, wmma::col_major> frag_b;
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
#endif
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
@@ -126,6 +134,19 @@ static __global__ void flash_attn_ext_f16(
__shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
half2 * VKQ2 = (half2 *) VKQ;
#if defined(GGML_USE_HIP) && HIP_VERSION >= 60500000
const _Float16 * K_h_f16 = reinterpret_cast<const _Float16 *>(K_h);
const _Float16 * V_h_f16 = reinterpret_cast<const _Float16 *>(V_h);
_Float16 * KQ_f16 = reinterpret_cast<_Float16 *>(KQ);
_Float16 * VKQ_f16 = reinterpret_cast<_Float16 *>(VKQ);
#else
const half * K_h_f16 = K_h;
const half * V_h_f16 = V_h;
half * KQ_f16 = KQ;
half * VKQ_f16 = VKQ;
#endif
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
@@ -160,7 +181,7 @@ static __global__ void flash_attn_ext_f16(
for (int i0 = 0; i0 < D; i0 += 16) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ_f16 + j0*D_padded + i0, D_padded);
}
}
@@ -180,7 +201,7 @@ static __global__ void flash_attn_ext_f16(
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
frag_a_K K_a;
wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
wmma::load_matrix_sync(K_a, K_h_f16 + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
@@ -310,7 +331,7 @@ static __global__ void flash_attn_ext_f16(
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
wmma::load_matrix_sync(
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
KQ + j0*(kqar*kqs_padded) + k,
KQ_f16 + j0*(kqar*kqs_padded) + k,
kqar*kqs_padded);
}
}
@@ -328,7 +349,7 @@ static __global__ void flash_attn_ext_f16(
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
frag_a_V v_a;
wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
wmma::load_matrix_sync(v_a, V_h_f16 + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
@@ -344,7 +365,7 @@ static __global__ void flash_attn_ext_f16(
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
wmma::store_matrix_sync(
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
KQ_f16 + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
D_padded, wmma::mem_col_major);
}
+21 -35
View File
@@ -2278,11 +2278,12 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
// [TAG_MUL_MAT_ID_CUDA_GRAPHS]
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE);
if (ne2 <= MMVQ_MAX_BATCH_SIZE) {
if (ggml_is_quantized(src0->type)) {
if (ne2 <= 4) {
if (ne2 <= MMVQ_MMID_MAX_BATCH_SIZE) {
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
return;
}
@@ -2305,6 +2306,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
}
}
// note: this path should not be reached when recording CUDA graphs, because it requires stream synchronization
// TODO: add asserts to verify this. should work with CUDA, HIP, etc.
cudaStream_t stream = ctx.stream();
GGML_ASSERT(nb12 % nb11 == 0);
@@ -2865,14 +2868,6 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
bool use_cuda_graph = true;
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out";
const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d";
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
@@ -2887,30 +2882,14 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
#endif
}
if (node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) {
use_cuda_graph = false; // This node type is not supported by CUDA graph capture
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
#endif
}
if (node->op == GGML_OP_ADD &&
node->src[1] && node->src[1]->ne[1] > 1 &&
(node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
(node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 &&
strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 &&
strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) {
// disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
// by means of matching node names. See
// https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
// https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
// Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
// [TAG_MUL_MAT_ID_CUDA_GRAPHS]
if (node->op == GGML_OP_MUL_MAT_ID && (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > MMVQ_MMID_MAX_BATCH_SIZE)) {
// under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs
// TODO: figure out a way to enable for larger batch sizes, without hurting performance
// ref: https://github.com/ggml-org/llama.cpp/pull/18958
use_cuda_graph = false;
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
#endif
}
@@ -3640,11 +3619,13 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
n_fuse++;
if (n_fuse > 1) {
ggml_tensor fused_add_node;
memcpy(&fused_add_node, node, sizeof(ggml_tensor));
for (int j = 0; j < n_fuse - 1; ++j) {
node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
fused_add_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
}
cgraph->nodes[i + n_fuse - 1]->data = node->data;
ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse);
fused_add_node.data = cgraph->nodes[i + n_fuse - 1]->data;
ggml_cuda_op_fused_add(*cuda_ctx, &fused_add_node, n_fuse);
i += n_fuse - 1;
continue;
@@ -4542,6 +4523,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_UNARY_OP_CEIL:
case GGML_UNARY_OP_ROUND:
case GGML_UNARY_OP_TRUNC:
// TODO: should become:
//return ggml_is_contiguous_rows(op->src[0]);
return ggml_is_contiguous(op->src[0]);
default:
return false;
@@ -4820,8 +4803,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_CONV_2D_DW:
case GGML_OP_CONV_TRANSPOSE_2D:
case GGML_OP_POOL_2D:
case GGML_OP_ACC:
return true;
case GGML_OP_ACC:
// TODO: extend support like so:
//return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]);
return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
case GGML_OP_SUM:
return ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_TOP_K:
+21 -16
View File
@@ -2715,14 +2715,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
#pragma unroll
for (int l = 0; l < QR2_XXS; ++l) {
const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]);
const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F];
const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[l]];
const uint32_t signs = unpack_ksigns(aux32 >> (7 * l));
const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0);
const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
const int grid1 = __vsub4(grid_pos.y ^ signs1, signs1);
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
@@ -2733,12 +2733,12 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
const int ls = aux32 >> 28;
const int ls = aux32 >> 27 | 1; // (scale * 2 + 1)
const float d = bxi->d;
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4
#else
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4;
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
}
@@ -2776,11 +2776,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
#pragma unroll
for (int l = 0; l < QR2_XS; ++l) {
const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF));
const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9));
const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l] & 0x1FF];
const uint32_t signs = unpack_ksigns(q2[l] >> 9);
const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
@@ -2904,11 +2907,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
#pragma unroll
for (int l = 0; l < QR3_XXS; ++l) {
const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]);
const uint32_t signs = unpack_ksigns(aux32 >> (7*l));
const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F));
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
+1
View File
@@ -1,6 +1,7 @@
#include "common.cuh"
#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
#define MMVQ_MMID_MAX_BATCH_SIZE 4 // Max. batch size for which to use MMVQ kernels for MUL_MAT_ID
void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr);
+31 -17
View File
@@ -94,6 +94,15 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con
#endif
}
static __device__ __forceinline__ uint32_t unpack_ksigns(const uint8_t v) {
// v is a 7 bit int, with the 8th sign being encodable as popcnt
// with xor we can "correct" the bit instead of having to mask
const uint32_t p = __popc(v) & 1;
const uint32_t s = v ^ p << 7;
// broadcast over uint to allow for 0x08040201 / 0x80402010 as selectors
return s * 0x01010101;
}
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
@@ -905,22 +914,22 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
int sumi = 0;
#pragma unroll
for (int k0 = 0; k0 < 8; k0 += 2) {
const int * grid_pos = (const int *) (iq2xxs_grid + aux8[k0/2]);
const int signs_packed = ksigns_iq2xs[(aux32 >> (7*k0/2)) & 0x7F];
const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[k0/2]];
const uint32_t signs = unpack_ksigns(aux32 >> (7 * k0 / 2));
const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0);
const int u0 = get_int_b4(bq8_1[iqs/2].qs, k0 + 0);
sumi = ggml_cuda_dp4a(grid0, u0, sumi);
const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
const int grid1 = __vsub4(grid_pos.y ^ signs1, signs1);
const int u1 = get_int_b4(bq8_1[iqs/2].qs, k0 + 1);
sumi = ggml_cuda_dp4a(grid1, u1, sumi);
}
const int ls = aux32 >> 28;
sumi = (ls*sumi + sumi/2)/4;
const int ls = aux32 >> 27 | 1; // (scale * 2 + 1)
sumi = sumi * ls / 8; // (sumi * scale + sumi / 2) / 4
const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds);
return d * sumi;
}
@@ -942,13 +951,15 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
int sumi1 = 0;
#pragma unroll
for (int l0 = 0; l0 < 8; l0 += 2) {
const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l0/2] & 0x000001FF));
const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l0/2] >> 9));
const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l0/2] & 0x1FF];
const uint32_t signs = unpack_ksigns(q2[l0/2] >> 9);
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
if (l0 < 4) {
@@ -1028,13 +1039,16 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
#pragma unroll
for (int l0 = 0; l0 < 8; l0 += 2) {
const int2 grid_pos = make_int2(iq3xxs_grid[q3[l0 + 0]], iq3xxs_grid[q3[l0 + 1]]);
const uint32_t signs = unpack_ksigns(aux32 >> (7*l0/2));
const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l0/2)) & 0x7F));
const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
sumi = ggml_cuda_dp4a(grid_l, u0, sumi);
+106 -7
View File
@@ -1935,11 +1935,6 @@ static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * se
return false;
}
// TODO: add support for non-contigiuos tensors
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
return false;
}
return true;
}
@@ -1991,6 +1986,25 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses
return true;
}
static bool ggml_hexagon_supported_sum_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
const struct ggml_tensor * src0 = op->src[0];
const struct ggml_tensor * dst = op;
if (!hex_supported_src0_type(src0->type)) {
return false;
}
if (!hex_supported_dst_type(dst->type)) {
return false;
}
// TODO: add support for non-contigiuos tensors
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {
return false;
}
return true;
}
static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session * sess,
const struct ggml_tensor * op) {
const struct ggml_tensor * src0 = op->src[0];
@@ -2111,6 +2125,26 @@ static bool ggml_hexagon_supported_get_rows(const struct ggml_hexagon_session *
return true;
}
static bool ggml_hexagon_supported_argsort(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
const struct ggml_tensor * src0 = op->src[0]; // values
const struct ggml_tensor * dst = op; // indices
if (src0->type != GGML_TYPE_F32) {
return false;
}
if (dst->type != GGML_TYPE_I32) {
return false;
}
if (src0->ne[0] > (16*1024)) {
// reject tensors with huge rows for now
return false;
}
return true;
}
static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
const int32_t * op_params = &op->op_params[0];
@@ -2278,6 +2312,9 @@ static inline size_t init_binary_req(htp_general_req * req, dspqueue_buffer * bu
case GGML_OP_SUB:
req->op = HTP_OP_SUB;
break;
case GGML_OP_DIV:
req->op = HTP_OP_DIV;
break;
default:
GGML_ABORT("ggml-hex: binary : unsupported op: %d\n", t->op);
break;
@@ -2316,6 +2353,17 @@ static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer *
return n_bufs;
}
static inline size_t init_argsort_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
req->op = HTP_OP_ARGSORT;
memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
size_t n_bufs = 0;
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
return n_bufs;
}
template <bool _is_src0_constant>
static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
switch (t->op) {
@@ -2370,6 +2418,16 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf
supported = true;
break;
case GGML_OP_SQR:
req->op = HTP_OP_SQR;
supported = true;
break;
case GGML_OP_SQRT:
req->op = HTP_OP_SQRT;
supported = true;
break;
case GGML_OP_UNARY:
if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) {
req->op = HTP_OP_UNARY_SILU;
@@ -2387,6 +2445,9 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf
} else if (ggml_get_glu_op(t) == GGML_GLU_OP_SWIGLU_OAI) {
req->op = HTP_OP_GLU_SWIGLU_OAI;
supported = true;
} else if (ggml_get_glu_op(t) == GGML_GLU_OP_GEGLU) {
req->op = HTP_OP_GLU_GEGLU;
supported = true;
}
break;
@@ -2411,6 +2472,17 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf
return n_bufs;
}
static inline size_t init_sum_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
req->op = HTP_OP_SUM_ROWS;
size_t n_bufs = 0;
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
return n_bufs;
}
static inline size_t init_rope_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
req->op = HTP_OP_ROPE;
@@ -2519,6 +2591,7 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
case GGML_OP_MUL:
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_DIV:
ggml_hexagon_dispatch_op<init_binary_req<false>>(sess, node, flags);
break;
case GGML_OP_ADD_ID:
@@ -2528,6 +2601,13 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
case GGML_OP_SCALE:
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
break;
case GGML_OP_SQR:
case GGML_OP_SQRT:
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
break;
case GGML_OP_SUM_ROWS:
ggml_hexagon_dispatch_op<init_sum_rows_req>(sess, node, flags);
break;
case GGML_OP_UNARY:
if ((ggml_get_unary_op(node) == GGML_UNARY_OP_SILU) ||
(ggml_get_unary_op(node) == GGML_UNARY_OP_GELU)) {
@@ -2536,7 +2616,8 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
break;
case GGML_OP_GLU:
if ((ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU) ||
(ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI)) {
(ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI) ||
(ggml_get_glu_op(node) == GGML_GLU_OP_GEGLU)) {
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
}
break;
@@ -2564,6 +2645,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
ggml_hexagon_dispatch_op<init_cpy_req>(sess, node, flags);
break;
case GGML_OP_ARGSORT:
ggml_hexagon_dispatch_op<init_argsort_req>(sess, node, flags);
break;
default:
GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node));
}
@@ -2916,6 +3001,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
case GGML_OP_MUL:
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_DIV:
supp = ggml_hexagon_supported_binary(sess, op);
break;
@@ -2928,6 +3014,15 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
supp = ggml_hexagon_supported_unary(sess, op);
break;
case GGML_OP_SQR:
case GGML_OP_SQRT:
supp = ggml_hexagon_supported_unary(sess, op);
break;
case GGML_OP_SUM_ROWS:
supp = ggml_hexagon_supported_sum_rows(sess, op);
break;
case GGML_OP_SOFT_MAX:
supp = ggml_hexagon_supported_softmax(sess, op);
break;
@@ -2943,7 +3038,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
case GGML_OP_GLU:
{
const auto glu_op = ggml_get_glu_op(op);
if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI)) {
if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI) || (glu_op == GGML_GLU_OP_GEGLU)) {
supp = ggml_hexagon_supported_activations(sess, op);
}
break;
@@ -2968,6 +3063,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
supp = ggml_hexagon_supported_cpy(sess, op);
break;
case GGML_OP_ARGSORT:
supp = ggml_hexagon_supported_argsort(sess, op);
break;
default:
break;
}
+3
View File
@@ -6,6 +6,7 @@ include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
include_directories(
${HEXAGON_SDK_ROOT}/incs
${HEXAGON_SDK_ROOT}/incs/stddef
${CMAKE_CURRENT_SOURCE_DIR}/../../../include
${CMAKE_CURRENT_SOURCE_DIR}/../..
${CMAKE_CURRENT_SOURCE_DIR}/..
${CMAKE_CURRENT_SOURCE_DIR}
@@ -21,6 +22,7 @@ add_library(${HTP_LIB} SHARED
matmul-ops.c
binary-ops.c
unary-ops.c
sum-rows-ops.c
softmax-ops.c
act-ops.c
rope-ops.c
@@ -28,6 +30,7 @@ add_library(${HTP_LIB} SHARED
set-rows-ops.c
get-rows-ops.c
cpy-ops.c
argsort-ops.c
)
target_compile_definitions(${HTP_LIB} PRIVATE
+150 -2
View File
@@ -410,7 +410,7 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
// gelu = x * sigmoid(1.702 * x) // current implementation
hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0);
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
hvx_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
}
dma_queue_push_vtcm_to_ddr(dma_queue,
@@ -516,7 +516,7 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
// silu = x * sigmoid(x)
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0);
hvx_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
}
dma_queue_push_vtcm_to_ddr(dma_queue,
@@ -541,6 +541,143 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
static const float GELU_COEF_A = 0.044715f;
static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
static void glu_geglu_f32_per_thread(const struct htp_tensor * src0,
const struct htp_tensor * src1,
struct htp_tensor * dst,
const int32_t * op_params,
struct htp_spad * src0_spad,
struct htp_spad * src1_spad,
struct htp_spad * dst_spad,
uint32_t nth,
uint32_t ith,
uint32_t src0_nrows_per_thread,
dma_queue * dma_queue) {
htp_act_preamble3;
size_t src0_row_size = nb01;
size_t src1_row_size = nb11;
size_t dst_row_size = nb1;
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
// no work for this thread
if (src0_start_row >= src0_end_row) {
return;
}
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
uint8_t * restrict data_dst = (uint8_t *) dst->data;
const bool src1_valid = src1->ne[0];
const int nc = (src1_valid) ? ne00 : ne00 / 2;
if (!src1_valid) {
const int32_t swapped = op_params[1];
data_src1 = data_src0;
src1_row_size = src0_row_size;
const size_t nc_in_bytes = nc * SIZEOF_FP32;
data_src0 += swapped ? nc_in_bytes : 0;
data_src1 += swapped ? 0 : nc_in_bytes;
}
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
if (BLOCK == 0) {
FARF(ERROR,
"geglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
src0_spad->size_per_thread, src0_row_size_aligned);
return;
}
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
// Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
dma_queue_push_vtcm_to_ddr(dma_queue,
dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
dst_row_size, dst_row_size_aligned, 0);
dma_queue_push_ddr_to_vtcm(dma_queue,
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)),
src0_row_size_aligned, src0_row_size, block_size);
dma_queue_push_ddr_to_vtcm(dma_queue,
dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + (ir * src1_row_size)),
src1_row_size_aligned, src1_row_size, block_size);
}
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
float * dst_spad = (float *) dma_queue_pop(dma_queue).src;
float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
float * src1_spad = (float *) dma_queue_pop(dma_queue).dst;
for (uint32_t ib = 0; ib < block_size; ib++) {
const uint8_t * src0_spad_ptr = (const uint8_t *)(src0_spad + ib * (src0_row_size_aligned / sizeof(float)));
const uint8_t * src1_spad_ptr = (const uint8_t *)(src1_spad + ib * (src1_row_size_aligned / sizeof(float)));
uint8_t * dst_spad_ptr = (uint8_t *)(dst_spad + ib * (dst_row_size_aligned / sizeof(float)));
// geglu tanh implementation
// geglu(x, g) = gelu(x) * g
// gelu(x) = 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)))
hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, src0_spad_ptr, nc); // res = x*x
hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, GELU_COEF_A, nc); // res = res * GELU_COEF_A
hvx_add_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, 1.0f, nc); // res = res + 1.0f
hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, (const uint8_t *)dst_spad_ptr, nc); // res = res * x
hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t*)dst_spad_ptr, SQRT_2_OVER_PI, nc); // res = result * SQRT_2_OVER_PI
hvx_tanh_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, nc); // res = tanh(res)
hvx_add_scalar_f32_aa(dst_spad_ptr, (const uint8_t*)dst_spad_ptr, 1.0f, nc); // res = res + 1.0f
hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, (const uint8_t *)dst_spad_ptr, nc); // res = res * x
hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, 0.5f, nc); // res = res + 0.5f
hvx_mul_f32_aaa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, src1_spad_ptr, nc); // res = res * g
}
dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size,
dst_row_size_aligned, block_size);
// prefetch N+2 loop iteration if any
const uint32_t pref_block = (ir + BLOCK * 2);
if (pref_block < src0_end_row) {
const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)),
src0_row_size_aligned, src0_row_size, pref_block_size);
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src1_spad, data_src1 + (pref_block * src1_row_size)),
src1_row_size_aligned, src1_row_size, pref_block_size);
}
}
dma_queue_flush(dma_queue);
t2 = HAP_perf_get_qtimer_count();
FARF(HIGH, "geglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
static void unary_silu_f32(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = (struct htp_ops_context *) data;
unary_silu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
@@ -559,6 +696,12 @@ static void glu_swiglu_oai_f32(unsigned int n, unsigned int i, void * data) {
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
}
static void glu_geglu_f32(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = (struct htp_ops_context *) data;
glu_geglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
}
static int execute_op_activations_f32(struct htp_ops_context * octx) {
int err = HTP_STATUS_OK;
@@ -593,6 +736,11 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
act_op_func = unary_gelu_f32;
op_type = "gelu-f32";
break;
case HTP_OP_GLU_GEGLU:
act_op_func = glu_geglu_f32;
op_type = "geglu-f32";
break;
default:
FARF(ERROR, "Unsupported activations Op %u\n", octx->op);
return HTP_STATUS_NO_SUPPORT;
+281
View File
@@ -0,0 +1,281 @@
#include <string.h>
#include <stdlib.h>
#include <math.h>
#include <HAP_farf.h>
#include <HAP_perf.h>
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "ggml.h"
#include "hvx-utils.h"
#include "hex-dma.h"
#include "htp-ctx.h"
#include "htp-msg.h"
#include "htp-ops.h"
#ifndef MIN
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#endif
struct htp_argsort_context {
struct htp_ops_context * octx;
uint32_t nrows_per_thread;
};
static inline bool all_greater_f32(HVX_Vector x, HVX_Vector y)
{
const HVX_Vector one = Q6_V_vsplat_R(1);
const HVX_Vector zero = Q6_V_vzero();
HVX_VectorPred pred = Q6_Q_vcmp_gt_VsfVsf(x, y);
HVX_Vector matches = Q6_V_vmux_QVV(pred, one, zero);
HVX_Vector sum = hvx_vec_reduce_sum_i32(matches);
return hvx_vec_get_i32(sum) == 32;
}
// Sorts values and mirrors swaps to indices.
static void quicksort_values_indices_asc(float * values, int32_t * indices, int left, int right) {
if (left >= right) return;
int pivot_idx = (left + right) / 2;
float pivot = values[pivot_idx];
int i = left;
int j = right;
HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot);
while (i <= j) {
// Vectorized scan for i
while (i <= j) {
// Check if we have at least one full vector
if (i + 32 <= j) {
HVX_Vector vals_vec = *(HVX_UVector *)(values + i);
if (all_greater_f32(pivot_vec, vals_vec)) {
// If all elements are < pivot, we can skip this whole block
i += 32;
continue;
}
}
// Scalar fallback / cleanup
if (values[i] < pivot) {
i++;
} else {
break;
}
}
// Vectorized scan for j
while (i <= j) {
if (j - 32 >= i) {
// Load 32 elements ending at j.
// Since we want `values[j] > pivot`, let's load from j-31 to j.
HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31);
if (all_greater_f32(vals_vec, pivot_vec)) {
j -= 32;
continue;
}
}
if (values[j] > pivot) {
j--;
} else {
break;
}
}
if (i <= j) {
float tmp_val = values[i];
values[i] = values[j];
values[j] = tmp_val;
int32_t tmp_idx = indices[i];
indices[i] = indices[j];
indices[j] = tmp_idx;
i++;
j--;
}
}
if (left < j) quicksort_values_indices_asc(values, indices, left, j);
if (i < right) quicksort_values_indices_asc(values, indices, i, right);
}
static void quicksort_values_indices_desc(float * values, int32_t * indices, int left, int right) {
if (left >= right) return;
int pivot_idx = (left + right) / 2;
float pivot = values[pivot_idx];
int i = left;
int j = right;
HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot);
while (i <= j) {
// Vectorized scan for i (values[i] > pivot)
while (i <= j) {
if (i + 32 <= j) {
HVX_Vector vals_vec = *(HVX_UVector *)(values + i);
if (all_greater_f32(vals_vec, pivot_vec)) {
i += 32;
continue;
}
}
if (values[i] > pivot) {
i++;
} else {
break;
}
}
// Vectorized scan for j (values[j] < pivot)
while (i <= j) {
if (j - 32 >= i) {
HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31);
if (all_greater_f32(pivot_vec, vals_vec)) {
j -= 32;
continue;
}
}
if (values[j] < pivot) {
j--;
} else {
break;
}
}
if (i <= j) {
float tmp_val = values[i];
values[i] = values[j];
values[j] = tmp_val;
int32_t tmp_idx = indices[i];
indices[i] = indices[j];
indices[j] = tmp_idx;
i++;
j--;
}
}
if (left < j) quicksort_values_indices_desc(values, indices, left, j);
if (i < right) quicksort_values_indices_desc(values, indices, i, right);
}
static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) {
struct htp_argsort_context * actx = (struct htp_argsort_context *)data;
struct htp_ops_context * octx = actx->octx;
// Unpack context
const struct htp_tensor * src0 = &octx->src0;
const struct htp_tensor * dst = &octx->dst;
// Scratchpad memory
uint8_t * spad = octx->src0_spad.data + octx->src0_spad.size_per_thread * i;
// Dimensions
uint32_t ne00 = src0->ne[0];
uint32_t ne01 = src0->ne[1];
uint32_t ne02 = src0->ne[2];
uint32_t ne03 = src0->ne[3];
uint32_t nb01 = src0->nb[1];
//uint32_t nb02 = src0->nb[2];
//uint32_t nb03 = src0->nb[3];
uint32_t nb1 = dst->nb[1];
//uint32_t nb2 = dst->nb[2];
//uint32_t nb3 = dst->nb[3];
// Sort order
enum ggml_sort_order order = (enum ggml_sort_order) octx->op_params[0];
// Rows to process
uint32_t total_rows = ne01 * ne02 * ne03;
uint32_t rows_per_thread = actx->nrows_per_thread;
uint32_t start_row = rows_per_thread * i;
uint32_t end_row = MIN(start_row + rows_per_thread, total_rows);
// Scratchpad layout:
// We need space for one row of float data (values) and one row of int32 indices.
// values: ne00 * sizeof(float)
// indices: ne00 * sizeof(int32_t)
// Padded to 128 bytes.
size_t values_size = hex_round_up(ne00 * sizeof(float), 128);
float * values_buf = (float *) spad;
int32_t * indices_buf = (int32_t *) (spad + values_size);
for (uint32_t r = start_row; r < end_row; r++) {
uint32_t src_offset = r * nb01;
uint32_t dst_offset = r * nb1;
uint8_t * src_ptr = (uint8_t *) src0->data + src_offset;
uint8_t * dst_ptr = (uint8_t *) dst->data + dst_offset;
hex_l2fetch(src_ptr, ne00 * sizeof(float), ne00 * sizeof(float), 1);
hvx_copy_f32_au((uint8_t*)values_buf, src_ptr, ne00);
// Initialize indices
for (uint32_t j = 0; j < ne00; j++) {
indices_buf[j] = j;
}
// Sort values and mirror swaps to indices
if (order == GGML_SORT_ORDER_ASC) {
quicksort_values_indices_asc(values_buf, indices_buf, 0, ne00 - 1);
} else {
quicksort_values_indices_desc(values_buf, indices_buf, 0, ne00 - 1);
}
// Copy indices back to DDR
hvx_copy_f32_ua(dst_ptr, (const uint8_t *) indices_buf, ne00);
}
}
int op_argsort(struct htp_ops_context * octx) {
// Check supported types
if (octx->src0.type != HTP_TYPE_F32) {
return HTP_STATUS_NO_SUPPORT;
}
// Allocate scratchpad
// We need 1 row of float + 1 row of int32 per thread.
uint32_t ne00 = octx->src0.ne[0];
size_t values_size = hex_round_up(ne00 * sizeof(float), 128);
size_t indices_size = hex_round_up(ne00 * sizeof(int32_t), 128);
size_t spad_per_thread = values_size + indices_size;
// Make sure we round up to 256 for alignment requirements
spad_per_thread = hex_round_up(spad_per_thread, 256);
size_t total_spad_size = spad_per_thread * octx->n_threads;
if (octx->ctx->vtcm_size < total_spad_size) {
FARF(ERROR, "argsort: VTCM size too small. Needed %zu, have %zu", total_spad_size, octx->ctx->vtcm_size);
return HTP_STATUS_VTCM_TOO_SMALL;
}
octx->src0_spad.data = octx->ctx->vtcm_base;
octx->src0_spad.size = total_spad_size;
octx->src0_spad.size_per_thread = spad_per_thread;
FARF(HIGH, "argsort: %ux%ux%ux%u -> %ux%ux%ux%u (0x%x, 0x%x)",
octx->src0.ne[0], octx->src0.ne[1], octx->src0.ne[2], octx->src0.ne[3],
octx->dst.ne[0], octx->dst.ne[1], octx->dst.ne[2], octx->dst.ne[3],
octx->src0.data, octx->dst.data);
uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3];
uint32_t n_jobs = MIN(total_rows, octx->n_threads);
struct htp_argsort_context actx;
actx.octx = octx;
actx.nrows_per_thread = (total_rows + n_jobs - 1) / n_jobs;
// Run jobs
worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_jobs);
return HTP_STATUS_OK;
}
File diff suppressed because it is too large Load Diff
+254 -272
View File
@@ -17,121 +17,6 @@
#include "htp-msg.h"
#include "htp-ops.h"
static inline HVX_Vector hvx_load_f32_to_f16(const HVX_Vector * restrict src, const HVX_Vector zero) {
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(src[0], zero); // 32 elements
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(src[1], zero); // 32 elements
return Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
}
// Dot product of FP32 and FP16 vectors, accumulating to float
static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) {
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
const HVX_Vector zero = Q6_V_vsplat_R(0);
HVX_Vector rsum = Q6_V_vsplat_R(0);
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; i++) {
// Load y (fp32) and convert into fp16
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
// Load x (fp16)
HVX_Vector x_hf = vx[i];
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
}
if (nloe) {
// Load y (fp32) and convert into fp16
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
// Load x (fp16)
HVX_Vector x_hf = vx[i];
// Zero-out unused elements
// Note that we need to clear both x and y because they may contain NANs
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
x_hf = Q6_V_vand_QV(bmask, x_hf);
y_hf = Q6_V_vand_QV(bmask, y_hf);
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
}
rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum));
hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum));
}
// Dot product of FP32 and FP16 vectors, accumulating to float
static inline void hvx_dot_f32_f16_aa_rx2(float * restrict r,
const void * restrict y,
const void * restrict x0,
const void * restrict x1,
unsigned int n,
float s) {
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16
const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
const HVX_Vector zero = Q6_V_vsplat_R(0);
HVX_Vector rsum0 = Q6_V_vsplat_R(0);
HVX_Vector rsum1 = Q6_V_vsplat_R(0);
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
// Load y (fp32) and convert into fp16
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
// Load x (fp16)
HVX_Vector x0_hf = vx0[i];
HVX_Vector x1_hf = vx1[i];
HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
}
if (nloe) {
// Load y (fp32) and convert into fp16
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
// Load x (fp16)
HVX_Vector x0_hf = vx0[i];
HVX_Vector x1_hf = vx1[i];
// Zero-out unused elements
// Note that we need to clear both x and y because they may contain NANs
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
x0_hf = Q6_V_vand_QV(bmask, x0_hf);
x1_hf = Q6_V_vand_QV(bmask, x1_hf);
y_hf = Q6_V_vand_QV(bmask, y_hf);
HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
}
HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1));
hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum));
}
// Dot product of two F16 vectors, accumulating to float
static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) {
const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
@@ -140,8 +25,7 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
const HVX_Vector zero = Q6_V_vsplat_R(0);
HVX_Vector rsum = Q6_V_vsplat_R(0);
HVX_Vector rsum = Q6_V_vsplat_R(0);
uint32_t i = 0;
@@ -156,11 +40,10 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
}
if (nloe) {
HVX_Vector y_hf = vy[i];
// Load x (fp16) and zero-out unused elements
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
@@ -181,12 +64,11 @@ static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r,
const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
const HVX_Vector zero = Q6_V_vsplat_R(0);
HVX_Vector rsum0 = Q6_V_vsplat_R(0);
HVX_Vector rsum1 = Q6_V_vsplat_R(0);
HVX_Vector rsum0 = Q6_V_vsplat_R(0);
HVX_Vector rsum1 = Q6_V_vsplat_R(0);
uint32_t i = 0;
@@ -204,12 +86,11 @@ static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r,
}
if (nloe) {
HVX_Vector y_hf = vy[i];
// Load x (fp16) and zero-out unused elements
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
@@ -222,7 +103,7 @@ static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r,
hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum));
}
// MAD: y (F32) += x (F16) * s (float)
// MAD: y (F32) += x (F16) * s (F32)
static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) {
const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x;
HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
@@ -259,15 +140,125 @@ static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict
}
}
// MAD: y (F32) += x0 (F16) * s0 (F32) + x1 (F16) * s1 (F32)
static inline void hvx_mad_f32_f16_aa_rx2(float * restrict y,
const void * restrict x0,
const void * restrict x1,
float s0,
float s1,
int n) {
const HVX_Vector * restrict ptr_x0 = (const HVX_Vector *) x0;
const HVX_Vector * restrict ptr_x1 = (const HVX_Vector *) x1;
HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
HVX_Vector S0 = hvx_vec_splat_f16(s0);
HVX_Vector S1 = hvx_vec_splat_f16(s1);
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; ++i) {
// Multiply x * s -> pair of F32 vectors
HVX_VectorPair xs0_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x0[i]), S0);
HVX_VectorPair xs1_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x1[i]), S1);
HVX_Vector xs_p_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xs0_p), Q6_V_lo_W(xs1_p));
HVX_Vector xs_p_hi = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(xs0_p), Q6_V_hi_W(xs1_p));
ptr_y[i * 2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs_p_lo, ptr_y[i * 2]));
ptr_y[i * 2 + 1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs_p_hi, ptr_y[i * 2 + 1]));
}
if (nloe) {
HVX_VectorPair xs0_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x0[i]), S0);
HVX_VectorPair xs1_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x1[i]), S1);
HVX_Vector xs_p_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xs0_p), Q6_V_lo_W(xs1_p));
HVX_Vector xs = xs_p_lo;
i = 2 * i; // index for ptr_y
if (nloe >= 32) {
ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
nloe -= 32; ++i;
xs = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(xs0_p), Q6_V_hi_W(xs1_p));
}
if (nloe) {
HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
hvx_vec_store_a(&ptr_y[i], nloe * 4, xy);
}
}
}
#define FLASH_ATTN_BLOCK_SIZE 128
static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, int nth) {
struct htp_fa_context {
const struct htp_ops_context * octx;
struct fastdiv_values src0_div21;
struct fastdiv_values src0_div1;
struct fastdiv_values broadcast_rk2;
struct fastdiv_values broadcast_rk3;
struct fastdiv_values broadcast_rv2;
struct fastdiv_values broadcast_rv3;
struct fastdiv_values src3_div2;
struct fastdiv_values src3_div3;
float scale;
float max_bias;
float logit_softcap;
uint32_t n_head_log2;
float m0;
float m1;
uint32_t n_blocks;
size_t size_q_row_padded;
size_t size_k_row_padded;
size_t size_v_row_padded;
size_t size_k_block;
size_t size_v_block;
size_t size_m_block;
bool is_q_fp32;
};
static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, HVX_Vector vs) {
assert((size_t) dst % 128 == 0);
assert((size_t) src % 128 == 0);
const HVX_Vector * restrict vsrc = (const HVX_Vector * restrict) src;
HVX_Vector * restrict vdst = (HVX_Vector * restrict) dst;
const uint32_t nvec = n / VLEN_FP32;
const uint32_t nloe = n % VLEN_FP32;
uint32_t i = 0;
#pragma unroll(4)
for (; i < nvec; ++i) {
vdst[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs));
}
if (nloe) {
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
hvx_vec_store_a(&vdst[i], nloe * sizeof(float), Q6_Vsf_equals_Vqf32(v));
}
}
static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * data) {
struct htp_fa_context * factx = (struct htp_fa_context *) data;
const struct htp_ops_context * octx = factx->octx;
const struct htp_tensor * q = &octx->src0;
const struct htp_tensor * k = &octx->src1;
const struct htp_tensor * v = &octx->src2;
const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;
const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL;
struct htp_tensor * dst = &octx->dst;
const struct htp_tensor * dst = &octx->dst;
const uint32_t neq0 = q->ne[0];
const uint32_t neq1 = q->ne[1];
@@ -304,18 +295,6 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
const uint32_t nb2 = dst->nb[2];
const uint32_t nb3 = dst->nb[3];
float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;
memcpy(&scale, (float *) octx->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float));
memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));
if (logit_softcap != 0) {
scale /= logit_softcap;
}
// total rows in q
const uint32_t nr = neq1*neq2*neq3;
@@ -331,18 +310,8 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
const uint32_t DV = nev0;
const size_t size_q_row = DK * ((q->type == HTP_TYPE_F32) ? 4 : 2);
const size_t size_q_row_padded = hex_round_up(size_q_row, 128);
const size_t size_k_row = DK * sizeof(__fp16);
const size_t size_v_row = DV * sizeof(__fp16);
const size_t size_m_row = FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16); // Treat block as one row for mask
const size_t size_k_row_padded = hex_round_up(size_k_row, 128);
const size_t size_v_row_padded = hex_round_up(size_v_row, 128);
const size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
const size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
const size_t size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
// Scratchpad buffers for Q, K, V, Mask, and VKQ32 accumulator
uint8_t * spad_q = octx->src0_spad.data + octx->src0_spad.size_per_thread * ith;
@@ -351,31 +320,28 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
uint8_t * spad_m = octx->src3_spad.data + octx->src3_spad.size_per_thread * ith;
uint8_t * spad_a = octx->dst_spad.data + octx->dst_spad.size_per_thread * ith;
const uint32_t n_head = neq2;
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
const HVX_Vector logit_cap = hvx_vec_splat_f32(factx->logit_softcap);
for (uint32_t ir = ir0; ir < ir1; ++ir) {
const uint32_t iq3 = fastdiv(ir, &octx->src0_div21);
const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &octx->src0_div1);
const uint32_t iq3 = fastdiv(ir, &factx->src0_div21);
const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &factx->src0_div1);
const uint32_t iq1 = (ir - iq3*neq2*neq1 - iq2 * neq1);
const uint32_t ik3 = fastdiv(iq3, &octx->broadcast_rk3);
const uint32_t ik2 = fastdiv(iq2, &octx->broadcast_rk2);
const uint32_t ik3 = fastdiv(iq3, &factx->broadcast_rk3);
const uint32_t ik2 = fastdiv(iq2, &factx->broadcast_rk2);
const uint32_t iv3 = fastdiv(iq3, &octx->broadcast_rv3);
const uint32_t iv2 = fastdiv(iq2, &octx->broadcast_rv2);
const uint32_t iv3 = fastdiv(iq3, &factx->broadcast_rv3);
const uint32_t iv2 = fastdiv(iq2, &factx->broadcast_rv2);
// Fetch Q row
const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3);
dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), size_q_row_padded, nbq1, size_q_row, 1);
dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), factx->size_q_row_padded, nbq1, size_q_row, 1);
const uint32_t h = iq2; // head index
const float slope = (max_bias > 0.0f) ? (h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1)) : 1.0f;
const float slope = (factx->max_bias > 0.0f) ? (h < factx->n_head_log2 ? powf(factx->m0, h + 1) : powf(factx->m1, 2*(h - factx->n_head_log2) + 1)) : 1.0f;
float S = 0.0f; // sum
float M = -INFINITY; // maximum KQ value
HVX_Vector S_vec = hvx_vec_splat_f32(0.0f);
HVX_Vector M_vec = hvx_vec_splat_f32(-INFINITY);
// Clear accumulator
hvx_splat_f32_a(spad_a, 0, DV);
@@ -383,40 +349,42 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
const __fp16 * mp_base = NULL;
if (mask) {
const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &octx->src3_div2);
const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &octx->src3_div3);
const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &factx->src3_div2);
const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &factx->src3_div3);
mp_base = (const __fp16 *) ((const uint8_t *) mask->data + iq1*mask->nb[1] + im2*mask->nb[2] + im3*mask->nb[3]);
}
const uint32_t n_blocks = (nek1 + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE;
// Prefetch first two blocks
for (uint32_t ib = 0; ib < MIN(n_blocks, 2); ++ib) {
for (uint32_t ib = 0; ib < MIN(factx->n_blocks, 2); ++ib) {
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
// K
const uint8_t * k_src = (const uint8_t *) k->data + (ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
uint8_t * k_dst = spad_k + (ib % 2) * size_k_block;
dma_queue_push(dma, dma_make_ptr(k_dst, k_src), size_k_row_padded, nbk1, size_k_row, current_block_size);
uint8_t * k_dst = spad_k + (ib % 2) * factx->size_k_block;
dma_queue_push(dma, dma_make_ptr(k_dst, k_src), factx->size_k_row_padded, nbk1, size_k_row, current_block_size);
// V
const uint8_t * v_src = (const uint8_t *) v->data + (ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
uint8_t * v_dst = spad_v + (ib % 2) * size_v_block;
dma_queue_push(dma, dma_make_ptr(v_dst, v_src), size_v_row_padded, nbv1, size_v_row, current_block_size);
uint8_t * v_dst = spad_v + (ib % 2) * factx->size_v_block;
dma_queue_push(dma, dma_make_ptr(v_dst, v_src), factx->size_v_row_padded, nbv1, size_v_row, current_block_size);
// Mask
if (mask) {
const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start);
uint8_t * m_dst = spad_m + (ib % 2) * size_m_block;
uint8_t * m_dst = spad_m + (ib % 2) * factx->size_m_block;
// Mask is 1D contiguous for this row
dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
}
}
const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
if (factx->is_q_fp32) {
hvx_copy_f16_f32_aa(q_ptr_vtcm, q_ptr_vtcm, DK); // inplace convert f32 to f16
}
for (uint32_t ib = 0; ib < n_blocks; ++ib) {
const HVX_Vector slope_vec = hvx_vec_splat_f16(slope);
for (uint32_t ib = 0; ib < factx->n_blocks; ++ib) {
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
@@ -428,8 +396,6 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
// Inner loop processing the block from VTCM
uint32_t ic = 0;
const bool is_q_fp32 = (q->type == HTP_TYPE_F32);
// Process in blocks of 32 (VLEN_FP32)
static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 <= 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage");
HVX_Vector_x4 scores_x4;
@@ -437,22 +403,18 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) {
// 1. Compute scores
float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
for (int j = 0; j < VLEN_FP32; j += 2) {
for (uint32_t j = 0; j < VLEN_FP32; j += 2) {
const uint32_t cur_ic = ic + j;
const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
if (is_q_fp32) {
hvx_dot_f32_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
} else {
hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
}
const uint8_t * k_ptr = k_base + cur_ic * factx->size_k_row_padded;
hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + factx->size_k_row_padded, DK, factx->scale);
}
HVX_Vector scores = *(HVX_Vector *) scores_arr;
// 2. Softcap
if (logit_softcap != 0.0f) {
if (factx->logit_softcap != 0.0f) {
scores = hvx_vec_tanh_f32(scores);
scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_f32(logit_softcap));
scores = Q6_Vqf32_vmpy_VsfVsf(scores, logit_cap);
scores = Q6_Vsf_equals_Vqf32(scores);
}
@@ -460,70 +422,59 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
if (mask) {
const __fp16 * mp = m_base + ic;
HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp;
HVX_Vector one_f16 = Q6_Vh_vsplat_R(0x3c00);
HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), one_f16);
HVX_Vector m_vals_f32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_f32_pair));
HVX_Vector slope_vec = hvx_vec_splat_f32(slope);
HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_f32, slope_vec);
scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val));
HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec);
HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair);
scores = Q6_Vqf32_vadd_Vqf32Vsf(add_val, scores);
scores = Q6_Vsf_equals_Vqf32(scores);
}
scores_x4.v[iv] = scores;
v_max = Q6_Vsf_vmax_VsfVsf(scores, v_max);
v_max = hvx_vec_reduce_max2_f32(scores, v_max); // All lanes have block max
}
{
// 4. Online Softmax Update
v_max = hvx_vec_reduce_max_f32(v_max);
float m_block = hvx_vec_get_f32(v_max);
float M_old = M;
float M_new = (m_block > M) ? m_block : M;
M = M_new;
HVX_Vector M_new_vec = Q6_Vsf_vmax_VsfVsf(v_max, M_vec);
HVX_Vector diff_vec = Q6_Vqf32_vsub_VsfVsf(M_vec, M_new_vec);
HVX_Vector ms_vec = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(diff_vec));
M_vec = M_new_vec;
const float ms = expf(M_old - M_new);
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
HVX_Vector M_new_vec = hvx_vec_splat_f32(M_new);
HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f);
for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) {
HVX_Vector scores = scores_x4.v[iv];
HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_vec);
HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted));
p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P));
// 5. Accumulate V
float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
*(HVX_Vector*)p_arr = P;
*(HVX_Vector *) p_arr = P;
for (int j = 0; j < VLEN_FP32; ++j) {
const uint32_t cur_ic = ic2 + j;
const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]);
for (uint32_t j = 0; j < VLEN_FP32; j += 2) {
const uint32_t cur_ic = ic2 + j;
const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded;
hvx_mad_f32_f16_aa_rx2(VKQ32, v_ptr, v_ptr + factx->size_v_row_padded, p_arr[j], p_arr[j + 1], DV);
}
}
p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec);
S = S * ms + hvx_vec_get_f32(p_sum_vec);
S_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(S_vec, ms_vec)), p_sum_vec));
}
// Sync scalars for leftover/next block if needed
float M = hvx_vec_get_f32(M_vec);
float S = hvx_vec_get_f32(S_vec);
// Leftover
for (; ic < current_block_size; ++ic) {
float s_val;
const uint8_t * k_ptr = k_base + ic * size_k_row_padded;
if (is_q_fp32) {
hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
} else {
hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
}
if (logit_softcap != 0.0f) {
s_val = logit_softcap * tanhf(s_val);
const uint8_t * k_ptr = k_base + ic * factx->size_k_row_padded;
hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, factx->scale);
if (factx->logit_softcap != 0.0f) {
s_val = factx->logit_softcap * tanhf(s_val);
}
if (mask) {
@@ -532,37 +483,42 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
}
const float Mold = M;
float ms = 1.0f;
float vs = 1.0f;
if (s_val > M) {
M = s_val;
ms = expf(Mold - M);
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
HVX_Vector diff_vec = hvx_vec_splat_f32(Mold - M);
HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec);
hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
float ms = hvx_vec_get_f32(ms_vec);
S = S * ms + vs;
} else {
vs = expf(s_val - M);
HVX_Vector diff_vec = hvx_vec_splat_f32(s_val - M);
vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec));
S += vs;
}
const uint8_t * v_ptr = v_base + ic * size_v_row_padded;
const uint8_t * v_ptr = v_base + ic * factx->size_v_row_padded;
hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, vs);
S = S * ms + vs;
}
M_vec = hvx_vec_splat_f32(M);
S_vec = hvx_vec_splat_f32(S);
// Issue DMA for next+1 block (if exists)
if (ib + 2 < n_blocks) {
if (ib + 2 < factx->n_blocks) {
const uint32_t next_ib = ib + 2;
const uint32_t next_ic_start = next_ib * FLASH_ATTN_BLOCK_SIZE;
const uint32_t next_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - next_ic_start);
// K
const uint8_t * k_src = (const uint8_t *) k->data + (next_ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
dma_queue_push(dma, dma_make_ptr(k_base, k_src), size_k_row_padded, nbk1, size_k_row, next_block_size);
dma_queue_push(dma, dma_make_ptr(k_base, k_src), factx->size_k_row_padded, nbk1, size_k_row, next_block_size);
// V
const uint8_t * v_src = (const uint8_t *) v->data + (next_ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
dma_queue_push(dma, dma_make_ptr(v_base, v_src), size_v_row_padded, nbv1, size_v_row, next_block_size);
dma_queue_push(dma, dma_make_ptr(v_base, v_src), factx->size_v_row_padded, nbv1, size_v_row, next_block_size);
// Mask
if (mask) {
@@ -573,20 +529,26 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
}
// sinks
float M = hvx_vec_get_f32(M_vec);
float S = hvx_vec_get_f32(S_vec);
if (sinks) {
const float s = ((float *)((char *) sinks->data))[h];
float ms = 1.0f;
float vs = 1.0f;
if (s > M) {
ms = expf(M - s);
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
} else {
vs = expf(s - M);
}
HVX_Vector diff_vec = hvx_vec_splat_f32(M - s);
HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec);
hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
S = S * ms + vs;
float ms = hvx_vec_get_f32(ms_vec);
S = S * ms + vs;
} else {
HVX_Vector diff_vec = hvx_vec_splat_f32(s - M);
vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec));
S += vs;
}
}
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
@@ -609,53 +571,73 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
}
}
static void htp_flash_attn_ext_job(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = data;
flash_attn_ext_f16_thread(octx, i, n);
}
int op_flash_attn_ext(struct htp_ops_context * octx) {
const struct htp_tensor * q = &octx->src0;
const struct htp_tensor * k = &octx->src1;
const struct htp_tensor * v = &octx->src2;
const struct htp_tensor * mask = (octx->src3.type != HTP_TYPE_COUNT) ? &octx->src3 : NULL;
struct htp_tensor * dst = &octx->dst;
const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;
const struct htp_tensor * dst = &octx->dst;
// Check support
if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) ||
k->type != HTP_TYPE_F16 ||
v->type != HTP_TYPE_F16) {
if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) || k->type != HTP_TYPE_F16 || v->type != HTP_TYPE_F16) {
return HTP_STATUS_NO_SUPPORT;
}
octx->src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
octx->src0_div1 = init_fastdiv_values(q->ne[1]);
struct htp_fa_context factx;
factx.octx = octx;
octx->broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
octx->broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
octx->broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
octx->broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
factx.src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
factx.src0_div1 = init_fastdiv_values(q->ne[1]);
factx.broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
factx.broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
factx.broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
factx.broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
if (mask) {
octx->src3_div2 = init_fastdiv_values(mask->ne[2]);
octx->src3_div3 = init_fastdiv_values(mask->ne[3]);
factx.src3_div2 = init_fastdiv_values(mask->ne[2]);
factx.src3_div3 = init_fastdiv_values(mask->ne[3]);
}
size_t size_q_row_padded = hex_round_up(q->ne[0] * (q->type == HTP_TYPE_F32 ? 4 : 2), 128);
size_t size_k_row_padded = hex_round_up(k->ne[0] * sizeof(__fp16), 128);
size_t size_v_row_padded = hex_round_up(v->ne[0] * sizeof(__fp16), 128);
factx.is_q_fp32 = (q->type == HTP_TYPE_F32);
factx.size_q_row_padded = hex_round_up(q->ne[0] * (factx.is_q_fp32 ? 4 : 2), 128);
factx.size_k_row_padded = hex_round_up(k->ne[0] * sizeof(__fp16), 128);
factx.size_v_row_padded = hex_round_up(v->ne[0] * sizeof(__fp16), 128);
size_t size_q_block = size_q_row_padded * 1; // single row for now
size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
size_t size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
size_t size_q_block = factx.size_q_row_padded * 1; // single row for now
factx.size_k_block = factx.size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
factx.size_v_block = factx.size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
factx.size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
factx.n_blocks = (k->ne[1] + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE;
float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;
memcpy(&scale, (float *) octx->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float));
memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));
if (logit_softcap != 0.0f) {
scale /= logit_softcap;
}
factx.scale = scale;
factx.max_bias = max_bias;
factx.logit_softcap = logit_softcap;
uint32_t n_head = q->ne[2];
factx.n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
factx.m0 = powf(2.0f, -(max_bias ) / factx.n_head_log2);
factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2);
size_t size_vkq_acc = hex_round_up(v->ne[0] * sizeof(float), 128); // VKQ32
octx->src0_spad.size_per_thread = size_q_block * 1;
octx->src1_spad.size_per_thread = size_k_block * 2;
octx->src2_spad.size_per_thread = size_v_block * 2;
octx->src3_spad.size_per_thread = mask ? size_m_block * 2 : 0;
octx->src1_spad.size_per_thread = factx.size_k_block * 2;
octx->src2_spad.size_per_thread = factx.size_v_block * 2;
octx->src3_spad.size_per_thread = mask ? factx.size_m_block * 2 : 0;
octx->dst_spad.size_per_thread = size_vkq_acc;
octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
@@ -677,7 +659,7 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size;
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
worker_pool_run_func(octx->ctx->worker_pool, htp_flash_attn_ext_job, octx, octx->n_threads);
worker_pool_run_func(octx->ctx->worker_pool, flash_attn_ext_f16_thread, &factx, octx->n_threads);
}
return HTP_STATUS_OK;
+26 -38
View File
@@ -42,32 +42,36 @@ enum htp_data_type {
HTP_TYPE_COUNT
};
// These values are manually translated over to HTP
// !!!! DO NOT ALTER THE ORDER OF THE FIRST FOUR ENUMS !!!!
// Do not reorder first 4 (used as an index)
enum htp_op {
HTP_OP_MUL = 0,
HTP_OP_ADD = 1,
HTP_OP_SUB = 2,
HTP_OP_DIV = 3,
HTP_OP_MUL_MAT = 4,
HTP_OP_MUL_MAT_ID = 5,
HTP_OP_RMS_NORM = 6,
HTP_OP_UNARY_SILU = 7,
HTP_OP_UNARY_GELU = 8,
HTP_OP_GLU_SWIGLU = 9,
HTP_OP_GLU_SWIGLU_OAI = 10,
HTP_OP_SOFTMAX = 11,
HTP_OP_ADD_ID = 12,
HTP_OP_ROPE = 13,
HTP_OP_FLASH_ATTN_EXT = 14,
HTP_OP_SET_ROWS = 15,
HTP_OP_SCALE = 16,
HTP_OP_GET_ROWS = 17,
HTP_OP_CPY = 18,
HTP_OP_MUL = 0,
HTP_OP_ADD = 1,
HTP_OP_SUB = 2,
HTP_OP_DIV = 3,
HTP_OP_MUL_MAT,
HTP_OP_MUL_MAT_ID,
HTP_OP_RMS_NORM,
HTP_OP_UNARY_SILU,
HTP_OP_UNARY_GELU,
HTP_OP_GLU_SWIGLU,
HTP_OP_GLU_SWIGLU_OAI,
HTP_OP_GLU_GEGLU,
HTP_OP_SOFTMAX,
HTP_OP_ADD_ID,
HTP_OP_ROPE,
HTP_OP_FLASH_ATTN_EXT,
HTP_OP_SET_ROWS,
HTP_OP_GET_ROWS,
HTP_OP_SCALE,
HTP_OP_CPY,
HTP_OP_ARGSORT,
HTP_OP_SQR,
HTP_OP_SQRT,
HTP_OP_SUM_ROWS,
INVALID
};
static inline size_t htp_type_block_size(uint32_t t) {
static inline size_t htp_t_block_size(uint32_t t) {
switch (t) {
case HTP_TYPE_F32:
return 1;
@@ -103,22 +107,6 @@ static inline size_t htp_type_nbytes(uint32_t t) {
return 0;
}
static const char * htp_type_name(uint32_t t) {
switch (t) {
case HTP_TYPE_F32:
return "fp32";
case HTP_TYPE_F16:
return "fp16";
case HTP_TYPE_Q4_0:
return "q4_0";
case HTP_TYPE_Q8_0:
return "q8_0";
case HTP_TYPE_MXFP4:
return "mxfp4";
}
return 0;
}
// Internal types
#define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128)
#define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks
+2 -13
View File
@@ -64,25 +64,12 @@ struct htp_ops_context {
struct fastdiv_values broadcast_rv2;
struct fastdiv_values broadcast_rv3;
struct fastdiv_values mm_div_ne12_ne1; // fastdiv values for ne12 * ne1
struct fastdiv_values mm_div_ne1; // fastdiv values for ne1
struct fastdiv_values mm_div_r2; // fastdiv values for ne12 / ne02
struct fastdiv_values mm_div_r3; // fastdiv values for ne13 / ne03
struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12
struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11
struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10
struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11
struct fastdiv_values cpy_div_ne01; // fastdiv values for ne01
struct fastdiv_values cpy_div_ne02; // fastdiv values for ne02
struct fastdiv_values cpy_div_ne03; // fastdiv values for ne03
struct fastdiv_values cpy_rshp_div_n0; // fastdiv values for ne00
struct fastdiv_values cpy_rshp_div_n1n0; // fastdiv values for ne00*ne01
struct fastdiv_values cpy_rshp_div_n2n1n0; // fastdiv values for ne00*ne01*ne02
uint32_t flags;
};
@@ -90,6 +77,7 @@ int op_matmul(struct htp_ops_context * octx);
int op_matmul_id(struct htp_ops_context * octx);
int op_binary(struct htp_ops_context * octx);
int op_unary(struct htp_ops_context * octx);
int op_sum_rows(struct htp_ops_context * octx);
int op_activations(struct htp_ops_context * octx);
int op_softmax(struct htp_ops_context * octx);
int op_add_id(struct htp_ops_context * octx);
@@ -98,5 +86,6 @@ int op_flash_attn_ext(struct htp_ops_context * octx);
int op_set_rows(struct htp_ops_context * octx);
int op_get_rows(struct htp_ops_context * octx);
int op_cpy(struct htp_ops_context * octx);
int op_argsort(struct htp_ops_context * octx);
#endif /* HTP_OPS_H */
+129 -116
View File
@@ -46,127 +46,76 @@
#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
#endif
// ADD variants
// Generic macro to define alignment permutations for an op
#define DEFINE_HVX_BINARY_OP_VARIANTS(OP_NAME, OP_MACRO) \
static inline void OP_NAME##_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) dst % 128 == 0); \
assert((uintptr_t) src0 % 128 == 0); \
assert((uintptr_t) src1 % 128 == 0); \
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \
} \
static inline void OP_NAME##_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) dst % 128 == 0); \
assert((uintptr_t) src0 % 128 == 0); \
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \
} \
static inline void OP_NAME##_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) dst % 128 == 0); \
assert((uintptr_t) src1 % 128 == 0); \
hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \
} \
static inline void OP_NAME##_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) dst % 128 == 0); \
hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \
} \
static inline void OP_NAME##_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) src0 % 128 == 0); \
assert((uintptr_t) src1 % 128 == 0); \
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \
} \
static inline void OP_NAME##_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) src0 % 128 == 0); \
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \
} \
static inline void OP_NAME##_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) src1 % 128 == 0); \
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \
} \
static inline void OP_NAME##_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \
} \
static inline void hvx_add_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src0 % 128 == 0);
assert((unsigned long) src1 % 128 == 0);
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_ADD);
DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f32, HVX_OP_ADD)
DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f32, HVX_OP_SUB)
DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f32, HVX_OP_MUL)
// Dispatcher logic
#define HVX_BINARY_DISPATCHER(OP_NAME) \
static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { \
if (hex_is_aligned((void *) dst, 128)) { \
if (hex_is_aligned((void *) src0, 128)) { \
if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aaa(dst, src0, src1, num_elems); \
else OP_NAME##_aau(dst, src0, src1, num_elems); \
} else { \
if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aua(dst, src0, src1, num_elems); \
else OP_NAME##_auu(dst, src0, src1, num_elems); \
} \
} else { \
if (hex_is_aligned((void *) src0, 128)) { \
if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uaa(dst, src0, src1, num_elems); \
else OP_NAME##_uau(dst, src0, src1, num_elems); \
} else { \
if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uua(dst, src0, src1, num_elems); \
else OP_NAME##_uuu(dst, src0, src1, num_elems); \
} \
} \
}
static inline void hvx_add_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src0 % 128 == 0);
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_ADD);
}
static inline void hvx_add_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((unsigned long) src0 % 128 == 0);
assert((unsigned long) src1 % 128 == 0);
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_ADD);
}
static inline void hvx_add_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_ADD);
}
// SUB variants
static inline void hvx_sub_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src0 % 128 == 0);
assert((unsigned long) src1 % 128 == 0);
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_SUB);
}
static inline void hvx_sub_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src0 % 128 == 0);
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_SUB);
}
static inline void hvx_sub_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((unsigned long) src0 % 128 == 0);
assert((unsigned long) src1 % 128 == 0);
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_SUB);
}
static inline void hvx_sub_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_SUB);
}
// MUL variants
static inline void hvx_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src0 % 128 == 0);
assert((unsigned long) src1 % 128 == 0);
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MUL);
}
static inline void hvx_mul_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src0 % 128 == 0);
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MUL);
}
static inline void hvx_mul_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((unsigned long) src0 % 128 == 0);
assert((unsigned long) src1 % 128 == 0);
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_MUL);
}
static inline void hvx_mul_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MUL);
}
// Dispatchers
static inline void hvx_add_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) {
if (hex_is_aligned((void *) src1, 128)) {
hvx_add_f32_aa(dst, src0, src1, num_elems);
} else {
hvx_add_f32_au(dst, src0, src1, num_elems);
}
} else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) {
hvx_add_f32_ua(dst, src0, src1, num_elems);
} else {
hvx_add_f32_uu(dst, src0, src1, num_elems);
}
}
static inline void hvx_sub_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) {
if (hex_is_aligned((void *) src1, 128)) {
hvx_sub_f32_aa(dst, src0, src1, num_elems);
} else {
hvx_sub_f32_au(dst, src0, src1, num_elems);
}
} else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) {
hvx_sub_f32_ua(dst, src0, src1, num_elems);
} else {
hvx_sub_f32_uu(dst, src0, src1, num_elems);
}
}
static inline void hvx_mul_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) {
if (hex_is_aligned((void *) src1, 128)) {
hvx_mul_f32_aa(dst, src0, src1, num_elems);
} else {
hvx_mul_f32_au(dst, src0, src1, num_elems);
}
} else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) {
hvx_mul_f32_ua(dst, src0, src1, num_elems);
} else {
hvx_mul_f32_uu(dst, src0, src1, num_elems);
}
}
HVX_BINARY_DISPATCHER(hvx_add_f32)
HVX_BINARY_DISPATCHER(hvx_sub_f32)
HVX_BINARY_DISPATCHER(hvx_mul_f32)
// Mul-Mul Optimized
static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint8_t * restrict src2, const uint32_t num_elems) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src0 % 128 == 0);
@@ -443,6 +392,68 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t *
}
}
//
// Square
//
#define hvx_sqr_loop_body(dst_type, src_type, vec_store) \
do { \
dst_type * restrict vdst = (dst_type *) dst; \
src_type * restrict vsrc = (src_type *) src; \
\
const uint32_t elem_size = sizeof(float); \
const uint32_t epv = 128 / elem_size; \
const uint32_t nvec = n / epv; \
const uint32_t nloe = n % epv; \
\
uint32_t i = 0; \
\
_Pragma("unroll(4)") \
for (; i < nvec; i++) { \
vdst[i] = HVX_OP_MUL(vsrc[i], vsrc[i]); \
} \
if (nloe) { \
HVX_Vector v = HVX_OP_MUL(vsrc[i], vsrc[i]); \
vec_store((void *) &vdst[i], nloe * elem_size, v); \
} \
} while(0)
static inline void hvx_sqr_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src % 128 == 0);
hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
}
static inline void hvx_sqr_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
}
static inline void hvx_sqr_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) src % 128 == 0);
hvx_sqr_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
}
static inline void hvx_sqr_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
hvx_sqr_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
}
static inline void hvx_sqr_f32(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t num_elems) {
if (hex_is_aligned((void *) dst, 128)) {
if (hex_is_aligned((void *) src, 128)) {
hvx_sqr_f32_aa(dst, src, num_elems);
} else {
hvx_sqr_f32_au(dst, src, num_elems);
}
} else {
if (hex_is_aligned((void *) src, 128)) {
hvx_sqr_f32_ua(dst, src, num_elems);
} else {
hvx_sqr_f32_uu(dst, src, num_elems);
}
}
}
#undef HVX_OP_ADD
#undef HVX_OP_SUB
#undef HVX_OP_MUL
@@ -453,5 +464,7 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t *
#undef hvx_scalar_loop_body
#undef HVX_OP_MIN_SCALAR
#undef HVX_OP_CLAMP_SCALAR
#undef DEFINE_HVX_BINARY_OP_VARIANTS
#undef HVX_BINARY_DISPATCHER
#endif // HVX_ARITH_H
+6
View File
@@ -66,6 +66,12 @@ static inline float hvx_vec_get_f32(HVX_Vector v) {
return x;
}
static inline int32_t hvx_vec_get_i32(HVX_Vector v) {
int32_t __attribute__((aligned(128))) x;
hvx_vec_store_a(&x, 4, v);
return x;
}
static inline HVX_Vector hvx_vec_abs_f16(HVX_Vector v) {
// abs by clearing the fp16 sign bit
HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff);
-2
View File
@@ -136,8 +136,6 @@ static inline void hvx_copy_f32_uu(uint8_t * restrict dst, const uint8_t * restr
dst_type * restrict vdst = (dst_type *) dst; \
src_type * restrict vsrc = (src_type *) src; \
\
const HVX_Vector zero = Q6_V_vsplat_R(0); \
\
const uint32_t elem_size = sizeof(__fp16); \
const uint32_t epv = 128 / elem_size; \
const uint32_t nvec = n / epv; \
+116
View File
@@ -0,0 +1,116 @@
#ifndef HVX_DIV_H
#define HVX_DIV_H
#include <HAP_farf.h>
#include <math.h>
#include <string.h>
#include <assert.h>
#include <stddef.h>
#include <stdint.h>
#include "hvx-base.h"
#include "hex-utils.h"
#include "hvx-inverse.h"
#include "hvx-arith.h"
#if __HVX_ARCH__ < 79
#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
#else
#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
#endif
#define hvx_div_f32_loop_body(dst_type, src0_type, src1_type, vec_store) \
do { \
dst_type * restrict vdst = (dst_type *) dst; \
src0_type * restrict vsrc0 = (src0_type *) src0; \
src1_type * restrict vsrc1 = (src1_type *) src1; \
\
const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \
\
const uint32_t nvec = n / VLEN_FP32; \
const uint32_t nloe = n % VLEN_FP32; \
\
uint32_t i = 0; \
\
_Pragma("unroll(4)") \
for (; i < nvec; i++) { \
HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \
HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1); \
vdst[i] = res; \
} \
if (nloe) { \
HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \
HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1); \
vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, res); \
} \
} while(0)
// 3-letter suffix variants
static inline void hvx_div_f32_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((uintptr_t) dst % 128 == 0);
assert((uintptr_t) src0 % 128 == 0);
assert((uintptr_t) src1 % 128 == 0);
hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a);
}
static inline void hvx_div_f32_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((uintptr_t) dst % 128 == 0);
assert((uintptr_t) src0 % 128 == 0);
hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a);
}
static inline void hvx_div_f32_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((uintptr_t) dst % 128 == 0);
assert((uintptr_t) src1 % 128 == 0);
hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a);
}
static inline void hvx_div_f32_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((uintptr_t) dst % 128 == 0);
hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a);
}
static inline void hvx_div_f32_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((uintptr_t) src0 % 128 == 0);
assert((uintptr_t) src1 % 128 == 0);
hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u);
}
static inline void hvx_div_f32_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((uintptr_t) src0 % 128 == 0);
hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u);
}
static inline void hvx_div_f32_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((uintptr_t) src1 % 128 == 0);
hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u);
}
static inline void hvx_div_f32_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u);
}
static inline void hvx_div_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
if (hex_is_aligned((void *) dst, 128)) {
if (hex_is_aligned((void *) src0, 128)) {
if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aaa(dst, src0, src1, num_elems);
else hvx_div_f32_aau(dst, src0, src1, num_elems);
} else {
if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aua(dst, src0, src1, num_elems);
else hvx_div_f32_auu(dst, src0, src1, num_elems);
}
} else {
if (hex_is_aligned((void *) src0, 128)) {
if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uaa(dst, src0, src1, num_elems);
else hvx_div_f32_uau(dst, src0, src1, num_elems);
} else {
if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uua(dst, src0, src1, num_elems);
else hvx_div_f32_uuu(dst, src0, src1, num_elems);
}
}
}
#undef HVX_OP_MUL
#endif // HVX_DIV_H
+27
View File
@@ -91,6 +91,27 @@ static inline HVX_Vector hvx_vec_tanh_f32(HVX_Vector x) {
} \
} while(0)
#define hvx_tanh_loop_body(dst_type, src_type, vec_store) \
do { \
dst_type * restrict vdst = (dst_type *) dst; \
src_type * restrict vsrc = (src_type *) src; \
\
const uint32_t epv = 128 / sizeof(float); \
const uint32_t nvec = n / epv; \
const uint32_t nloe = n % epv; \
\
uint32_t i = 0; \
\
_Pragma("unroll(4)") \
for (; i < nvec; i++) { \
vdst[i] = hvx_vec_tanh_f32(vsrc[i]); \
} \
if (nloe) { \
HVX_Vector tmp = hvx_vec_tanh_f32(vsrc[i]); \
vec_store((void *) &vdst[i], nloe * sizeof(float), tmp); \
} \
} while(0)
static inline void hvx_sigmoid_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src % 128 == 0);
@@ -111,4 +132,10 @@ static inline void hvx_sigmoid_f32_uu(uint8_t * restrict dst, const uint8_t * re
hvx_sigmoid_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
}
static inline void hvx_tanh_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src % 128 == 0);
hvx_tanh_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
}
#endif /* HVX_SIGMOID_H */
+67 -1
View File
@@ -12,11 +12,17 @@
#define RSQRT_ONE_HALF 0x3f000000 // 0.5
#define RSQRT_THREE_HALVES 0x3fc00000 // 1.5
#if __HVX_ARCH__ < 79
#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
#else
#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
#endif
static inline HVX_Vector hvx_vec_rsqrt_f32(HVX_Vector in_vec) {
//Algorithm :
// x2 = input*0.5
// y = * (long *) &input
// y = 0x5f3759df - (y>>2)
// y = 0x5f3759df - (y>>1)
// y = y*(threehalfs - x2*y*y)
HVX_Vector rsqrtconst = Q6_V_vsplat_R(RSQRT_CONST);
@@ -57,4 +63,64 @@ static inline HVX_Vector hvx_vec_rsqrt_f32(HVX_Vector in_vec) {
return Q6_Vsf_equals_Vqf32(temp);
}
// Compute sqrt(x) as x*inv_sqrt(x)
#define hvx_sqrt_f32_loop_body(dst_type, src_type, vec_store) \
do { \
dst_type * restrict vdst = (dst_type *) dst; \
src_type * restrict vsrc = (src_type *) src; \
\
const uint32_t nvec = n / VLEN_FP32; \
const uint32_t nloe = n % VLEN_FP32; \
\
uint32_t i = 0; \
\
_Pragma("unroll(4)") \
for (; i < nvec; i++) { \
HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]); \
HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]); \
vdst[i] = sqrt_res; \
} \
if (nloe) { \
HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]); \
HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]); \
vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, sqrt_res); \
} \
} while(0)
static inline void hvx_sqrt_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src % 128 == 0);
hvx_sqrt_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
}
static inline void hvx_sqrt_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
hvx_sqrt_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
}
static inline void hvx_sqrt_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) src % 128 == 0);
hvx_sqrt_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
}
static inline void hvx_sqrt_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
hvx_sqrt_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
}
static inline void hvx_sqrt_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int num_elems) {
if ((unsigned long) dst % 128 == 0) {
if ((unsigned long) src % 128 == 0) {
hvx_sqrt_f32_aa(dst, src, num_elems);
} else {
hvx_sqrt_f32_au(dst, src, num_elems);
}
} else {
if ((unsigned long) src % 128 == 0) {
hvx_sqrt_f32_ua(dst, src, num_elems);
} else {
hvx_sqrt_f32_uu(dst, src, num_elems);
}
}
}
#endif /* HVX_SQRT_H */
+1
View File
@@ -12,6 +12,7 @@
#include "hvx-sigmoid.h"
#include "hvx-sqrt.h"
#include "hvx-arith.h"
#include "hvx-div.h"
#include "hvx-base.h"
#endif /* HVX_UTILS_H */
+108 -1
View File
@@ -189,7 +189,7 @@ static int vtcm_release_callback(unsigned int rctx, void * state) {
// otherwise we'll release it once we're done with the current Op.
if (ctx->vtcm_inuse) {
ctx->vtcm_needs_release = false;
ctx->vtcm_needs_release = true;
return 0;
}
@@ -440,6 +440,45 @@ static void proc_matmul_req(struct htp_context * ctx,
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
static void proc_argsort_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
struct dspqueue_buffer rsp_bufs[1];
// We had written to the output buffer, we'd also need to flush it
rsp_bufs[0].fd = bufs[1].fd;
rsp_bufs[0].ptr = bufs[1].ptr;
rsp_bufs[0].offset = bufs[1].offset;
rsp_bufs[0].size = bufs[1].size;
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
// Setup Op context
struct htp_ops_context octx = { 0 };
octx.ctx = ctx;
octx.src0 = req->src0;
octx.dst = req->dst;
octx.flags = req->flags;
octx.op = req->op;
memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
// Update data pointers
octx.src0.data = (uint32_t) bufs[0].ptr;
octx.dst.data = (uint32_t) bufs[1].ptr;
octx.n_threads = ctx->n_threads;
struct profile_data prof;
profile_start(&prof);
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
rsp_status = op_argsort(&octx);
vtcm_release(ctx);
}
profile_stop(&prof);
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
static void proc_cpy_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
struct dspqueue_buffer rsp_bufs[1];
@@ -679,6 +718,45 @@ static void proc_unary_req(struct htp_context * ctx, struct htp_general_req * re
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
static void proc_sum_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
// We had written to the output buffer, we'd also need to flush it
rsp_bufs[0].fd = bufs[1].fd;
rsp_bufs[0].ptr = bufs[1].ptr;
rsp_bufs[0].offset = bufs[1].offset;
rsp_bufs[0].size = bufs[1].size;
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
// Setup Op context
struct htp_ops_context octx = { 0 };
octx.ctx = ctx;
octx.src0 = req->src0;
octx.dst = req->dst;
octx.flags = req->flags;
octx.op = req->op;
memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
// Update data pointers
octx.src0.data = (uint32_t) bufs[0].ptr;
octx.dst.data = (uint32_t) bufs[1].ptr;
octx.n_threads = ctx->n_threads;
struct profile_data prof;
profile_start(&prof);
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
rsp_status = op_sum_rows(&octx);
vtcm_release(ctx);
}
profile_stop(&prof);
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
static void proc_activations_req(struct htp_context * ctx,
struct htp_general_req * req,
struct dspqueue_buffer * bufs,
@@ -951,6 +1029,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
case HTP_OP_MUL:
case HTP_OP_ADD:
case HTP_OP_SUB:
case HTP_OP_DIV:
if (n_bufs != 3) {
FARF(ERROR, "Bad binary-req buffer list");
continue;
@@ -968,6 +1047,25 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
proc_unary_req(ctx, &req, bufs);
break;
case HTP_OP_SQR:
case HTP_OP_SQRT:
if (n_bufs != 2) {
FARF(ERROR, "Bad unary-req buffer list");
continue;
}
proc_unary_req(ctx, &req, bufs);
break;
case HTP_OP_SUM_ROWS:
if (n_bufs != 2) {
FARF(ERROR, "Bad unary-req buffer list");
continue;
}
proc_sum_rows_req(ctx, &req, bufs);
break;
case HTP_OP_UNARY_SILU:
case HTP_OP_UNARY_GELU:
if (n_bufs != 2) {
@@ -980,6 +1078,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
case HTP_OP_GLU_SWIGLU:
case HTP_OP_GLU_SWIGLU_OAI:
case HTP_OP_SOFTMAX:
case HTP_OP_GLU_GEGLU:
if ((n_bufs != 2) && (n_bufs != 3)) {
FARF(ERROR, "Bad act-req buffer list");
continue;
@@ -1035,6 +1134,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
proc_cpy_req(ctx, &req, bufs);
break;
case HTP_OP_ARGSORT:
if (n_bufs != 2) {
FARF(ERROR, "Bad argsort-req buffer list");
continue;
}
proc_argsort_req(ctx, &req, bufs);
break;
default:
FARF(ERROR, "Unknown Op %u", req.op);
break;
File diff suppressed because it is too large Load Diff
+115
View File
@@ -0,0 +1,115 @@
#pragma clang diagnostic ignored "-Wunused-variable"
#pragma clang diagnostic ignored "-Wunused-function"
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
#include <HAP_farf.h>
#include <HAP_perf.h>
#include <string.h>
#include <math.h>
#include "hex-dma.h"
#include "hvx-utils.h"
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "htp-ctx.h"
#include "htp-msg.h"
#include "htp-ops.h"
#define sum_rows_preamble \
struct htp_tensor *src0 = &octx->src0;\
struct htp_tensor *dst = &octx->dst; \
\
const uint32_t ne00 = src0->ne[0]; \
const uint32_t ne01 = src0->ne[1]; \
const uint32_t ne02 = src0->ne[2]; \
const uint32_t ne03 = src0->ne[3]; \
\
const uint32_t nb00 = src0->nb[0]; \
const uint32_t nb01 = src0->nb[1]; \
const uint32_t nb02 = src0->nb[2]; \
const uint32_t nb03 = src0->nb[3]; \
\
const uint32_t ne0 = dst->ne[0]; \
const uint32_t ne1 = dst->ne[1]; \
const uint32_t ne2 = dst->ne[2]; \
const uint32_t ne3 = dst->ne[3]; \
\
const uint32_t nb0 = dst->nb[0]; \
const uint32_t nb1 = dst->nb[1]; \
const uint32_t nb2 = dst->nb[2]; \
const uint32_t nb3 = dst->nb[3]; \
static int sum_rows_thread_f32(struct htp_ops_context * octx, const int nth, const int ith) {
sum_rows_preamble;
const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
const size_t src0_row_size = nb01;
const size_t dst_row_size = nb1;
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
// no work for this thread
if (src0_start_row >= src0_end_row) {
return HTP_STATUS_OK;
}
int opt_path = 0;
if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) {
opt_path = 1;
}
const uint8_t * restrict data_src = (const uint8_t *) src0->data;
uint8_t * restrict data_dst = (uint8_t *) dst->data;
const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size));
float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size));
for (uint32_t ir = 0; ir < src0_nrows_per_thread; ir++) {
const float * restrict src_local = src_th + (ir * ne00);
if (ir + 1 < src0_nrows_per_thread) {
hex_l2fetch(src_local + ne00, src0_row_size, src0_row_size, 1);
}
if (1 == opt_path) {
dst_th[ir] = hvx_reduce_sum_f32_a((const uint8_t *) src_local, ne00);
} else {
dst_th[ir] = hvx_reduce_sum_f32((const uint8_t *) src_local, ne00);
}
}
return HTP_STATUS_OK;
}
static void sum_rows_work_f32(unsigned int n, unsigned int i, void *data) {
sum_rows_thread_f32((struct htp_ops_context *) data, n, i);
}
int op_sum_rows(struct htp_ops_context * octx) {
sum_rows_preamble;
if (octx->src0.type != HTP_TYPE_F32) {
return HTP_STATUS_NO_SUPPORT;
}
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
return HTP_STATUS_OK;
}
const int n_threads = octx->n_threads;
const uint32_t src0_nrows = ne01 * ne02 * ne03;
uint32_t n_jobs = MIN(n_threads, src0_nrows);
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
worker_pool_run_func(octx->ctx->worker_pool, sum_rows_work_f32, octx, n_jobs);
return HTP_STATUS_OK;
}
+64
View File
@@ -132,6 +132,56 @@ static void rms_norm_htp_f32(const float * restrict src,
}
}
static void sqr_htp_f32(const float * restrict src,
float * restrict dst,
uint8_t * restrict spad,
const uint32_t num_rows,
const uint32_t row_elems,
const size_t row_size,
int32_t * op_params,
int opt_path) {
for (uint32_t ir = 0; ir < num_rows; ir++) {
const float * restrict src_local = src + (ir * row_elems);
float * restrict dst_local = dst + (ir * row_elems);
if (ir + 1 < num_rows) {
hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
}
if (1 == opt_path) {
hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
} else {
hvx_sqr_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
}
}
}
static void sqrt_htp_f32(const float * restrict src,
float * restrict dst,
uint8_t * restrict spad,
const uint32_t num_rows,
const uint32_t row_elems,
const size_t row_size,
int32_t * op_params,
int opt_path) {
for (uint32_t ir = 0; ir < num_rows; ir++) {
const float * restrict src_local = src + (ir * row_elems);
float * restrict dst_local = dst + (ir * row_elems);
if (ir + 1 < num_rows) {
hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
}
if (1 == opt_path) {
hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
} else {
hvx_sqrt_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
}
}
}
static void unary_job_f32_per_thread(const struct htp_tensor * src,
struct htp_tensor * dst,
uint8_t * spad,
@@ -181,6 +231,12 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src,
case HTP_OP_SCALE:
scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
break;
case HTP_OP_SQR:
sqr_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
break;
case HTP_OP_SQRT:
sqrt_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
break;
default:
break;
@@ -218,6 +274,14 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
unary_op_func = unary_job_dispatcher_f32;
op_type = "scale-f32";
break;
case HTP_OP_SQR:
unary_op_func = unary_job_dispatcher_f32;
op_type = "sqr-f32";
break;
case HTP_OP_SQRT:
unary_op_func = unary_job_dispatcher_f32;
op_type = "sqrt-f32";
break;
default:
FARF(ERROR, "Unsupported unary Op %u\n", octx->op);
+4
View File
@@ -98,6 +98,10 @@ static bool ggml_op_is_empty(enum ggml_op op) {
}
}
static inline bool ggml_impl_is_view(const struct ggml_tensor * t) {
return t->view_src != NULL;
}
static inline float ggml_compute_softplus_f32(float input) {
return (input > 20.0f) ? input : logf(1 + expf(input));
}
+13 -2
View File
@@ -264,15 +264,26 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_GROUP_NORM:
case GGML_OP_L2_NORM:
case GGML_OP_SUM_ROWS:
case GGML_OP_SSM_CONV:
case GGML_OP_SSM_SCAN:
case GGML_OP_CLAMP:
case GGML_OP_TRI:
case GGML_OP_DIAG:
case GGML_OP_MUL:
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_DIV:
case GGML_OP_GLU:
case GGML_OP_SCALE:
case GGML_OP_UNARY:
case GGML_OP_GET_ROWS:
case GGML_OP_CPY:
case GGML_OP_SET_ROWS:
case GGML_OP_SET:
case GGML_OP_CPY:
case GGML_OP_CONT:
case GGML_OP_REPEAT:
return true;
default:
return ggml_op_is_empty(op);
@@ -312,7 +323,7 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
h_add(mrs1, node0);
// that many nodes forward to search for a concurrent node
constexpr int N_FORWARD = 8;
constexpr int N_FORWARD = 64;
for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) {
if (used[i1]) {
+76 -50
View File
@@ -212,61 +212,69 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_meta
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
char base[256];
char name[256];
const int64_t n = ggml_nelements(op);
int op_num = -1;
const char * op_str = "undefined";
switch (op->op) {
case GGML_OP_SCALE: op_str = "scale"; break;
case GGML_OP_FILL: op_str = "fill"; break;
case GGML_OP_CLAMP: op_str = "clamp"; break;
case GGML_OP_SQR: op_str = "sqr"; break;
case GGML_OP_SQRT: op_str = "sqrt"; break;
case GGML_OP_SIN: op_str = "sin"; break;
case GGML_OP_COS: op_str = "cos"; break;
case GGML_OP_LOG: op_str = "log"; break;
case GGML_OP_LEAKY_RELU: op_str = "leaky_relu"; break;
case GGML_OP_SCALE: op_num = OP_UNARY_NUM_SCALE; break;
case GGML_OP_FILL: op_num = OP_UNARY_NUM_FILL; break;
case GGML_OP_CLAMP: op_num = OP_UNARY_NUM_CLAMP; break;
case GGML_OP_SQR: op_num = OP_UNARY_NUM_SQR; break;
case GGML_OP_SQRT: op_num = OP_UNARY_NUM_SQRT; break;
case GGML_OP_SIN: op_num = OP_UNARY_NUM_SIN; break;
case GGML_OP_COS: op_num = OP_UNARY_NUM_COS; break;
case GGML_OP_LOG: op_num = OP_UNARY_NUM_LOG; break;
case GGML_OP_LEAKY_RELU: op_num = OP_UNARY_NUM_LEAKY_RELU; break;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_TANH: op_str = "tanh"; break;
case GGML_UNARY_OP_RELU: op_str = "relu"; break;
case GGML_UNARY_OP_SIGMOID: op_str = "sigmoid"; break;
case GGML_UNARY_OP_GELU: op_str = "gelu"; break;
case GGML_UNARY_OP_GELU_ERF: op_str = "gelu_erf"; break;
case GGML_UNARY_OP_GELU_QUICK: op_str = "gelu_quick"; break;
case GGML_UNARY_OP_SILU: op_str = "silu"; break;
case GGML_UNARY_OP_ELU: op_str = "elu"; break;
case GGML_UNARY_OP_NEG: op_str = "neg"; break;
case GGML_UNARY_OP_ABS: op_str = "abs"; break;
case GGML_UNARY_OP_SGN: op_str = "sgn"; break;
case GGML_UNARY_OP_STEP: op_str = "step"; break;
case GGML_UNARY_OP_HARDSWISH: op_str = "hardswish"; break;
case GGML_UNARY_OP_HARDSIGMOID: op_str = "hardsigmoid"; break;
case GGML_UNARY_OP_EXP: op_str = "exp"; break;
case GGML_UNARY_OP_SOFTPLUS: op_str = "softplus"; break;
case GGML_UNARY_OP_EXPM1: op_str = "expm1"; break;
case GGML_UNARY_OP_TANH: op_num = OP_UNARY_NUM_TANH; break;
case GGML_UNARY_OP_RELU: op_num = OP_UNARY_NUM_RELU; break;
case GGML_UNARY_OP_SIGMOID: op_num = OP_UNARY_NUM_SIGMOID; break;
case GGML_UNARY_OP_GELU: op_num = OP_UNARY_NUM_GELU; break;
case GGML_UNARY_OP_GELU_ERF: op_num = OP_UNARY_NUM_GELU_ERF; break;
case GGML_UNARY_OP_GELU_QUICK: op_num = OP_UNARY_NUM_GELU_QUICK; break;
case GGML_UNARY_OP_SILU: op_num = OP_UNARY_NUM_SILU; break;
case GGML_UNARY_OP_ELU: op_num = OP_UNARY_NUM_ELU; break;
case GGML_UNARY_OP_NEG: op_num = OP_UNARY_NUM_NEG; break;
case GGML_UNARY_OP_ABS: op_num = OP_UNARY_NUM_ABS; break;
case GGML_UNARY_OP_SGN: op_num = OP_UNARY_NUM_SGN; break;
case GGML_UNARY_OP_STEP: op_num = OP_UNARY_NUM_STEP; break;
case GGML_UNARY_OP_HARDSWISH: op_num = OP_UNARY_NUM_HARDSWISH; break;
case GGML_UNARY_OP_HARDSIGMOID: op_num = OP_UNARY_NUM_HARDSIGMOID; break;
case GGML_UNARY_OP_EXP: op_num = OP_UNARY_NUM_EXP; break;
case GGML_UNARY_OP_SOFTPLUS: op_num = OP_UNARY_NUM_SOFTPLUS; break;
case GGML_UNARY_OP_EXPM1: op_num = OP_UNARY_NUM_EXPM1; break;
default: GGML_ABORT("fatal error");
} break;
default: GGML_ABORT("fatal error");
};
const char * suffix = "";
if (n % 4 == 0) {
suffix = "_4";
}
const char * t0_str = ggml_type_name(op->src[0]->type);
const char * t_str = ggml_type_name(op->type);
snprintf(base, 256, "kernel_%s_%s%s", op_str, ggml_type_name(op->src[0]->type), suffix);
snprintf(name, 256, "%s", base);
const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
const bool is_cnt = ggml_is_contiguous(op->src[0]) && ggml_nelements(op) < 32768;
snprintf(base, 256, "kernel_unary_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
snprintf(name, 256, "%s_op=%d_cnt=%d", base, op_num, is_cnt);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_int16(cv, op_num, FC_UNARY + 0);
ggml_metal_cv_set_bool (cv, is_cnt, FC_UNARY + 1);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_cv_free(cv);
}
res.c4 = is_c4;
res.cnt = is_cnt;
return res;
}
@@ -320,31 +328,46 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum(ggml_metal_l
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
char base[256];
char name[256];
const char * op_str = "undefined";
int op_num = -1;
switch (op->op) {
case GGML_OP_SUM_ROWS:
op_str = "sum_rows"; break;
case GGML_OP_MEAN:
op_str = "mean"; break;
case GGML_OP_SUM_ROWS: op_num = OP_SUM_ROWS_NUM_SUM_ROWS; break;
case GGML_OP_MEAN: op_num = OP_SUM_ROWS_NUM_MEAN; break;
default: GGML_ABORT("fatal error");
};
snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type));
const char * t0_str = ggml_type_name(op->src[0]->type);
const char * t_str = ggml_type_name(op->type);
snprintf(name, 256, "%s", base);
const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
snprintf(base, 256, "kernel_sum_rows_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
snprintf(name, 256, "%s_op=%d", base, op_num);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_int16(cv, op_num, FC_SUM_ROWS + 0);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_cv_free(cv);
}
res.smem = 32*sizeof(float);
if (is_c4) {
res.smem *= 4;
}
res.c4 = is_c4;
return res;
}
@@ -1472,13 +1495,15 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one(ggml_met
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_L2_NORM);
GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
char base[256];
char name[256];
snprintf(base, 256, "kernel_l2_norm_f32");
const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
const char * t0_str = ggml_type_name(op->src[0]->type);
const char * t_str = ggml_type_name(op->type);
snprintf(base, 256, "kernel_l2_norm_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
@@ -1486,6 +1511,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
}
res.c4 = is_c4;
res.smem = 32*sizeof(float);
return res;
+13 -14
View File
@@ -1011,6 +1011,15 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
}
switch (op->op) {
case GGML_OP_SCALE:
case GGML_OP_FILL:
case GGML_OP_CLAMP:
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_LOG:
return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_TANH:
@@ -1030,7 +1039,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_SOFTPLUS:
case GGML_UNARY_OP_EXPM1:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
default:
return false;
}
@@ -1058,11 +1067,9 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_ADD_ID:
return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ACC:
return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_REPEAT:
case GGML_OP_SCALE:
case GGML_OP_FILL:
case GGML_OP_CONV_TRANSPOSE_1D:
return true;
case GGML_OP_CONV_TRANSPOSE_2D:
@@ -1070,14 +1077,6 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) &&
op->src[1]->type == GGML_TYPE_F32 &&
op->type == GGML_TYPE_F32;
case GGML_OP_CLAMP:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_LOG:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SUM:
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
case GGML_OP_TRI:
@@ -1087,9 +1086,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_MEAN:
case GGML_OP_SOFT_MAX:
case GGML_OP_GROUP_NORM:
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_L2_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_COUNT_EQUAL:
return has_simdgroup_reduction &&
op->src[0]->type == GGML_TYPE_I32 &&
@@ -1161,6 +1159,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
return has_simdgroup_reduction;
case GGML_OP_SET:
case GGML_OP_CPY:
case GGML_OP_DUP:
case GGML_OP_CONT:
+73 -20
View File
@@ -80,7 +80,9 @@
#define FC_SSM_CONV 900
#define FC_SOLVE_TRI 1000
#define FC_COUNT_EQUAL 1100
#define FC_BIN 1200
#define FC_UNARY 1200
#define FC_BIN 1300
#define FC_SUM_ROWS 1400
// op-specific constants
#define OP_FLASH_ATTN_EXT_NQPSG 8
@@ -89,6 +91,37 @@
#define OP_FLASH_ATTN_EXT_VEC_NQPSG 1
#define OP_FLASH_ATTN_EXT_VEC_NCPSG 32
#define OP_UNARY_NUM_SCALE 10
#define OP_UNARY_NUM_FILL 11
#define OP_UNARY_NUM_CLAMP 12
#define OP_UNARY_NUM_SQR 13
#define OP_UNARY_NUM_SQRT 14
#define OP_UNARY_NUM_SIN 15
#define OP_UNARY_NUM_COS 16
#define OP_UNARY_NUM_LOG 17
#define OP_UNARY_NUM_LEAKY_RELU 18
#define OP_UNARY_NUM_TANH 100
#define OP_UNARY_NUM_RELU 101
#define OP_UNARY_NUM_SIGMOID 102
#define OP_UNARY_NUM_GELU 103
#define OP_UNARY_NUM_GELU_ERF 104
#define OP_UNARY_NUM_GELU_QUICK 105
#define OP_UNARY_NUM_SILU 106
#define OP_UNARY_NUM_ELU 107
#define OP_UNARY_NUM_NEG 108
#define OP_UNARY_NUM_ABS 109
#define OP_UNARY_NUM_SGN 110
#define OP_UNARY_NUM_STEP 111
#define OP_UNARY_NUM_HARDSWISH 112
#define OP_UNARY_NUM_HARDSIGMOID 113
#define OP_UNARY_NUM_EXP 114
#define OP_UNARY_NUM_SOFTPLUS 115
#define OP_UNARY_NUM_EXPM1 116
#define OP_SUM_ROWS_NUM_SUM_ROWS 10
#define OP_SUM_ROWS_NUM_MEAN 11
// kernel argument structs
//
// - element counters (e.g. ne00) typically use int32_t to reduce register usage
@@ -124,6 +157,31 @@ typedef struct {
int32_t dim;
} ggml_metal_kargs_concat;
typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
float slope;
float scale;
float bias;
float val;
float min;
float max;
} ggml_metal_kargs_unary;
typedef struct {
int32_t ne00;
int32_t ne01;
@@ -181,20 +239,6 @@ typedef struct {
uint64_t nb3;
} ggml_metal_kargs_repeat;
typedef struct {
float scale;
float bias;
} ggml_metal_kargs_scale;
typedef struct {
float val;
} ggml_metal_kargs_fill;
typedef struct {
float min;
float max;
} ggml_metal_kargs_clamp;
typedef struct {
int64_t nk0;
int64_t ne00;
@@ -498,8 +542,21 @@ typedef struct {
typedef struct {
int32_t ne00;
int32_t ne00_4;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
float eps;
} ggml_metal_kargs_l2_norm;
@@ -881,10 +938,6 @@ typedef struct {
int max_period;
} ggml_metal_kargs_timestep_embedding;
typedef struct {
float slope;
} ggml_metal_kargs_leaky_relu;
typedef struct {
int32_t ne00;
int32_t ne01;
+264 -197
View File
@@ -287,17 +287,9 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
n_fuse = ggml_metal_op_acc(ctx, idx);
} break;
case GGML_OP_SCALE:
{
n_fuse = ggml_metal_op_scale(ctx, idx);
} break;
case GGML_OP_FILL:
{
n_fuse = ggml_metal_op_fill(ctx, idx);
} break;
case GGML_OP_CLAMP:
{
n_fuse = ggml_metal_op_clamp(ctx, idx);
} break;
case GGML_OP_LEAKY_RELU:
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_SIN:
@@ -426,10 +418,6 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_top_k(ctx, idx);
} break;
case GGML_OP_LEAKY_RELU:
{
n_fuse = ggml_metal_op_leaky_relu(ctx, idx);
} break;
case GGML_OP_TRI:
{
n_fuse = ggml_metal_op_tri(ctx, idx);
@@ -438,6 +426,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx);
} break;
case GGML_OP_SET:
{
n_fuse = ggml_metal_op_set(ctx, idx);
} break;
case GGML_OP_DUP:
case GGML_OP_CPY:
case GGML_OP_CONT:
@@ -628,8 +620,8 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
GGML_ASSERT(op->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
GGML_ASSERT(ggml_is_contiguous(op->src[1]));
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
const size_t pnb1 = ((const int32_t *) op->op_params)[0];
const size_t pnb2 = ((const int32_t *) op->op_params)[1];
@@ -679,10 +671,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
}
ggml_metal_kargs_bin args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.ne00 =*/ ne10,
/*.ne01 =*/ ne11,
/*.ne02 =*/ ne12,
/*.ne03 =*/ ne13,
/*.nb00 =*/ nb00,
/*.nb01 =*/ pnb1,
/*.nb02 =*/ pnb2,
@@ -695,10 +687,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.ne0 =*/ ne10,
/*.ne1 =*/ ne11,
/*.ne2 =*/ ne12,
/*.ne3 =*/ ne13,
/*.nb0 =*/ nb0,
/*.nb1 =*/ pnb1,
/*.nb2 =*/ pnb2,
@@ -715,126 +707,19 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
int nth = 1;
while (2*nth < args.ne0 && nth < nth_max) {
nth *= 2;
}
ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1);
return 1;
}
int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
float scale;
float bias;
memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(float));
memcpy(&bias, ((const int32_t *) op->op_params) + 1, sizeof(float));
ggml_metal_kargs_scale args = {
/*.scale =*/ scale,
/*.bias =*/ bias,
};
int64_t n = ggml_nelements(op);
if (n % 4 == 0) {
n /= 4;
}
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
return 1;
}
int ggml_metal_op_fill(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
const float val = ggml_get_op_params_f32(op, 0);
ggml_metal_kargs_fill args = {
/*.val =*/ val
};
int64_t n = ggml_nelements(op);
if (n % 4 == 0) {
n /= 4;
}
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
return 1;
}
int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
float min;
float max;
memcpy(&min, ((const int32_t *) op->op_params) + 0, sizeof(float));
memcpy(&max, ((const int32_t *) op->op_params) + 1, sizeof(float));
ggml_metal_kargs_clamp args = {
/*.min =*/ min,
/*.max =*/ max,
};
int64_t n = ggml_nelements(op);
if (n % 4 == 0) {
n /= 4;
}
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
return 1;
}
int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
@@ -846,19 +731,79 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
int64_t n = ggml_nelements(op);
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
if (n % 4 == 0) {
n /= 4;
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
ggml_metal_kargs_unary args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.slope =*/ 0.0,
/*.scale =*/ 0.0,
/*.bias =*/ 0.0,
/*.val =*/ 0.0,
/*.min =*/ 0.0,
/*.max =*/ 0.0,
};
if (op->op == GGML_OP_LEAKY_RELU) {
args.slope = ggml_get_op_params_f32(op, 0);
}
if (op->op == GGML_OP_SCALE) {
args.scale = ggml_get_op_params_f32(op, 0);
args.bias = ggml_get_op_params_f32(op, 1);
}
if (op->op == GGML_OP_FILL) {
args.val = ggml_get_op_params_f32(op, 0);
}
if (op->op == GGML_OP_CLAMP) {
args.min = ggml_get_op_params_f32(op, 0);
args.max = ggml_get_op_params_f32(op, 1);
}
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 1);
if (pipeline.c4) {
args.ne00 = ne00/4;
args.ne0 = ne0/4;
}
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
if (pipeline.cnt) {
const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
} else {
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
const int nth = MIN(args.ne00, nth_max);
const int nk0 = (args.ne00 + nth - 1)/nth;
ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1);
}
return 1;
}
@@ -969,6 +914,11 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
ggml_metal_kargs_sum_rows args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
@@ -990,21 +940,26 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
if (pipeline.c4) {
args.ne00 = ne00/4;
args.ne0 = ne0/4;
}
int nth = 32; // SIMD width
while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
while (nth < args.ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth *= 2;
}
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
nth = std::min(nth, ne00);
nth = std::min(nth, (int) args.ne00);
const size_t smem = pipeline.smem;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
@@ -1664,6 +1619,134 @@ int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
return 1;
}
int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
const size_t pnb1 = ((const int32_t *) op->op_params)[0];
const size_t pnb2 = ((const int32_t *) op->op_params)[1];
const size_t pnb3 = ((const int32_t *) op->op_params)[2];
const size_t offs = ((const int32_t *) op->op_params)[3];
const bool inplace = (bool) ((const int32_t *) op->op_params)[4];
if (!inplace) {
// run a separete kernel to cpy src->dst
// not sure how to avoid this
// TODO: make a simpler cpy_bytes kernel
//const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
ggml_metal_kargs_cpy args = {
/*.nk0 =*/ ne00,
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
};
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
ggml_metal_op_concurrency_reset(ctx);
}
auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[1]->type, op->type);
GGML_ASSERT(ne10 % ggml_blck_size(op->src[1]->type) == 0);
int64_t nk0 = ne10;
if (ggml_is_quantized(op->src[1]->type)) {
nk0 = ne10/16;
} else if (ggml_is_quantized(op->type)) {
nk0 = ne10/ggml_blck_size(op->type);
}
int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
// when rows are small, we can batch them together in a single threadgroup
int nrptg = 1;
// TODO: relax this constraint in the future
if (ggml_blck_size(op->src[1]->type) == 1 && ggml_blck_size(op->type) == 1) {
if (nth > nk0) {
nrptg = (nth + nk0 - 1)/nk0;
nth = nk0;
if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nrptg--;
}
}
}
nth = std::min<int>(nth, nk0);
ggml_metal_kargs_cpy args = {
/*.nk0 =*/ nk0,
/*.ne00 =*/ ne10,
/*.ne01 =*/ ne11,
/*.ne02 =*/ ne12,
/*.ne03 =*/ ne13,
/*.nb00 =*/ nb10,
/*.nb01 =*/ nb11,
/*.nb02 =*/ nb12,
/*.nb03 =*/ nb13,
/*.ne0 =*/ ne10,
/*.ne1 =*/ ne11,
/*.ne2 =*/ ne12,
/*.ne3 =*/ ne13,
/*.nb0 =*/ ggml_element_size(op),
/*.nb1 =*/ pnb1,
/*.nb2 =*/ pnb2,
/*.nb3 =*/ pnb3,
};
const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
bid_dst.offs += offs;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne11 + nrptg - 1)/nrptg, ne12, ne13, nth, nrptg, 1);
return 1;
}
int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
@@ -3044,39 +3127,59 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
float eps;
memcpy(&eps, op->op_params, sizeof(float));
int nth = 32; // SIMD width
ggml_metal_kargs_l2_norm args = {
/*.ne00 =*/ ne00,
/*.ne00_4 =*/ ne00/4,
/*.nb01 =*/ nb01,
/*.eps =*/ eps,
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.eps =*/ eps,
};
auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
if (pipeline.c4) {
args.ne00 = ne00/4;
args.ne0 = ne0/4;
}
int nth = 32; // SIMD width
while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth *= 2;
}
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
nth = std::min(nth, ne00/4);
const size_t smem = pipeline.smem;
const int64_t nrows = ggml_nrows(op->src[0]);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
return 1;
}
@@ -4084,42 +4187,6 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
return 1;
}
int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
float slope;
memcpy(&slope, op->op_params, sizeof(float));
ggml_metal_kargs_leaky_relu args = {
/*.slope =*/ slope
};
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
int64_t n = ggml_nelements(op);
if (n % 4 == 0) {
n /= 4;
}
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
return 1;
}
int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
+1 -4
View File
@@ -46,9 +46,6 @@ size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op);
int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_repeat (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_acc (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_fill (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx);
@@ -62,6 +59,7 @@ int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_set (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_pool_1d (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_pool_2d (ggml_metal_op_t ctx, int idx);
@@ -86,7 +84,6 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_tri (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);
+275 -466
View File
@@ -77,6 +77,14 @@ static inline float dot(float x, float y) {
return x*y;
}
static inline float sum(float x) {
return x;
}
static inline float sum(float4 x) {
return x[0] + x[1] + x[2] + x[3];
}
// NOTE: this is not dequantizing - we are simply fitting the template
template <typename type4x4>
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
@@ -895,6 +903,210 @@ enum ggml_sort_order {
GGML_SORT_ORDER_DESC,
};
constant float GELU_COEF_A = 0.044715f;
constant float GELU_QUICK_COEF = -1.702f;
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
// ref: https://www.johndcook.com/blog/python_erf/
constant float p_erf = 0.3275911f;
constant float a1_erf = 0.254829592f;
constant float a2_erf = -0.284496736f;
constant float a3_erf = 1.421413741f;
constant float a4_erf = -1.453152027f;
constant float a5_erf = 1.061405429f;
template<typename T>
inline T erf_approx(T x) {
T sign_x = sign(x);
x = fabs(x);
T t = 1.0f / (1.0f + p_erf * x);
T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
return sign_x * y;
}
template<typename T> T elu_approx(T x);
template<> inline float elu_approx<float>(float x) {
return (x > 0.f) ? x : (exp(x) - 1);
}
template<> inline float4 elu_approx<float4>(float4 x) {
float4 res;
res[0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
res[1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
res[2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
res[3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
return res;
}
constant short FC_unary_op [[function_constant(FC_UNARY + 0)]];
constant bool FC_unary_cnt[[function_constant(FC_UNARY + 1)]];
template <typename T0, typename T, typename TC>
kernel void kernel_unary_impl(
constant ggml_metal_kargs_unary & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
#define FC_OP FC_unary_op
#define FC_CNT FC_unary_cnt
device const T0 * src0_ptr;
device T * dst_ptr;
int i0;
if (FC_CNT) {
i0 = tgpig.x;
src0_ptr = (device const T0 *) (src0);
dst_ptr = (device T *) (dst);
} else {
const int i03 = tgpig.z;
const int i02 = tgpig.y;
const int k0 = tgpig.x/args.ne01;
const int i01 = tgpig.x - k0*args.ne01;
i0 = k0*ntg.x + tpitg.x;
src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 );
}
{
//threadgroup_barrier(mem_flags::mem_none);
if (!FC_CNT) {
if (i0 >= args.ne0) {
return;
}
}
const TC x = (TC) src0_ptr[i0];
if (FC_OP == OP_UNARY_NUM_SCALE) {
dst_ptr[i0] = (T) (args.scale * x + args.bias);
}
if (FC_OP == OP_UNARY_NUM_FILL) {
dst_ptr[i0] = (T) args.val;
}
if (FC_OP == OP_UNARY_NUM_CLAMP) {
dst_ptr[i0] = (T) clamp(x, args.min, args.max);
}
if (FC_OP == OP_UNARY_NUM_SQR) {
dst_ptr[i0] = (T) (x * x);
}
if (FC_OP == OP_UNARY_NUM_SQRT) {
dst_ptr[i0] = (T) sqrt(x);
}
if (FC_OP == OP_UNARY_NUM_SIN) {
dst_ptr[i0] = (T) sin(x);
}
if (FC_OP == OP_UNARY_NUM_COS) {
dst_ptr[i0] = (T) cos(x);
}
if (FC_OP == OP_UNARY_NUM_LOG) {
dst_ptr[i0] = (T) log(x);
}
if (FC_OP == OP_UNARY_NUM_LEAKY_RELU) {
dst_ptr[i0] = (T) (TC(x > 0)*x + TC(x <= 0)*(x * args.slope));
}
if (FC_OP == OP_UNARY_NUM_TANH) {
dst_ptr[i0] = (T) precise::tanh(x);
}
if (FC_OP == OP_UNARY_NUM_RELU) {
dst_ptr[i0] = (T) fmax(0, x);
}
if (FC_OP == OP_UNARY_NUM_SIGMOID) {
dst_ptr[i0] = (T) (1 / (1 + exp(-x)));
}
if (FC_OP == OP_UNARY_NUM_GELU) {
dst_ptr[i0] = (T) (0.5*x*(1 + precise::tanh(SQRT_2_OVER_PI*x*(1 + GELU_COEF_A*x*x))));
}
if (FC_OP == OP_UNARY_NUM_GELU_ERF) {
dst_ptr[i0] = (T) (0.5*x*(1 + erf_approx(SQRT_2_INV*x)));
}
if (FC_OP == OP_UNARY_NUM_GELU_QUICK) {
dst_ptr[i0] = (T) (x * (1/(1 + exp(GELU_QUICK_COEF*x))));
}
if (FC_OP == OP_UNARY_NUM_SILU) {
dst_ptr[i0] = (T) (x / (1 + exp(-x)));
}
if (FC_OP == OP_UNARY_NUM_ELU) {
dst_ptr[i0] = (T) elu_approx(x);
}
if (FC_OP == OP_UNARY_NUM_NEG) {
dst_ptr[i0] = (T) -x;
}
if (FC_OP == OP_UNARY_NUM_ABS) {
dst_ptr[i0] = (T) fabs(x);
}
if (FC_OP == OP_UNARY_NUM_SGN) {
dst_ptr[i0] = T(x > 0) - T(x < 0);
}
if (FC_OP == OP_UNARY_NUM_STEP) {
dst_ptr[i0] = T(x > 0);
}
if (FC_OP == OP_UNARY_NUM_HARDSWISH) {
dst_ptr[i0] = (T) (x * fmax(0, fmin(1, x/6 + 0.5)));
}
if (FC_OP == OP_UNARY_NUM_HARDSIGMOID) {
dst_ptr[i0] = (T) fmax(0, fmin(1, x/6 + 0.5));
}
if (FC_OP == OP_UNARY_NUM_EXP) {
dst_ptr[i0] = (T) exp(x);
}
if (FC_OP == OP_UNARY_NUM_SOFTPLUS) {
dst_ptr[i0] = (T) select(log(1 + exp(x)), x, x > 20);
}
if (FC_OP == OP_UNARY_NUM_EXPM1) {
// TODO: precise implementation
dst_ptr[i0] = (T) (exp(x) - 1);
}
}
#undef FC_OP
#undef FC_CNT
}
typedef decltype(kernel_unary_impl<float, float, float>) kernel_unary_t;
template [[host_name("kernel_unary_f32_f32")]] kernel kernel_unary_t kernel_unary_impl<float, float, float>;
template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl<float4, float4, float4>;
template [[host_name("kernel_unary_f16_f16")]] kernel kernel_unary_t kernel_unary_impl<half, half, float>;
template [[host_name("kernel_unary_f16_f16_4")]] kernel kernel_unary_t kernel_unary_impl<half4, half4, float4>;
// OP: 0 - add, 1 - sub, 2 - mul, 3 - div
constant short FC_bin_op [[function_constant(FC_BIN + 0)]];
constant short FC_bin_f [[function_constant(FC_BIN + 1)]];
@@ -1114,414 +1326,6 @@ template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat
template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
kernel void kernel_scale_f32(
constant ggml_metal_kargs_scale & args,
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * args.scale + args.bias;
}
kernel void kernel_scale_f32_4(
constant ggml_metal_kargs_scale & args,
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * args.scale + args.bias;
}
kernel void kernel_fill_f32(
constant ggml_metal_kargs_fill & args,
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = args.val;
}
kernel void kernel_fill_f32_4(
constant ggml_metal_kargs_fill & args,
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = args.val;
}
kernel void kernel_clamp_f32(
constant ggml_metal_kargs_clamp & args,
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = clamp(src0[tpig], args.min, args.max);
}
kernel void kernel_clamp_f32_4(
constant ggml_metal_kargs_clamp & args,
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = clamp(src0[tpig], args.min, args.max);
}
kernel void kernel_relu_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = max(0.0f, src0[tpig]);
}
kernel void kernel_relu_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = max(0.0f, src0[tpig]);
}
kernel void kernel_sigmoid_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
}
kernel void kernel_sigmoid_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
}
kernel void kernel_tanh_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = precise::tanh(src0[tpig]);
}
kernel void kernel_tanh_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = precise::tanh(src0[tpig]);
}
constant float GELU_COEF_A = 0.044715f;
constant float GELU_QUICK_COEF = -1.702f;
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
kernel void kernel_gelu_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
device const float & x = src0[tpig];
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
}
kernel void kernel_gelu_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
device const float4 & x = src0[tpig];
// BEWARE !!!
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
// This was observed with Falcon 7B and 40B models
//
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
}
kernel void kernel_gelu_quick_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
device const float & x = src0[tpig];
dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
}
kernel void kernel_gelu_quick_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
device const float4 & x = src0[tpig];
dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
}
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
// ref: https://www.johndcook.com/blog/python_erf/
constant float p_erf = 0.3275911f;
constant float a1_erf = 0.254829592f;
constant float a2_erf = -0.284496736f;
constant float a3_erf = 1.421413741f;
constant float a4_erf = -1.453152027f;
constant float a5_erf = 1.061405429f;
template<typename T>
T erf_approx(T x) {
T sign_x = sign(x);
x = fabs(x);
T t = 1.0f / (1.0f + p_erf * x);
T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
return sign_x * y;
}
kernel void kernel_gelu_erf_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
device const float & x = src0[tpig];
dst[tpig] = 0.5f*x*(1.0f+erf_approx<float>(x*SQRT_2_INV));
}
kernel void kernel_gelu_erf_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
device const float4 & x = src0[tpig];
dst[tpig] = 0.5f*x*(1.0f+erf_approx<float4>(x*SQRT_2_INV));
}
kernel void kernel_silu_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
device const float & x = src0[tpig];
dst[tpig] = x / (1.0f + exp(-x));
}
kernel void kernel_silu_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
device const float4 & x = src0[tpig];
dst[tpig] = x / (1.0f + exp(-x));
}
kernel void kernel_elu_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
const float x = src0[tpig];
dst[tpig] = (x > 0.0f) ? x : (exp(x) - 1.0f);
}
kernel void kernel_elu_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
const float4 x = src0[tpig];
dst[tpig][0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
dst[tpig][1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
dst[tpig][2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
dst[tpig][3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
}
kernel void kernel_sqr_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * src0[tpig];
}
kernel void kernel_sqr_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * src0[tpig];
}
kernel void kernel_sqrt_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = sqrt(src0[tpig]);
}
kernel void kernel_sqrt_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = sqrt(src0[tpig]);
}
kernel void kernel_sin_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = sin(src0[tpig]);
}
kernel void kernel_sin_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = sin(src0[tpig]);
}
kernel void kernel_cos_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = cos(src0[tpig]);
}
kernel void kernel_cos_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = cos(src0[tpig]);
}
kernel void kernel_log_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = log(src0[tpig]);
}
kernel void kernel_log_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = log(src0[tpig]);
}
kernel void kernel_neg_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = -src0[tpig];
}
kernel void kernel_neg_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = -src0[tpig];
}
kernel void kernel_abs_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = fabs(src0[tpig]);
}
kernel void kernel_abs_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = fabs(src0[tpig]);
}
kernel void kernel_sgn_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = sign(src0[tpig]);
}
kernel void kernel_sgn_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = sign(src0[tpig]);
}
kernel void kernel_step_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = step(0.0f, src0[tpig]);
}
kernel void kernel_step_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = step(0.0f, src0[tpig]);
}
kernel void kernel_hardswish_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
const float x = src0[tpig];
dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
}
kernel void kernel_hardswish_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
const float4 x = src0[tpig];
dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
}
kernel void kernel_hardsigmoid_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
const float x = src0[tpig];
dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
}
kernel void kernel_hardsigmoid_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
const float4 x = src0[tpig];
dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
}
kernel void kernel_exp_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = exp(src0[tpig]);
}
kernel void kernel_exp_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = exp(src0[tpig]);
}
kernel void kernel_softplus_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
device const float & x = src0[tpig];
dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
}
kernel void kernel_softplus_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
device const float4 & x = src0[tpig];
dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
}
kernel void kernel_expm1_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = exp(src0[tpig]) - 1.0f;
}
kernel void kernel_expm1_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = exp(src0[tpig]) - 1.0f;
}
kernel void kernel_reglu_f32(
constant ggml_metal_kargs_glu & args,
device const char * src0,
@@ -1705,33 +1509,35 @@ kernel void kernel_op_sum_f32(
}
}
template <bool norm>
kernel void kernel_sum_rows(
constant short FC_sum_rows_op [[function_constant(FC_SUM_ROWS + 0)]];
template <typename T0, typename T>
kernel void kernel_sum_rows_impl(
constant ggml_metal_kargs_sum_rows & args,
device const float * src0,
device float * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
device const char * src0,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
int64_t i3 = tgpig.z;
int64_t i2 = tgpig.y;
int64_t i1 = tgpig.x;
#define FC_OP FC_sum_rows_op
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
return;
}
const int i3 = tgpig.z;
const int i2 = tgpig.y;
const int i1 = tgpig.x;
threadgroup T0 * shmem_t = (threadgroup T0 *) shmem;
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
shmem_t[tiisg] = 0.0f;
}
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
device const T0 * src_row = (device const T0 *) (src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
device T * dst_row = (device T *) (dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
float sumf = 0;
T0 sumf = T0(0.0f);
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
sumf += src_row[i0];
@@ -1742,23 +1548,33 @@ kernel void kernel_sum_rows(
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
shmem_t[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg];
sumf = shmem_t[tiisg];
sumf = simd_sum(sumf);
if (tpitg.x == 0) {
dst_row[0] = norm ? sumf / args.ne00 : sumf;
if (FC_OP == OP_SUM_ROWS_NUM_MEAN) {
if (is_same<float4, T0>::value) {
dst_row[0] = sum(sumf) / (4*args.ne00);
} else {
dst_row[0] = sum(sumf) / args.ne00;
}
} else {
dst_row[0] = sum(sumf);
}
}
#undef FC_OP
}
typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
typedef decltype(kernel_sum_rows_impl<float, float>) kernel_sum_rows_t;
template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
template [[host_name("kernel_sum_rows_f32_f32")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float, float>;
template [[host_name("kernel_sum_rows_f32_f32_4")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float4, float>;
template<typename T>
kernel void kernel_cumsum_blk(
@@ -2639,9 +2455,6 @@ kernel void kernel_solve_tri_f32(
const short K = FC_solve_tri_k;
const short NP = PAD2(N, NW);
const int32_t ne02 = args.ne02;
const int32_t ne03 = args.ne03;
const int32_t i03 = tgpig.z;
const int32_t i02 = tgpig.y;
const int32_t i01 = tgpig.x*NSG + sgitg;
@@ -2928,26 +2741,32 @@ template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_f
template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>;
template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 3>;
kernel void kernel_l2_norm_f32(
template <typename T0, typename T>
kernel void kernel_l2_norm_impl(
constant ggml_metal_kargs_l2_norm & args,
device const char * src0,
device char * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
ushort tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort ntg[[threads_per_threadgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig.z;
const int i02 = tgpig.y;
const int i01 = tgpig.x;
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
device const T0 * x = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
float sumf = 0.0f;
// parallel sum
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
sumf += dot(x[i00], x[i00]);
}
sumf = simd_sum(sumf);
@@ -2965,12 +2784,16 @@ kernel void kernel_l2_norm_f32(
const float scale = 1.0f/sqrt(max(sumf, args.eps));
device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
y[i00] = x[i00] * scale;
}
}
typedef decltype(kernel_l2_norm_impl<float, float>) kernel_l2_norm_t;
template [[host_name("kernel_l2_norm_f32_f32")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float, float>;
template [[host_name("kernel_l2_norm_f32_f32_4")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float4, float4>;
kernel void kernel_group_norm_f32(
constant ggml_metal_kargs_group_norm & args,
device const float * src0,
@@ -5072,24 +4895,6 @@ kernel void kernel_argsort_merge_f32_i32(
template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;
template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;
kernel void kernel_leaky_relu_f32(
constant ggml_metal_kargs_leaky_relu & args,
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
const float x = src0[tpig];
dst[tpig] = x > 0.0f ? x : x * args.slope;
}
kernel void kernel_leaky_relu_f32_4(
constant ggml_metal_kargs_leaky_relu & args,
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
const float4 x = src0[tpig];
dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope);
}
constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]];
constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]];
@@ -6161,7 +5966,7 @@ kernel void kernel_flash_attn_ext_vec(
static_assert(DK4 % NL == 0, "DK4 must be divisible by NL");
static_assert(DV4 % NL == 0, "DV4 must be divisible by NL");
const short T = PK + NSG*SH; // shared memory size per query in (half)
//const short T = PK + NSG*SH; // shared memory size per query in (half)
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t
@@ -8749,7 +8554,9 @@ kernel void kernel_mul_mm(
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
#ifdef GGML_METAL_HAS_TENSOR
threadgroup float * sc = (threadgroup float *)(shmem);
#endif
constexpr int NR0 = 64;
constexpr int NR1 = 32;
@@ -8872,8 +8679,8 @@ kernel void kernel_mul_mm(
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
const short dx = sx;
const short dy = sy;
//const short dx = sx;
//const short dy = sy;
const short ly = (tiitg/NL1)%8;
@@ -9122,7 +8929,9 @@ kernel void kernel_mul_mm_id(
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
#ifdef GGML_METAL_HAS_TENSOR
threadgroup float * sc = (threadgroup float *)(shmem);
#endif
constexpr int NR0 = 64;
constexpr int NR1 = 32;
@@ -9257,8 +9066,8 @@ kernel void kernel_mul_mm_id(
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
const short dx = sx;
const short dy = sy;
//const short dx = sx;
//const short dy = sy;
const short ly = (tiitg/NL1)%8;
@@ -9939,7 +9748,7 @@ kernel void kernel_opt_step_sgd_f32(
template<typename T>
kernel void kernel_memset(
constant ggml_metal_kargs_fill & args,
constant ggml_metal_kargs_memset & args,
device T * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = args.val;
+6
View File
@@ -85,6 +85,9 @@ set(GGML_OPENCL_KERNELS
mul_mv_q4_0_f32_8x_flat
mul_mv_q4_0_f32_1d_8x_flat
mul_mv_q4_0_f32_1d_16x_flat
mul_mv_q4_1_f32
mul_mv_q4_1_f32_flat
mul_mv_q4_k_f32
mul_mv_q6_k_f32
mul_mv_q6_k_f32_flat
mul_mv_q8_0_f32
@@ -100,7 +103,10 @@ set(GGML_OPENCL_KERNELS
gemv_moe_mxfp4_f32
mul_mm_f32_f32_l4_lm
mul_mm_f16_f32_l4_lm
mul_mm_q4_0_f32_l4_lm
mul_mm_q4_1_f32_l4_lm
mul_mm_q8_0_f32_l4_lm
mul_mm_q6_k_f32_l4_lm
mul_mm_q8_0_f32_8x4
gemv_noshuffle_general_q8_0_f32
mul
File diff suppressed because it is too large Load Diff
+51
View File
@@ -46,6 +46,15 @@ struct block_q4_0
uint8_t qs[QK4_0 / 2];
};
//------------------------------------------------------------------------------
// block_q4_1
//------------------------------------------------------------------------------
struct block_q4_1 {
half d; // delta
half m; // min
uchar qs[QK4_1 / 2]; // nibbles / quants
};
//------------------------------------------------------------------------------
// block_q6_K
//------------------------------------------------------------------------------
@@ -148,6 +157,48 @@ kernel void kernel_restore_block_q4_0_noshuffle(
}
}
//------------------------------------------------------------------------------
// kernel_convert_block_q4_1
// Convert the block_q4_1 format to 2 separate arrays (AOS -> SOA).
// This kernel does not deshuffle the bits.
//------------------------------------------------------------------------------
kernel void kernel_convert_block_q4_1(
global struct block_q4_1 * src0,
global uchar * dst_q,
global half * dst_d,
global half * dst_m
) {
global struct block_q4_1 * b = (global struct block_q4_1 *) src0 + get_global_id(0);
global uchar * q = (global uchar *) dst_q + QK4_1/2*get_global_id(0);
global half * d = (global half *) dst_d + get_global_id(0);
global half * m = (global half *) dst_m + get_global_id(0);
*d = b->d;
*m = b->m;
for (int i = 0; i < QK4_1/2; ++i) {
q[i] = b->qs[i];
}
}
kernel void kernel_restore_block_q4_1(
global uchar * src_q,
global half * src_d,
global half * src_m,
global struct block_q4_1 * dst
) {
global struct block_q4_1 * b = (global struct block_q4_1 *) dst + get_global_id(0);
global uchar * q = (global uchar *) src_q + QK4_1/2*get_global_id(0);
global half * d = (global half *) src_d + get_global_id(0);
global half * m = (global half *) src_m + get_global_id(0);
b->d = *d;
b->m = *m;
for (int i = 0; i < QK4_1/2; ++i) {
b->qs[i] = q[i];
}
}
//------------------------------------------------------------------------------
// block_mxfp4
//------------------------------------------------------------------------------
+87 -56
View File
@@ -3,80 +3,111 @@
//------------------------------------------------------------------------------
// expm1
//------------------------------------------------------------------------------
kernel void kernel_expm1_f32_nd(
global void * p_src0_base,
ulong off_src0_abs,
global void * p_dst_base,
ulong off_dst_abs,
int ne00,
int ne01,
int ne02,
int ne03,
kernel void kernel_expm1_f32(
global const float * src0,
ulong offset0,
global float * dst,
ulong offsetd
) {
src0 = (global float*)((global char*)src0 + offset0);
dst = (global float*)((global char*)dst + offsetd);
dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0f;
}
kernel void kernel_expm1_f32_4(
global const float4 * src0,
ulong offset0,
global float4 * dst,
ulong offsetd
) {
src0 = (global float4*)((global char*)src0 + offset0);
dst = (global float4*)((global char*)dst + offsetd);
dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0f;
}
kernel void kernel_expm1_f16(
global const half * src0,
ulong offset0,
global half * dst,
ulong offsetd
) {
src0 = (global half*)((global char*)src0 + offset0);
dst = (global half*)((global char*)dst + offsetd);
dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0h;
}
kernel void kernel_expm1_f16_4(
global const half4 * src0,
ulong offset0,
global half4 * dst,
ulong offsetd
) {
src0 = (global half4*)((global char*)src0 + offset0);
dst = (global half4*)((global char*)dst + offsetd);
dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0h;
}
kernel void kernel_expm1_f32_nc(
global const char * src0,
ulong offset0,
global char * dst,
ulong offsetd,
int ne00,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
int ne11,
int ne12,
int ne13,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb13
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3
) {
int i0 = get_global_id(0);
int i1 = get_global_id(1);
int i2 = get_global_id(2);
src0 = src0 + offset0;
dst = dst + offsetd;
if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
for (int i3 = 0; i3 < ne13; ++i3) {
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
const int i3 = get_group_id(2);
const int i2 = get_group_id(1);
const int i1 = get_group_id(0);
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
*dst_val_ptr = exp(*src_val_ptr) - 1;
}
*y = exp(*x) - 1.0f;
}
}
kernel void kernel_expm1_f16_nd(
global void * p_src0_base,
ulong off_src0_abs,
global void * p_dst_base,
ulong off_dst_abs,
int ne00,
int ne01,
int ne02,
int ne03,
kernel void kernel_expm1_f16_nc(
global const char * src0,
ulong offset0,
global char * dst,
ulong offsetd,
int ne00,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
int ne11,
int ne12,
int ne13,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb13
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3
) {
int i0 = get_global_id(0);
int i1 = get_global_id(1);
int i2 = get_global_id(2);
src0 = src0 + offset0;
dst = dst + offsetd;
if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
for (int i3 = 0; i3 < ne13; ++i3) {
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
const int i3 = get_group_id(2);
const int i2 = get_group_id(1);
const int i1 = get_group_id(0);
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
global const half * x = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
global half * y = (global half *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
*dst_val_ptr = exp(*src_val_ptr) - 1;
}
*y = exp(*x) - 1.0f;
}
}
+116 -15
View File
@@ -1,8 +1,13 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
// Most devices have max workgroup size of 1024, so this is enough for subgroup
// sizes of 16, 32, 64 and 128. Increase this value for smaller subgroups sizes
#define MAX_SUBGROUPS 64
kernel void kernel_mean_f32(
global float * src0,
global char * src0,
ulong offset0,
global float * dst,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
@@ -15,25 +20,121 @@ kernel void kernel_mean_f32(
ulong nb2,
ulong nb3
) {
src0 = (global float *)((global char *)src0 + offset0);
dst = (global float *)((global char *)dst + offsetd);
src0 = src0 + offset0;
dst = dst + offsetd;
int i3 = get_global_id(2);
int i2 = get_global_id(1);
int i1 = get_global_id(0);
const int i3 = get_group_id(2);
const int i2 = get_group_id(1);
const int i1 = get_group_id(0);
const int lid = get_local_id(0);
const int lsize = get_local_size(0);
const uint sg_size = get_sub_group_size();
const uint sg_id = get_sub_group_id();
const uint sg_lid = get_sub_group_local_id();
__local float lmem[MAX_SUBGROUPS];
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
return;
}
global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
global float * dst_row = (global float *) ((global char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
float row_sum = 0;
for (int i0 = 0; i0 < ne00; i0++) {
row_sum += src_row[i0];
if(sg_id == 0){
lmem[sg_lid] = 0.0f;
}
dst_row[0] = row_sum / ne00;
global float * src_row = (global float *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);
global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3);
float sumf = 0.0f;
for (int i0 = lid; i0 < ne00; i0 += lsize) {
sumf += src_row[i0];
}
sumf = sub_group_reduce_add(sumf);
barrier(CLK_LOCAL_MEM_FENCE);
if(sg_lid == 0){
lmem[sg_id] = sumf;
}
barrier(CLK_LOCAL_MEM_FENCE);
sumf = lmem[sg_lid];
sumf = sub_group_reduce_add(sumf);
if (lid == 0) {
dst_row[0] = sumf / ne00;
}
}
kernel void kernel_mean_f32_4(
global char * src0,
ulong offset0,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb01,
ulong nb02,
ulong nb03,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = src0 + offset0;
dst = dst + offsetd;
const int i3 = get_group_id(2);
const int i2 = get_group_id(1);
const int i1 = get_group_id(0);
const int lid = get_local_id(0);
const int lsize = get_local_size(0);
const uint sg_size = get_sub_group_size();
const uint sg_id = get_sub_group_id();
const uint sg_lid = get_sub_group_local_id();
__local float lmem[MAX_SUBGROUPS];
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
return;
}
if(sg_id == 0){
lmem[sg_lid] = 0.0f;
}
global float4 * src_row = (global float4 *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);
global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3);
float4 sum_vec = (float4)0.0f;
for (int i0 = lid; i0 < ne00 / 4; i0 += lsize) {
sum_vec += src_row[i0];
}
float sumf = dot(sum_vec, (float4)(1.0f));
sumf = sub_group_reduce_add(sumf);
barrier(CLK_LOCAL_MEM_FENCE);
if(sg_lid == 0){
lmem[sg_id] = sumf;
}
barrier(CLK_LOCAL_MEM_FENCE);
sumf = lmem[sg_lid];
sumf = sub_group_reduce_add(sumf);
if (lid == 0) {
dst_row[0] = sumf / ne00;
}
}
@@ -0,0 +1,163 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#define LOAD_VEC_A 8
#define LOAD_VEC_B 4
#define BM 64
#define BN 64
#define BK 32
#define TM 4
#define TN 8
kernel void kernel_mul_mm_q4_0_f32_l4_lm(
global uchar4 * src0_q,
global half * src0_d,
global float4 * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne11,
int ne12,
int stride_a,
int stride_b,
int stride_d,
int batch_stride_a,
int batch_stride_b,
int batch_stride_d,
int r2,
int r3
) {
src1 = (global float4*)((global char*)src1 + offset1);
dst = (global float *)((global char*)dst + offsetd);
local float buf_a[BM * BK];
local float buf_b[BN * BK];
const int batch_idx = get_global_id(2);
const int i13 = batch_idx / ne12;
const int i12 = batch_idx % ne12;
const int i03 = i13 / r3;
const int i02 = i12 / r2;
const int batch_idx_a = i03 * ne02 + i02;
const int ir = get_group_id(0);
const int ic = get_group_id(1);
const int tid = get_local_id(0);
const int th_r = tid % (BM / TM);
const int th_c = tid / (BM / TM);
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
float sums[TM * TN];
float cache_a[TM];
float cache_b[TN];
for (int i = 0; i < TM * TN; i++) {
sums[i] = 0.0f;
}
for (int block = 0; block < ne00; block += BK) {
for (int l = 0; l < BM; l += loadstride_a) {
if (ir*BM + loadc_a + l < ne01) {
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
int ib = idx / 4;
int iqs = idx % 4;
float d = (float)src0_d[ib];
global uchar4 * qs = src0_q + ib*4 + iqs;
uchar4 q = *qs;
float4 v1 = (convert_float4((uchar4)((q.s0 )&0x0F, (q.s1 )&0x0F, (q.s2 )&0x0F, (q.s3 )&0x0F)) - 8.0f)*d;
float4 v2 = (convert_float4((uchar4)((q.s0>>4)&0x0F, (q.s1>>4)&0x0F, (q.s2>>4)&0x0F, (q.s3>>4)&0x0F)) - 8.0f)*d;
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0;
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1;
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2;
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3;
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0;
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1;
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2;
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3;
} else {
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f;
}
}
for (int l = 0; l < BN; l += loadstride_b) {
if (ic*BN + loadc_b + l < ne11) {
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
} else {
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
}
}
barrier(CLK_LOCAL_MEM_FENCE);
pos_a += BK / LOAD_VEC_A;
pos_b += BK / LOAD_VEC_B;
for (int i = 0; i < BK; i++) {
for (int j = 0; j < TM; j++) {
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
}
for (int j = 0; j < TN; j++) {
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
}
for (int cc = 0; cc < TN; cc++) {
for (int cr = 0; cr < TM; cr++) {
const int sums_idx = cc*TM + cr;
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
}
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
const int dr = ir * BM + th_r * TM;
const int dc = ic * BN + th_c * TN;
const int offsets = batch_idx * batch_stride_d;
for (int cc = 0; cc < TN; cc++) {
for (int cr = 0; cr < TM; cr++) {
if (dr + cr < ne01 && dc + cc < ne11) {
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
}
}
}
}
@@ -0,0 +1,165 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#define LOAD_VEC_A 8
#define LOAD_VEC_B 4
#define BM 64
#define BN 64
#define BK 32
#define TM 4
#define TN 8
kernel void kernel_mul_mm_q4_1_f32_l4_lm(
global uchar4 * src0_q,
global half * src0_d,
global half * src0_m,
global float4 * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne11,
int ne12,
int stride_a,
int stride_b,
int stride_d,
int batch_stride_a,
int batch_stride_b,
int batch_stride_d,
int r2,
int r3
) {
src1 = (global float4*)((global char*)src1 + offset1);
dst = (global float *)((global char*)dst + offsetd);
local float buf_a[BM * BK];
local float buf_b[BN * BK];
const int batch_idx = get_global_id(2);
const int i13 = batch_idx / ne12;
const int i12 = batch_idx % ne12;
const int i03 = i13 / r3;
const int i02 = i12 / r2;
const int batch_idx_a = i03 * ne02 + i02;
const int ir = get_group_id(0);
const int ic = get_group_id(1);
const int tid = get_local_id(0);
const int th_r = tid % (BM / TM);
const int th_c = tid / (BM / TM);
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
float sums[TM * TN];
float cache_a[TM];
float cache_b[TN];
for (int i = 0; i < TM * TN; i++) {
sums[i] = 0.0f;
}
for (int block = 0; block < ne00; block += BK) {
for (int l = 0; l < BM; l += loadstride_a) {
if (ir*BM + loadc_a + l < ne01) {
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
int ib = idx / 4;
int iqs = idx % 4;
float d = (float)src0_d[ib];
float m = (float)src0_m[ib];
global uchar4 * qs = src0_q + ib*4 + iqs;
uchar4 q = *qs;
float4 v1 = (convert_float4((uchar4)((q.s0 )&0x0F, (q.s1 )&0x0F, (q.s2 )&0x0F, (q.s3 )&0x0F)))*d + m;
float4 v2 = (convert_float4((uchar4)((q.s0>>4)&0x0F, (q.s1>>4)&0x0F, (q.s2>>4)&0x0F, (q.s3>>4)&0x0F)))*d + m;
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0;
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1;
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2;
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3;
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0;
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1;
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2;
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3;
} else {
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f;
}
}
for (int l = 0; l < BN; l += loadstride_b) {
if (ic*BN + loadc_b + l < ne11) {
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
} else {
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
}
}
barrier(CLK_LOCAL_MEM_FENCE);
pos_a += BK / LOAD_VEC_A;
pos_b += BK / LOAD_VEC_B;
for (int i = 0; i < BK; i++) {
for (int j = 0; j < TM; j++) {
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
}
for (int j = 0; j < TN; j++) {
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
}
for (int cc = 0; cc < TN; cc++) {
for (int cr = 0; cr < TM; cr++) {
const int sums_idx = cc*TM + cr;
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
}
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
const int dr = ir * BM + th_r * TM;
const int dc = ic * BN + th_c * TN;
const int offsets = batch_idx * batch_stride_d;
for (int cc = 0; cc < TN; cc++) {
for (int cr = 0; cr < TM; cr++) {
if (dr + cr < ne01 && dc + cc < ne11) {
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
}
}
}
}
@@ -0,0 +1,158 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#define LOAD_VEC_A 2
#define LOAD_VEC_B 4
#define BM 64
#define BN 64
#define BK 32
#define TM 4
#define TN 8
kernel void kernel_mul_mm_q6_k_f32_l4_lm(
global uchar * src0_ql,
global uchar * src0_qh,
global char * src0_s,
global half * src0_d,
global float4 * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne11,
int ne12,
int stride_a,
int stride_b,
int stride_d,
int batch_stride_a,
int batch_stride_b,
int batch_stride_d,
int r2,
int r3
) {
src1 = (global float4*)((global char*)src1 + offset1);
dst = (global float *)((global char*)dst + offsetd);
local float buf_a[BM * BK];
local float buf_b[BN * BK];
const int batch_idx = get_global_id(2);
const int i13 = batch_idx / ne12;
const int i12 = batch_idx % ne12;
const int i03 = i13 / r3;
const int i02 = i12 / r2;
const int batch_idx_a = i03 * ne02 + i02;
const int ir = get_group_id(0);
const int ic = get_group_id(1);
const int tid = get_local_id(0);
const int th_r = tid % (BM / TM);
const int th_c = tid / (BM / TM);
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
float sums[TM * TN];
float cache_a[TM];
float cache_b[TN];
for (int i = 0; i < TM * TN; i++) {
sums[i] = 0.0f;
}
for (int block = 0; block < ne00; block += BK) {
for (int l = 0; l < BM; l += loadstride_a) {
if (ir*BM + loadc_a + l < ne01) {
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
int ib = idx / 128; // 2 values per idx
int iqs = idx % 128; // 0..127
int n = iqs / 64; // 0,1
int b = (iqs % 64) / 32; // 0,1
int is_b = (iqs % 16) / 8; // 0,1
int qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
int is = 8 * n + qhshift + is_b; // 0..15
int qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126
int qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
float dscale = (float)src0_d[ib] * (float)src0_s[ib*16 + is];
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 0] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 0] >> qhshift) & 3) << 4)) - 32);
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 1] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 1] >> qhshift) & 3) << 4)) - 32);
} else {
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f;
}
}
for (int l = 0; l < BN; l += loadstride_b) {
if (ic*BN + loadc_b + l < ne11) {
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
} else {
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
}
}
barrier(CLK_LOCAL_MEM_FENCE);
pos_a += BK / LOAD_VEC_A;
pos_b += BK / LOAD_VEC_B;
for (int i = 0; i < BK; i++) {
for (int j = 0; j < TM; j++) {
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
}
for (int j = 0; j < TN; j++) {
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
}
for (int cc = 0; cc < TN; cc++) {
for (int cr = 0; cr < TM; cr++) {
const int sums_idx = cc*TM + cr;
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
}
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
const int dr = ir * BM + th_r * TM;
const int dc = ic * BN + th_c * TN;
const int offsets = batch_idx * batch_stride_d;
for (int cc = 0; cc < TN; cc++) {
for (int cr = 0; cr < TM; cr++) {
if (dr + cr < ne01 && dc + cc < ne11) {
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
}
}
}
}
@@ -0,0 +1,219 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_subgroups
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#endif
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#define QK4_1 32
struct block_q4_1 {
half d; // delta
half m; // min
uchar qs[QK4_1 / 2]; // nibbles / quants
};
inline float block_q4_1_dot_y(
global const struct block_q4_1 * qb_curr,
float sumy,
float16 yl,
int il
) {
float d = qb_curr->d;
float m = qb_curr->m;
float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f);
global const ushort * qs = ((global const ushort *) qb_curr + 2 + il/2);
acc.s0 += yl.s0 * (qs[0] & 0x000F);
acc.s0 += yl.s1 * (qs[0] & 0x0F00);
acc.s0 += yl.s8 * (qs[0] & 0x00F0);
acc.s3 += yl.s9 * (qs[0] & 0xF000);
acc.s0 += yl.s2 * (qs[1] & 0x000F);
acc.s1 += yl.s3 * (qs[1] & 0x0F00);
acc.s2 += yl.sa * (qs[1] & 0x00F0);
acc.s3 += yl.sb * (qs[1] & 0xF000);
acc.s0 += yl.s4 * (qs[2] & 0x000F);
acc.s1 += yl.s5 * (qs[2] & 0x0F00);
acc.s2 += yl.sc * (qs[2] & 0x00F0);
acc.s3 += yl.sd * (qs[2] & 0xF000);
acc.s0 += yl.s6 * (qs[3] & 0x000F);
acc.s1 += yl.s7 * (qs[3] & 0x0F00);
acc.s2 += yl.se * (qs[3] & 0x00F0);
acc.s3 += yl.sf * (qs[3] & 0xF000);
return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3) + sumy * m;
}
#undef N_DST
#undef N_SIMDGROUP
#undef N_SIMDWIDTH
#ifdef INTEL_GPU
#define N_DST 4 // each subgroup works on 4 rows
#define N_SIMDGROUP 1 // number of subgroups in a thread group
#define N_SIMDWIDTH 16 // assuming subgroup size is 16
#elif defined (ADRENO_GPU)
#define N_DST 4
#define N_SIMDGROUP 1
#define N_SIMDWIDTH 64
#endif
inline void mul_vec_q_n_f32(
global void * src0,
global float * src1,
global float * dst,
int ne00,
int ne01,
int ne02,
int ne10,
int ne12,
int ne0,
int ne1,
int r2,
int r3
) {
const ulong nb = ne00/QK4_1;
int r0 = get_group_id(0);
int r1 = get_group_id(1);
int im = get_group_id(2);
int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
int i12 = im%ne12;
int i13 = im/ne12;
ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
global struct block_q4_1 * x = (global struct block_q4_1 *) src0 + offset0;
global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1;
float16 yl;
float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f);
int ix = get_sub_group_local_id()/2;
int il = 8*(get_sub_group_local_id()%2);
global float * yb = y + ix * QK4_1 + il;
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
float sumy = 0;
sumy += yb[0];
sumy += yb[1];
sumy += yb[2];
sumy += yb[3];
sumy += yb[4];
sumy += yb[5];
sumy += yb[6];
sumy += yb[7];
sumy += yb[16];
sumy += yb[17];
sumy += yb[18];
sumy += yb[19];
sumy += yb[20];
sumy += yb[21];
sumy += yb[22];
sumy += yb[23];
yl.s0 = yb[0];
yl.s1 = yb[1]/256.f;
yl.s2 = yb[2];
yl.s3 = yb[3]/256.f;
yl.s4 = yb[4];
yl.s5 = yb[5]/256.f;
yl.s6 = yb[6];
yl.s7 = yb[7]/256.f;
yl.s8 = yb[16]/16.f;
yl.s9 = yb[17]/4096.f;
yl.sa = yb[18]/16.f;
yl.sb = yb[19]/4096.f;
yl.sc = yb[20]/16.f;
yl.sd = yb[21]/4096.f;
yl.se = yb[22]/16.f;
yl.sf = yb[23]/4096.f;
sumf.s0 += block_q4_1_dot_y(x+ib+0*nb, sumy, yl, il);
sumf.s1 += block_q4_1_dot_y(x+ib+1*nb, sumy, yl, il);
sumf.s2 += block_q4_1_dot_y(x+ib+2*nb, sumy, yl, il);
sumf.s3 += block_q4_1_dot_y(x+ib+3*nb, sumy, yl, il);
yb += QK4_1 * (N_SIMDWIDTH/2);
}
float4 tot = (float4)(
sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3)
);
if (get_sub_group_local_id() == 0) {
if (first_row + 0 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
}
if (first_row + 1 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
}
if (first_row + 2 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
}
if (first_row + 3 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
}
}
}
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_16
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_mul_mv_q4_1_f32(
global void * src0,
ulong offset0,
global float * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne10,
int ne12,
int ne0,
int ne1,
int r2,
int r3
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global float*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);
mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
}
@@ -0,0 +1,229 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_subgroups
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#endif
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#define QK4_1 32
struct block_q4_1 {
half d; // delta
half m; // min
uchar qs[QK4_1 / 2]; // nibbles / quants
};
inline float block_q4_1_dot_y_flat(
global const uchar * x,
global const half * dh,
global const half * mh,
float sumy,
float16 yl,
int il
) {
float d = *dh;
float m = *mh;
global const ushort * qs = ((global const ushort *) x + il/2);
float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f);
acc.s0 += yl.s0 * (qs[0] & 0x000F);
acc.s0 += yl.s1 * (qs[0] & 0x0F00);
acc.s0 += yl.s8 * (qs[0] & 0x00F0);
acc.s3 += yl.s9 * (qs[0] & 0xF000);
acc.s0 += yl.s2 * (qs[1] & 0x000F);
acc.s1 += yl.s3 * (qs[1] & 0x0F00);
acc.s2 += yl.sa * (qs[1] & 0x00F0);
acc.s3 += yl.sb * (qs[1] & 0xF000);
acc.s0 += yl.s4 * (qs[2] & 0x000F);
acc.s1 += yl.s5 * (qs[2] & 0x0F00);
acc.s2 += yl.sc * (qs[2] & 0x00F0);
acc.s3 += yl.sd * (qs[2] & 0xF000);
acc.s0 += yl.s6 * (qs[3] & 0x000F);
acc.s1 += yl.s7 * (qs[3] & 0x0F00);
acc.s2 += yl.se * (qs[3] & 0x00F0);
acc.s3 += yl.sf * (qs[3] & 0xF000);
return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3) + sumy * m;
}
#undef N_DST
#undef N_SIMDGROUP
#undef N_SIMDWIDTH
#ifdef INTEL_GPU
#define N_DST 4 // each subgroup works on 4 rows
#define N_SIMDGROUP 1 // number of subgroups in a thread group
#define N_SIMDWIDTH 16 // assuming subgroup size is 16
#elif defined (ADRENO_GPU)
#define N_DST 4
#define N_SIMDGROUP 1
#define N_SIMDWIDTH 64
#endif
inline void mul_vec_q_n_f32_flat(
global void * src0_q,
global void * src0_d,
global void * src0_m,
global float * src1,
global float * dst,
int ne00,
int ne01,
int ne02,
int ne10,
int ne12,
int ne0,
int ne1,
int r2,
int r3
) {
const ulong nb = ne00/QK4_1;
int r0 = get_group_id(0);
int r1 = get_group_id(1);
int im = get_group_id(2);
int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
int i12 = im%ne12;
int i13 = im/ne12;
ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
// The number of scales/mins is the same as the number of blocks.
ulong offset0_dm = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02));
// Each block contains QK4_1/2 uchars, hence offset for qs is as follows.
ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_1/2;
global uchar * x = (global uchar *) src0_q + offset0_q;
global half * d = (global half *) src0_d + offset0_dm;
global half * m = (global half *) src0_m + offset0_dm;
global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1;
float16 yl;
float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f);
int ix = get_sub_group_local_id()/2;
int il = 8*(get_sub_group_local_id()%2);
global float * yb = y + ix * QK4_1 + il;
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
float sumy = 0;
sumy += yb[0];
sumy += yb[1];
sumy += yb[2];
sumy += yb[3];
sumy += yb[4];
sumy += yb[5];
sumy += yb[6];
sumy += yb[7];
sumy += yb[16];
sumy += yb[17];
sumy += yb[18];
sumy += yb[19];
sumy += yb[20];
sumy += yb[21];
sumy += yb[22];
sumy += yb[23];
yl.s0 = yb[0];
yl.s1 = yb[1]/256.f;
yl.s2 = yb[2];
yl.s3 = yb[3]/256.f;
yl.s4 = yb[4];
yl.s5 = yb[5]/256.f;
yl.s6 = yb[6];
yl.s7 = yb[7]/256.f;
yl.s8 = yb[16]/16.f;
yl.s9 = yb[17]/4096.f;
yl.sa = yb[18]/16.f;
yl.sb = yb[19]/4096.f;
yl.sc = yb[20]/16.f;
yl.sd = yb[21]/4096.f;
yl.se = yb[22]/16.f;
yl.sf = yb[23]/4096.f;
sumf.s0 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 0*nb*QK4_1/2, d + ib + 0*nb, m + ib + 0*nb, sumy, yl, il);
sumf.s1 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 1*nb*QK4_1/2, d + ib + 1*nb, m + ib + 1*nb, sumy, yl, il);
sumf.s2 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 2*nb*QK4_1/2, d + ib + 2*nb, m + ib + 2*nb, sumy, yl, il);
sumf.s3 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 3*nb*QK4_1/2, d + ib + 3*nb, m + ib + 3*nb, sumy, yl, il);
yb += QK4_1 * (N_SIMDWIDTH/2);
}
float4 tot = (float4)(
sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3)
);
if (get_sub_group_local_id() == 0) {
if (first_row + 0 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
}
if (first_row + 1 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
}
if (first_row + 2 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
}
if (first_row + 3 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
}
}
}
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_16
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_mul_mv_q4_1_f32_flat(
global void * src0_q,
global void * src0_d,
global void * src0_m,
global float * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne10,
int ne12,
int ne0,
int ne1,
int r2,
int r3
) {
src1 = (global float*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);
mul_vec_q_n_f32_flat(src0_q, src0_d, src0_m, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
}
@@ -0,0 +1,180 @@
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
//------------------------------------------------------------------------------
// block_q4_K
//------------------------------------------------------------------------------
#define QK_K 256
#define K_SCALE_SIZE 12
// 8 blocks of 32 elements each
// weight is represented as x = a * q + b
typedef struct {
half d; // super-block scale for quantized scales
half dmin; // super-block scale for quantized mins
uchar scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
uchar qs[QK_K/2]; // 4-bit quants
} block_q4_K;
#undef N_DST
#undef N_SIMDGROUP
#undef N_SIMDWIDTH
#ifdef INTEL_GPU
#define N_DST 4 // number of rows each SIMD group works on
#define N_SIMDGROUP 1 // number of SIMD groups in a thread group
#define N_SIMDWIDTH 16 // SIMD group size
#elif defined (ADRENO_GPU)
#define N_DST 4
#define N_SIMDGROUP 1
#define N_SIMDWIDTH 64
#endif
#undef BLOCK_STRIDE
// number of (super) blocks each subgroup processes
// each thread in a subgroup processes a block (32 weights)
#define BLOCK_STRIDE (N_SIMDWIDTH/8)
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_16
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_mul_mv_q4_K_f32(
global char * src0,
int offset0,
global char * src1,
int offset1,
global char * dst,
int offsetd,
int ne00,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
int ne12,
ulong nb11,
ulong nb12,
ulong nb13,
int ne0,
int ne1,
int r2,
int r3
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
ushort kmask1 = 0x3f3f;
ushort kmask2 = 0x0f0f;
ushort kmask3 = 0xc0c0;
int ix = get_sub_group_local_id()/8; // super block index
int it = get_sub_group_local_id()%8; // block index (inside super block)
int iq = it/4; // 0 or 1 - first or second half of the super block
int ir = it%4; // 0...3 - block index in the half super block
int nb = ne00/QK_K;
int r0 = get_group_id(0);
int r1 = get_group_id(1);
int im = get_group_id(2);
int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
int i12 = im%ne12;
int i13 = im/ne12;
int offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
int offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
global block_q4_K * x = (global block_q4_K *) (src0 + offset_src0);
global float * y = (global float *) (src1 + offset_src1);
float yl[16];
float yh[16];
float sumf[N_DST] = {0.f};
float all_sum;
global float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
ushort sc16[4];
uchar * sc8 = (uchar *)sc16;
for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) {
float4 sumy = {0.f, 0.f, 0.f, 0.f};
for (int i = 0; i < 8; ++i) {
yl[i+0] = y4[i+0];
sumy.s0 += yl[i+0];
yl[i+8] = y4[i+32];
sumy.s1 += yl[i+8];
yh[i+0] = y4[i+128];
sumy.s2 += yh[i+0];
yh[i+8] = y4[i+160];
sumy.s3 += yh[i+8];
}
global ushort * sc = (global ushort *)x[ib].scales + iq;
global ushort * q1 = (global ushort *)x[ib].qs + 16 * iq + 4 * ir;
global half * dh = &x[ib].d;
for (int row = 0; row < N_DST; row++) {
sc16[0] = sc[0] & kmask1;
sc16[1] = sc[2] & kmask1;
sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
global ushort * q2 = q1 + 32;
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
for (int i = 0; i < 8; i += 2) {
acc1.s0 += yl[i+0] * (q1[i/2] & 0x000F);
acc1.s1 += yl[i+1] * (q1[i/2] & 0x0F00);
acc1.s2 += yl[i+8] * (q1[i/2] & 0x00F0);
acc1.s3 += yl[i+9] * (q1[i/2] & 0xF000);
acc2.s0 += yh[i+0] * (q2[i/2] & 0x000F);
acc2.s1 += yh[i+1] * (q2[i/2] & 0x0F00);
acc2.s2 += yh[i+8] * (q2[i/2] & 0x00F0);
acc2.s3 += yh[i+9] * (q2[i/2] & 0xF000);
}
float dall = dh[0];
float dmin = dh[1];
sumf[row] += dall * ((acc1.s0 + 1.f/256.f * acc1.s1) * sc8[0] +
(acc1.s2 + 1.f/256.f * acc1.s3) * sc8[1] * 1.f/16.f +
(acc2.s0 + 1.f/256.f * acc2.s1) * sc8[4] +
(acc2.s2 + 1.f/256.f * acc2.s3) * sc8[5] * 1.f/16.f) -
dmin * (sumy.s0 * sc8[2] + sumy.s1 * sc8[3] + sumy.s2 * sc8[6] + sumy.s3 * sc8[7]);
q1 += nb01/2;
sc += nb01/2;
dh += nb01/2;
}
y4 += BLOCK_STRIDE * QK_K;
}
global float * dst_f32 = (global float *) dst + im*ne0*ne1 + r1*ne0;
for (int row = 0; row < N_DST; ++row) {
all_sum = sub_group_reduce_add(sumf[row]);
if (first_row + row < ne01) {
if (get_sub_group_local_id() == 0) {
dst_f32[first_row + row] = all_sum;
}
}
}
}
+88 -60
View File
@@ -3,86 +3,114 @@
//------------------------------------------------------------------------------
// softplus
//------------------------------------------------------------------------------
inline float softplus_f32(float x){
float ax = fabs(x);
float m = fmax(x, 0.0f);
return log1p(exp(-ax)) + m;
kernel void kernel_softplus_f32(
global const float * src0,
ulong offset0,
global float * dst,
ulong offsetd
) {
src0 = (global float*)((global char*)src0 + offset0);
dst = (global float*)((global char*)dst + offsetd);
dst[get_global_id(0)] = (src0[get_global_id(0)] > 20.0f) ? src0[get_global_id(0)] : log(1.0f + exp(src0[get_global_id(0)]));
}
kernel void kernel_softplus_f32_nd(
global void * p_src0_base,
ulong off_src0_abs,
global void * p_dst_base,
ulong off_dst_abs,
int ne00,
int ne01,
int ne02,
int ne03,
kernel void kernel_softplus_f32_4(
global const float4 * src0,
ulong offset0,
global float4 * dst,
ulong offsetd
) {
src0 = (global float4*)((global char*)src0 + offset0);
dst = (global float4*)((global char*)dst + offsetd);
dst[get_global_id(0)] = (src0[get_global_id(0)] > 20.0f) ? src0[get_global_id(0)] : log(1.0f + exp(src0[get_global_id(0)]));
}
kernel void kernel_softplus_f16(
global const half * src0,
ulong offset0,
global half * dst,
ulong offsetd
) {
src0 = (global half*)((global char*)src0 + offset0);
dst = (global half*)((global char*)dst + offsetd);
const float x = convert_float(src0[get_global_id(0)]);
dst[get_global_id(0)] = convert_half_rte((x > 20.0f) ? x : log(1.0f + exp(x)));
}
kernel void kernel_softplus_f16_4(
global const half4 * src0,
ulong offset0,
global half4 * dst,
ulong offsetd
) {
src0 = (global half4*)((global char*)src0 + offset0);
dst = (global half4*)((global char*)dst + offsetd);
const float4 x = convert_float4(src0[get_global_id(0)]);
dst[get_global_id(0)] = convert_half4_rte((x > 20.0f) ? x : log(1.0f + exp(x)));
}
kernel void kernel_softplus_f32_nc(
global const char * src0,
ulong offset0,
global char * dst,
ulong offsetd,
int ne00,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
int ne11,
int ne12,
int ne13,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb13
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3
) {
int i0 = get_global_id(0);
int i1 = get_global_id(1);
int i2 = get_global_id(2);
src0 = src0 + offset0;
dst = dst + offsetd;
if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
for (int i3 = 0; i3 < ne13; ++i3) {
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
const int i3 = get_group_id(2);
const int i2 = get_group_id(1);
const int i1 = get_group_id(0);
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
*dst_val_ptr = softplus_f32(*src_val_ptr);
}
*y = (*x > 20.0f) ? *x : log(1.0f + exp(*x));
}
}
kernel void kernel_softplus_f16_nd(
global void * p_src0_base,
ulong off_src0_abs,
global void * p_dst_base,
ulong off_dst_abs,
int ne00,
int ne01,
int ne02,
int ne03,
kernel void kernel_softplus_f16_nc(
global const char * src0,
ulong offset0,
global char * dst,
ulong offsetd,
int ne00,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
int ne11,
int ne12,
int ne13,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb13
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3
) {
int i0 = get_global_id(0);
int i1 = get_global_id(1);
int i2 = get_global_id(2);
src0 = src0 + offset0;
dst = dst + offsetd;
if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
for (int i3 = 0; i3 < ne13; ++i3) {
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
const int i3 = get_group_id(2);
const int i2 = get_group_id(1);
const int i1 = get_group_id(0);
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
global const half * hx = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
global half * hy = (global half *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
*dst_val_ptr = (half)(softplus_f32((float)(*src_val_ptr)));
}
const float x = convert_float(*hx);
*hy = convert_half_rte((x > 20.0f) ? x : log(1.0f + exp(x)));
}
}
+116 -15
View File
@@ -1,8 +1,13 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
// Most devices have max workgroup size of 1024, so this is enough for subgroup
// sizes of 16, 32, 64 and 128. Increase this value for smaller subgroups sizes
#define MAX_SUBGROUPS 64
kernel void kernel_sum_rows_f32(
global float * src0,
global char * src0,
ulong offset0,
global float * dst,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
@@ -15,25 +20,121 @@ kernel void kernel_sum_rows_f32(
ulong nb2,
ulong nb3
) {
src0 = (global float *)((global char *)src0 + offset0);
dst = (global float *)((global char *)dst + offsetd);
src0 = src0 + offset0;
dst = dst + offsetd;
int i3 = get_global_id(2);
int i2 = get_global_id(1);
int i1 = get_global_id(0);
const int i3 = get_group_id(2);
const int i2 = get_group_id(1);
const int i1 = get_group_id(0);
const int lid = get_local_id(0);
const int lsize = get_local_size(0);
const uint sg_size = get_sub_group_size();
const uint sg_id = get_sub_group_id();
const uint sg_lid = get_sub_group_local_id();
__local float lmem[MAX_SUBGROUPS];
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
return;
}
global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
global float * dst_row = (global float *) ((global char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
float row_sum = 0;
for (int i0 = 0; i0 < ne00; i0++) {
row_sum += src_row[i0];
if(sg_id == 0){
lmem[sg_lid] = 0.0f;
}
dst_row[0] = row_sum;
global float * src_row = (global float *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);
global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3);
float sumf = 0.0f;
for (int i0 = lid; i0 < ne00; i0 += lsize) {
sumf += src_row[i0];
}
sumf = sub_group_reduce_add(sumf);
barrier(CLK_LOCAL_MEM_FENCE);
if(sg_lid == 0){
lmem[sg_id] = sumf;
}
barrier(CLK_LOCAL_MEM_FENCE);
sumf = lmem[sg_lid];
sumf = sub_group_reduce_add(sumf);
if (lid == 0) {
dst_row[0] = sumf;
}
}
kernel void kernel_sum_rows_f32_4(
global char * src0,
ulong offset0,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb01,
ulong nb02,
ulong nb03,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = src0 + offset0;
dst = dst + offsetd;
const int i3 = get_group_id(2);
const int i2 = get_group_id(1);
const int i1 = get_group_id(0);
const int lid = get_local_id(0);
const int lsize = get_local_size(0);
const uint sg_size = get_sub_group_size();
const uint sg_id = get_sub_group_id();
const uint sg_lid = get_sub_group_local_id();
__local float lmem[MAX_SUBGROUPS];
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
return;
}
if(sg_id == 0){
lmem[sg_lid] = 0.0f;
}
global float4 * src_row = (global float4 *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);
global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3);
float4 sum_vec = (float4)0.0f;
for (int i0 = lid; i0 < ne00 / 4; i0 += lsize) {
sum_vec += src_row[i0];
}
float sumf = dot(sum_vec, (float4)(1.0f));
sumf = sub_group_reduce_add(sumf);
barrier(CLK_LOCAL_MEM_FENCE);
if(sg_lid == 0){
lmem[sg_id] = sumf;
}
barrier(CLK_LOCAL_MEM_FENCE);
sumf = lmem[sg_lid];
sumf = sub_group_reduce_add(sumf);
if (lid == 0) {
dst_row[0] = sumf;
}
}
+92 -41
View File
@@ -92,6 +92,7 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
#define VK_VENDOR_ID_APPLE 0x106b
#define VK_VENDOR_ID_INTEL 0x8086
#define VK_VENDOR_ID_NVIDIA 0x10de
#define VK_VENDOR_ID_QUALCOMM 0x5143
#define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256
@@ -687,6 +688,7 @@ struct vk_device_struct {
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
vk_pipeline pipeline_acc_f32;
vk_pipeline pipeline_set_f32;
// [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16]
vk_pipeline pipeline_add[2][2][2];
@@ -942,6 +944,7 @@ struct vk_mat_mat_push_constants {
uint32_t M; uint32_t N; uint32_t K;
uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
uint32_t base_work_group_z; uint32_t num_batches;
uint32_t k_split;
uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
uint32_t padded_N;
@@ -961,6 +964,7 @@ struct vk_mat_vec_push_constants {
uint32_t batch_stride_b;
uint32_t batch_stride_d;
uint32_t fusion_flags;
uint32_t base_work_group_y;
uint32_t ne02;
uint32_t ne12;
uint32_t broadcast2;
@@ -4080,7 +4084,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
}
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@@ -4181,7 +4185,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_set_f32, "set_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 0}, 1);
ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
@@ -5641,6 +5646,10 @@ static void ggml_vk_instance_init() {
driver_priorities[vk::DriverId::eMesaNvk] = 2;
#endif
break;
case VK_VENDOR_ID_QUALCOMM:
driver_priorities[vk::DriverId::eQualcommProprietary] = 1;
driver_priorities[vk::DriverId::eMesaTurnip] = 2;
break;
}
driver_priorities[vk::DriverId::eMesaDozen] = 100;
@@ -6766,8 +6775,16 @@ static void ggml_vk_matmul(
uint32_t padded_n) {
VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")");
if (split_k == 1) {
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, batch });
ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2]));
uint32_t base_work_group_z = 0;
while (base_work_group_z < batch) {
uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k, ne02, ne12, broadcast2, broadcast3, padded_n };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, groups_z });
base_work_group_z += groups_z;
}
return;
}
@@ -6781,9 +6798,17 @@ static void ggml_vk_matmul(
uint32_t k_split = CEIL_DIV(k, split_k);
k_split = ROUNDUP_POW2(k_split, 256);
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
// Make sure enough workgroups get assigned for split k to work
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2]));
uint32_t base_work_group_z = 0;
while (base_work_group_z < batch) {
uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
// Make sure enough workgroups get assigned for split k to work
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, groups_z });
base_work_group_z += groups_z;
}
ggml_vk_sync_buffers(ctx, subctx);
const std::array<uint32_t, 2> pc2 = { (uint32_t)(m * n * batch), split_k };
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 });
@@ -7179,7 +7204,6 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
}
// Request descriptor sets
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
if (qx_needs_dequant) {
ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);
}
@@ -7477,7 +7501,6 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (quantize_y) {
ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
}
ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
}
vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
@@ -7572,22 +7595,29 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;
}
// compute
const vk_mat_vec_push_constants pc = {
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
stride_batch_x, stride_batch_y, stride_batch_d,
fusion_flags,
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
};
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
{
d_X,
d_Y,
d_D,
d_F0,
d_F1,
},
pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
ggml_pipeline_request_descriptor_sets(ctx, dmmv, CEIL_DIV(ne12 * ne13, ctx->device->properties.limits.maxComputeWorkGroupCount[1]));
uint32_t base_work_group_y = 0;
while (base_work_group_y < ne12 * ne13) {
uint32_t groups_y = std::min((uint32_t)(ne12 * ne13) - base_work_group_y, ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
const vk_mat_vec_push_constants pc = {
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
stride_batch_x, stride_batch_y, stride_batch_d,
fusion_flags, base_work_group_y,
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
};
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
{
d_X,
d_Y,
d_D,
d_F0,
d_F1,
},
pc, { groups_x, groups_y, groups_z });
base_work_group_y += groups_y;
}
if (x_non_contig) {
ctx->prealloc_x_need_sync = true;
@@ -7825,10 +7855,15 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
src1->nb[2] <= src1->nb[1] &&
src1->nb[1] <= src1->nb[3] &&
src0->ne[3] == 1 &&
src1->ne[3] == 1) {
src1->ne[3] == 1 &&
src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) {
ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, cgraph, node_idx);
} else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 &&
!ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
!ggml_is_permuted(src0) && !ggml_is_permuted(src1) &&
src0->ne[3] <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&
src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) {
ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, cgraph, node_idx);
// mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
// when ne12 and ne13 are one.
@@ -8422,6 +8457,8 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
const uint32_t acctype = f32acc ? 4 : 2;
const uint32_t f16vec4 = 8;
const uint32_t tmpsh = (Bc / MatBc) * sizeof(float);
const uint32_t qstride = hsk_pad / 4 + 2;
const uint32_t Qf = Br * qstride * f16vec4;
@@ -8438,7 +8475,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
const uint32_t slope = Br * acctype;
const uint32_t total_size = Qf + Psh + sfsh + ksh + slope;
const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + slope;
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", kv_type=" << kv_type << ", total_size=" << total_size << ", supported=" << supported);
@@ -8815,6 +8852,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_acc_f32;
}
return nullptr;
case GGML_OP_SET:
if (src0->type == src1->type && src0->type == dst->type &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32)) {
return ctx->device->pipeline_set_f32;
}
return nullptr;
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_MUL:
@@ -9801,16 +9844,16 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
// int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
int offset = dst->op_params[3] / 4; // offset in bytes
int nb1 = dst->op_params[0] / src0_type_size; // 4 bytes of float32
int nb2 = dst->op_params[1] / src0_type_size; // 4 bytes of float32
int nb3 = dst->op_params[2] / src0_type_size; // 4 bytes of float32
int offset = dst->op_params[3] / src0_type_size; // offset in bytes
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ACC, {
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] / dst_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3,
0,
0.0f, 0.0f, offset,
});
@@ -10624,8 +10667,10 @@ static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& sub
}
static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
float * op_params = (float *)dst->op_params;
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
const float * op_params = (const float *)dst->op_params;
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
p.param1 = op_params[0];
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, std::move(p));
}
static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
@@ -11543,7 +11588,6 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
}
}
ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
if (split_k > 1) {
ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
@@ -12052,7 +12096,6 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
// y[i] = i % k;
}
ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
if (split_k > 1) {
ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
@@ -12500,6 +12543,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break;
case GGML_OP_ACC:
case GGML_OP_SET:
ggml_vk_acc(ctx, compute_ctx, src0, src1, node);
break;
@@ -14896,8 +14940,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return true;
case GGML_OP_NORM:
case GGML_OP_GROUP_NORM:
case GGML_OP_L2_NORM:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_L2_NORM:
return ggml_is_contiguous_rows(op->src[0]) &&
op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_MUL:
@@ -14960,7 +15006,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
}
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ACC:
return op->src[0]->type == GGML_TYPE_F32;
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
case GGML_OP_SET:
return op->src[0]->type == op->src[1]->type && op->src[0]->type == op->type &&
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_I32);
case GGML_OP_CONCAT:
return ggml_type_size(op->src[0]->type) == ggml_type_size(GGML_TYPE_F32);
case GGML_OP_ADD1:
@@ -15611,6 +15660,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_ACC) {
tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
} else if (tensor->op == GGML_OP_SET) {
tensor_clone = ggml_set(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
} else if (tensor->op == GGML_OP_NORM) {
tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
} else if (tensor->op == GGML_OP_GROUP_NORM) {
+16 -8
View File
@@ -3,6 +3,9 @@
#include "types.glsl"
#include "generic_binary_head.glsl"
// false for SET, true for ACC
layout(constant_id = 1) const bool ACC = true;
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
@@ -13,17 +16,22 @@ void main() {
const uint offset = p.param3;
const uint src1_i = idx - offset;
const uint oz = src1_i / p.nb02;
const uint oy = (src1_i - (oz * p.nb02)) / p.nb01;
const uint ox = src1_i % p.nb01;
const uint i3 = src1_i / p.nb03;
const uint rem2 = src1_i - i3 * p.nb03;
const uint i2 = rem2 / p.nb02;
const uint rem1 = rem2 - i2 * p.nb02;
const uint i1 = rem1 / p.nb01;
const uint i0 = rem1 % p.nb01;
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03);
if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) {
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + ox + oy * p.ne10 + oz * p.ne10 * p.ne11]));
if (i0 < p.ne10 && i1 < p.ne11 && i2 < p.ne12 && i3 < p.ne13) {
if (ACC) {
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)]));
} else {
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)]));
}
} else {
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]));
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]));
}
}
@@ -130,6 +130,7 @@ void main() {
if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
float max_mask = NEG_FLT_MAX_OVER_2;
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
@@ -137,12 +138,25 @@ void main() {
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
masksh[c][r] = m;
max_mask = max(max_mask, m);
} else {
masksh[c][r] = float(0);
}
}
}
// skip the block if the mask is entirely -inf
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
barrier();
if (gl_SubgroupInvocationID == 0) {
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
}
barrier();
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
max_mask = max(max_mask, tmpsh[s]);
}
if (max_mask <= NEG_FLT_MAX_OVER_2) {
continue;
}
}
float Sf[Br][cols_per_thread];
@@ -260,6 +274,9 @@ void main() {
barrier();
}
// prevent race on tmpsh
barrier();
// reduce across threads
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
@@ -42,6 +42,8 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
return elem;
}
shared float tmpsh[row_split];
const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
shared f16vec4 Qf[Br * qstride];
@@ -213,6 +215,19 @@ void main() {
}
}
}
// skip the block if the mask is entirely -inf
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
barrier();
if (gl_SubgroupInvocationID == 0) {
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
}
barrier();
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
max_mask = max(max_mask, tmpsh[s]);
}
if (max_mask <= NEG_FLT_MAX_OVER_2) {
continue;
}
}
}
@@ -176,7 +176,14 @@ void main() {
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
// skip the block if the mask is entirely -inf
coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
continue;
}
} else {
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
// Don't clamp against nem1 when GQA is enabled
@@ -184,7 +191,14 @@ void main() {
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
// skip the block if the mask is entirely -inf
coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
continue;
}
}
}
}
@@ -1,6 +1,6 @@
#version 450
#include "generic_head.glsl"
#include "generic_unary_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
@@ -8,19 +8,22 @@
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
shared FLOAT_TYPE sum[BLOCK_SIZE];
void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
const uint i3 = row / (p.ne11 * p.ne12);
const uint i3_offset = i3 * p.ne12 * p.ne11;
const uint i2 = (row - i3_offset) / p.ne11;
const uint i2_offset = i2 * p.ne11;
const uint i1 = row - i3_offset - i2_offset;
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]);
sum[tid] += xi * xi;
}
@@ -35,7 +38,7 @@ void main() {
const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1)));
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
data_d[i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0] = D_TYPE(scale * FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]));
}
}

Some files were not shown because too many files have changed in this diff Show More