Compare commits

...

133 Commits

Author SHA1 Message Date
Sigbjørn Skjæret e8e261699a cli : provide model with text filename (#19783) 2026-02-22 22:33:49 +01:00
Xuan-Son Nguyen 5452d736f8 jinja: correct stats for tojson and string filters (#19785) 2026-02-22 21:08:23 +01:00
Aldehir Rojas ed4837891d common : fix improper trimming in XML parser on complete message (#19805)
Co-authored-by: Jules LEIDELINGER <11395311+julio75012@users.noreply.github.com>
2026-02-22 17:34:54 +01:00
Kilian Krampf cacc371f99 Fix wrong cli-argument in documentation (#19804) 2026-02-22 16:26:33 +01:00
HelloKS ae2368e74e model : add Kanana-2 model support (#19803)
* model: Add Kanana-2 model support

* lint: adjust spacing
2026-02-22 16:15:02 +01:00
Sigbjørn Skjæret 9f0684f003 ci : fix rocm archive name [no ci] (#19808) 2026-02-22 16:14:37 +01:00
Aldehir Rojas 34ec1c3f18 server : merge contiguous Responses input items into a single assistant message (#19773)
* server : merge contiguous input items into a single assistant message

* cont : simplify tool call msg

* cont : reduce and combine content

* cont : fix merging content items
2026-02-22 14:11:31 +01:00
Sigbjørn Skjæret e877ad8bd9 ci : fix rocm release path [no ci] (#19784) 2026-02-22 08:07:46 +01:00
Mario Limonciello 35715657cb Update ROCm docker container to 7.2 release (#19418)
Also update architectures
2026-02-21 21:53:39 +01:00
Mario Limonciello f75c4e8bf5 Add a build target to generate ROCm artifacts using ROCm 7.2 (#19433)
This builds the following targets:
 * gfx1151
 * gfx1150
 * gfx1200
 * gfx1201
 * gfx1100
 * gfx1101
 * gfx1030
 * gfx908
 * gfx90a
 * gfx942
2026-02-21 19:56:26 +01:00
Adrien Gallouët 99156f3a5f vendor : update cpp-httplib to 0.33.1 (#19778)
Signed-off-by: Adrien Gallouët <adrien@gallouet.fr>
2026-02-21 19:12:31 +01:00
Gaurav Garg a0c91e8f9f Improve CUDA graph capture (#19754)
* Improve CUDA graph capture

Currently, CUDA graphs are eagerly enabled on the first call to ggml_backend_cuda_graph_compute. If the graph properties keep changing (4+ consecutive updates), the graph is permanently disabled. This is suboptimal because:

- The first call always incurs CUDA graph capture overhead even if the graph is unstable
- Once permanently disabled, CUDA graphs never re-enable even after the graph stabilizes (e.g., switching from prompt processing to decode)

The new approach delays CUDA graph activation until warmup completes: the same cgraph must be called at least twice with matching properties before CUDA graph capture begins. This avoids wasted capture overhead on volatile graphs and allows graphs to become eligible once they stabilize.
This also fixes issues such as https://github.com/ggml-org/llama.cpp/discussions/19708

* Update ggml/src/ggml-cuda/ggml-cuda.cu

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Remove EM dashes

* Update ggml/src/ggml-cuda/ggml-cuda.cu

Co-authored-by: Aman Gupta <amangupta052@gmail.com>

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Aman Gupta <amangupta052@gmail.com>
2026-02-21 15:09:36 +05:30
crsawyer 07968d53e4 fix: UI single model selection in router mode (#19767) 2026-02-21 09:28:39 +01:00
Mengsheng Wu ba3b9c8844 hexagon : fix build release (#19444) (#19587) 2026-02-20 16:40:00 -08:00
Aldehir Rojas 94b0200a01 common : merge qwen3-coder and nemotron nano 3 parsers (#19765)
* common : migrate qwen3-coder to PEG parsing variant

* cont : add JSON parameter test
2026-02-20 23:22:22 +01:00
Taimur Ahmad b908baf182 ggml-cpu: add RVV vec dot kernels for quantization types (#18784)
* ggml-cpu: add rvv vec_dot for iq2_s

Co-authored-by: Rehan Qasim <rehan.qasim@10xengineers.ai>

* ggml-cpu: add rvv vec_dot for iq3_s

Co-authored-by: Rehan Qasim <rehan.qasim@10xengineers.ai>

* ggml-cpu: add rvv vec_dot for tq1_0, tq2_0

Co-authored-by: Rehan Qasim <rehan.qasim@10xengineers.ai>

ggml-cpu: add rvv vec_dot for tq1_0, tq2_0

* ggml-cpu: add rvv vec_dot for iq1_s, iq1_m

Co-authored-by: Rehan Qasim <rehan.qasim@10xengineers.ai>

* ggml-cpu: add vlen switch for rvv vec_dot

---------

Co-authored-by: Rehan Qasim <rehan.qasim@10xengineers.ai>
2026-02-20 13:30:07 +02:00
ddh0 492bc31978 quantize : add --dry-run option (#19526)
* clean slate for branch

* use 6 characters for tensor dims

* add --dry-run to llama-quantize

* use 6 characters for tensor dims (cont.)

* no need to re-calculate ggml_nbytes for tensor

* fix indent

* show model and quant BPW when quant completes

* add example to --help

* new function `tensor_requires_imatrix`, add courtesy warning about imatrix

* missing __func__, move imatrix flag set

* logic error

* fixup tensor_requires_imatrix

* add missing `GGML_TYPE`s

* simplify and rename `tensor_type_requires_imatrix`

* simplify for style

* add back Q2_K edge case for imatrix

* guard ftype imatrix warning

* comment ref #12557

* remove per @compilade

* remove unused `params` parameter

* move `bool dry_run` per GG

* move `bool dry_run` per GG

* Update src/llama-quant.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/llama-quant.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/llama-quant.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-02-20 09:20:16 +01:00
Jeff Bolz 77d6ae4ac8 test: mul_mat tests with huge batch size (#19519) 2026-02-19 20:08:25 -06:00
crsawyer 10b26ee23a WebUI hide models in router mode (#19374) 2026-02-19 22:53:42 +01:00
Jesse Posner 3dadc88b58 common : fix Step-3.5-Flash format detection and thinking support (#19635)
* common : fix Step-3.5-Flash format detection and thinking support

Step-3.5-Flash uses the same XML-style tool call format as Qwen3-Coder
(<tool_call><function=...><parameter=...>) but its Jinja template lacks
the bare <function> and plural <parameters> markers that the detection
logic previously required. This caused it to fall through to Hermes 2
Pro, which doesn't call func_args_not_string(), so arguments stayed as
JSON strings and templates using arguments|items crashed.

Additionally, the Qwen3-Coder-XML format handler had no thinking support.
Models like Step-3.5-Flash that unconditionally emit <think> in their
generation prompt need the same thinking_forced_open handling that
Nemotron v3 and Hermes 2 Pro already have, otherwise reasoning_content
is never separated from content in API responses.

Changes:
- Relax Qwen3-Coder XML detection to only require the 3 shared markers
- Tighten Nemotron v3 branch to also require bare <function> and plural
  <parameters>, preventing Step-3.5-Flash from being misrouted via <think>
- Add thinking_forced_open support to Qwen3-Coder-XML init function
- Add <think>/</think> to preserved tokens
- Fix build_grammar_xml_tool_call to handle thinking_forced_open in the
  grammar root rule, allowing </think> before tool calls
- Add Step-3.5-Flash chat template and format detection test

Builds on: https://github.com/ggml-org/llama.cpp/pull/19283

* chat : route Step-3.5-Flash to Nemotron v3 PEG parser, add tests

Step-3.5-Flash uses the same XML tool call format as Qwen3-Coder and
Nemotron 3 Nano (<tool_call>/<function=...>/<parameter=...>) but with
unconditional <think> output. Route it to the Nemotron v3 PEG parser
for streaming and schema-aware parameter parsing.

Detection: templates with <think> + XML tool tags use Nemotron v3 PEG
parser; templates without <think> (Qwen3-Coder) use GBNF grammar.

Tests cover: basic messages, tool calls with/without thinking content,
parallel tool calls, code string parameters, optional </parameter>
closing tags, and JSON schema response format.

* chat : remove dead thinking code from qwen3_coder_xml

Remove thinking handling code that became unreachable after routing
Step-3.5-Flash to the Nemotron v3 PEG parser. Qwen3-Coder has no
<think> in its template, so the thinking_forced_open logic, preserved
tokens, and grammar prefix were dead paths.
2026-02-19 22:40:52 +01:00
abhijitb11 39e4b1dc9b common : fix gpt-oss Jinja error when assistant message has both content and thinking with tool calls (#19704) 2026-02-19 14:59:20 -06:00
Masashi Yoshimura 11c325c6e0 ggml-webgpu: Add unary op (SQR, SQRT, SIN, COS) support. (#19700)
* ggml-webgpu: Add unary op (SQR, SQRT, SIN, COS) support.

* Fix to cast the src value to f32 before sin/cos computing.
2026-02-19 09:18:30 -07:00
megemini 237958db33 model: Add PaddleOCR-VL model support (#18825)
* support PaddleOCR-VL

* clip: update PaddleOCR model loader parameters to prevent OOM during warmup

* [update] add paddleocr vl text model instead of ernie4.5

* [update] restore change of minicpmv

* [update] format

* [update] format

* [update] positions and patch merge permute

* [update] mtmd_decode_use_mrope for paddleocr

* [update] image min/max pixels

* [update] remove set_limit_image_tokens

* upate: preprocess without padding

* clean up

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-02-19 17:05:25 +01:00
Ruben Ortlam abb9f3c42b vulkan: fix MMQ shader push constants and multi-dispatch (#19732) 2026-02-19 14:59:16 +01:00
Georgi Gerganov da348c9dfb models : fix qwen3.5 beta/gate shapes (#19730)
* models : fix qwen3.5 beta/gate shapes

* cont : avoid extra reshapes
2026-02-19 15:19:53 +02:00
Saba Fallah e6267a9359 mtmd: build_attn modified, flash_attn on/off via ctx_params (#19729) 2026-02-19 13:50:29 +01:00
3 a l i 2bf318fd2f model : add JAIS-2 architecture support (#19488)
* model: add JAIS-2 architecture support

Add support for the JAIS-2 family of Arabic-English bilingual models
from Inception AI (https://huggingface.co/inceptionai/Jais-2-8B-Chat).

Architecture characteristics:
- LayerNorm (not RMSNorm) with biases
- ReLU² (ReLU squared) activation function
- Separate Q/K/V projections with biases
- Simple MLP without gate projection (up -> act -> down)
- RoPE positional embeddings
- GPT-2 BPE tokenizer

Supported model sizes:
- Jais-2-8B (32 layers, 26 heads, 3328 hidden)
- Jais-2-70B (68 layers, 56 heads, 7168 hidden)

Tested with quantizations: BF16, Q8_0, Q6_K, Q5_K_M, Q5_0, Q4_K_M, Q4_0, Q3_K_M, Q2_K

Note: JAIS-2 requires F32 precision accumulators for numerical stability
and uses standard attention (not flash attention) on CUDA backends.

* fix: run convert_hf_to_gguf_update.py for jais-2 tokenizer hash

* fix: use NEOX RoPE type for JAIS2

* fix: remove Q/K permutation (NEOX RoPE doesn't need it)

* fix: enable flash attention for JAIS2 (fixed by #19115)

* fix: add dedicated JAIS2 pre-tokenizer type and control vector support

- Add LLAMA_VOCAB_PRE_TYPE_JAIS2 with cascading whitespace regex
- Include original regex from tokenizer.json as comment
- Add build_cvec call for control vector support

* no longer necessary to override set_vocab

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-02-19 13:30:17 +01:00
Johannes Gäßler c78e682245 CUDA: fix kernel selection logic for tile FA (#19686)
* CUDA: fix kernel selection logic for tile FA

* add comment
2026-02-19 12:42:58 +01:00
Tarek Dakhran c5897995a7 mtmd : chat : Fix extra \n between text and media marker (#19595)
* mtmd : chat : Fix extra \n between text and media marker

Thanks to @tugot17 for detecting and reporting the issue.

For vision models (e.g. LFM2.5-VL-1.6B and Qwen/Qwen3-VL-4B-Instruct) `llama-mtmd-cli` produces identical output to HF implementation.

However `llama-server` doesn't. I traced it down to extra newline
inserted after `<__media__>`.

This happens in `to_json_oaicompat`, that treats media markers as text
and joins all parts with `\n` separator.

PR introduces new type `media_marker` and uses it for media markers.
Extra logic is added to prevent insertion of newlines before and after
media markers.

With this change number of input tokens is identical to HF
implementation and as a result the output is also identical.

I explored other ways to address the issue
* remove completely `\n` between text parts in `to_json_oaicompat`
* merge text messages in server-common.cpp before sending them to `to_json_oaicompat`

Please propose alternative ways of fixing this issue.

* Refactor to use explicite per type ifs

* Update common/chat.cpp

Co-authored-by: Piotr Wilkin (ilintar) <piotr.wilkin@syndatis.com>

* Update common_chat_templates_apply_legacy

---------

Co-authored-by: Piotr Wilkin (ilintar) <piotr.wilkin@syndatis.com>
2026-02-19 12:18:57 +01:00
Aleksander Grygier 03fd9d3bb4 webui: Fix Attachments not being included in completion request (#19731)
* fix: Add missing argument

* chore: update webui build output
2026-02-19 10:27:38 +01:00
Tarek Dakhran 8004f3a8d1 model : add tokenizer from LFM2.5-Audio-1.5B (#19687)
* model : Add tokenizer from LFM2.5-Audio-1.5B

[LFM2.5-Audio-1.5B](https://huggingface.co/LiquidAI/LFM2.5-Audio-1.5B) introduced lightweight audio tokenizer.

Tokenizer based on LFM2 architecture and acts as "embedding" model with
different input `n_embd` and output `n_embd_out`.

To be used in https://github.com/ggml-org/llama.cpp/pull/18641.

To convert use

```shell
python3 convert_hf_to_gguf.py /path/to/LFM2.5-Audio-1.5B/audio_detokenizer
```

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Formatting

* Rework check for attention layers

* Add LFM2 SWA model support

* Address PR feedback

* Set vocab to none

* Move helper function definitions to cpp file

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-02-19 09:54:48 +01:00
Daniel Bevenius eacb4b67a2 llama : use output_resolve_row() in get_logits_ith/get_embeddings_ith (#19663)
This commit updates get_logits_ith(), and get_embeddings_ith() to use
output_resolve_row() to resolve the batch index to output row index.

The motivation for this is to remove some code duplication between these
functions.
2026-02-19 09:48:08 +01:00
Ryan Mangeno c0d0430340 model : full modern bert support (#18330)
* full modern bert support

* added gelu op in rank pooling for modern bert

* still working on stuff, added mean calculation before classifier head

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* first layer is dense, as per modern bert research paper

* Update src/llama-graph.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* fixed set input for mean pooling to check if pooling type is ranking since modern bert does mean & rank

* Update src/llama-graph.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-02-19 08:52:21 +01:00
shalinib-ibm 3bb2fcc856 llamafile: powerpc: add FP16 MMA path for Q4/Q8 matmul (#19709)
Avoid xvi8ger4pp signed→unsigned bias correction by dequantizing Q4/Q8
inputs to FP16 and using FP16×FP16→FP32 MMA. This removes
post-processing overhead and improves performance.

Performance Impact:
1.5 ~ 2x improvement in PP_Speed for Q4 and Q8 Models,
measured with llama-bench and llama-batched-bench.
Q8 Model: granite-4.0-h-micro-Q8_0.gguf (from huggingface)
Q4 Model: Meta-Llama3-8b Q4 model (generated with llama-quantize from
f32 model)

llama-bench Q8 Model Results:
 model                          	       size 	     params 	 backend    	 threads 	            test 	Base t/s	Patch t/s
 granitehybrid 3B Q8_0          	   3.16 GiB 	     3.19 B 	 CPU        	      10 	             pp8 	         64.48 ± 4.72 	         73.99 ± 0.27
 granitehybrid 3B Q8_0          	   3.16 GiB 	     3.19 B 	 CPU        	      10 	            pp16 	         80.11 ± 0.32 	        112.53 ± 0.40
 granitehybrid 3B Q8_0          	   3.16 GiB 	     3.19 B 	 CPU        	      10 	            pp32 	         89.10 ± 0.27 	        152.95 ± 0.68
 granitehybrid 3B Q8_0          	   3.16 GiB 	     3.19 B 	 CPU        	      10 	            pp64 	         93.65 ± 0.25 	        187.83 ± 0.83
 granitehybrid 3B Q8_0          	   3.16 GiB 	     3.19 B 	 CPU        	      10 	           pp128 	         99.93 ± 0.02 	        201.32 ± 0.11
 granitehybrid 3B Q8_0          	   3.16 GiB 	     3.19 B 	 CPU        	      10 	           pp256 	        102.32 ± 0.40 	        208.32 ± 0.41
 granitehybrid 3B Q8_0          	   3.16 GiB 	     3.19 B 	 CPU        	      10 	           pp512 	        103.42 ± 0.40 	        209.98 ± 0.14
 granitehybrid 3B Q8_0          	   3.16 GiB 	     3.19 B 	 CPU        	      10 	           tg128 	         20.35 ± 0.01 	         19.57 ± 0.01

llama-bench Q4 Model Results:
 model                          	       size 	     params 	 backend    	 threads 	            test 	              Base    t/s 	               Patch   t/s
 llama 8B Q4_0                  	   4.33 GiB 	     8.03 B 	 CPU        	      10 	             pp8 	         34.77 ± 0.10 	         41.23 ± 0.08
 llama 8B Q4_0                  	   4.33 GiB 	     8.03 B 	 CPU        	      10 	            pp16 	         40.81 ± 0.04 	         64.55 ± 0.15
 llama 8B Q4_0                  	   4.33 GiB 	     8.03 B 	 CPU        	      10 	            pp32 	         44.65 ± 0.05 	         90.84 ± 0.22
 llama 8B Q4_0                  	   4.33 GiB 	     8.03 B 	 CPU        	      10 	            pp64 	         47.49 ± 0.03 	        114.39 ± 0.11
 llama 8B Q4_0                  	   4.33 GiB 	     8.03 B 	 CPU        	      10 	           pp128 	         49.29 ± 0.24 	        120.13 ± 0.19
 llama 8B Q4_0                  	   4.33 GiB 	     8.03 B 	 CPU        	      10 	           pp256 	         49.77 ± 0.23 	        121.51 ± 0.11
 llama 8B Q4_0                  	   4.33 GiB 	     8.03 B 	 CPU        	      10 	           pp512 	         49.89 ± 0.23 	        117.52 ± 0.10
 llama 8B Q4_0                  	   4.33 GiB 	     8.03 B 	 CPU        	      10 	           tg128 	         13.40 ± 0.01 	         13.37 ± 0.00

Llama perplexity Results:

Model	                    Base Final PPL Estimate	Patch Final PPL Estimate
granite-4.0-h-micro-Q8_0    1.3862 +/- 0.04424	        1.3868 +/- 0.04432
Meta-Llama3-8b Q4	    1.3801 +/- 0.04116	        1.3803 +/- 0.04116

Signed-off-by: Shalini.Salomi.Bodapati <Shalini.Salomi.Bodapati@ibm.com>
2026-02-19 14:28:53 +08:00
Georgi Gerganov 27326bfce1 models : dedup qwen35 graphs (#19660)
* models : dedup qwen35 graphs

* cont : add missing sigmoid
2026-02-19 08:17:49 +02:00
ymcki ad9f692f8f models : dedup Kimi Linear delta net implementation (#19668)
* models : add llm_build_delta_net_base

* cont : keep qwen35 and qwen35moe graphs intact

* cont : add comments [no ci]

* add kimi linear to delta-net-base

* removed unnecessary ggml_cont from g_exp_t

* removed ggml_cont from g_diff_exp_t. moved ggml_cont for o to kimi-linear.cpp

* removed unnecessary diag mask

* cont : simplify

* cont : avoid graph splits

* scale q after mul instead of beginning

* scale q after mul instead of beginning

* identical ppl

* cont : fix scale and decay mask

* minor : remove TODO

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-02-19 08:15:17 +02:00
Piotr Wilkin (ilintar) 8a70973557 Add Jinja support for "indent" string filter (#19529)
* Add partial Jinja support for "indent" string filter

* Fully implement indent

* Add tests for all width variants.

* Update tests/test-jinja.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Fix getline ignoring trailing newlines

* Update common/jinja/value.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* fix first indent condition

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-02-19 00:25:52 +01:00
Reese Levine e7f2f95c9a ggml webgpu: Fix bug in dispatching large matrix-vector multiplication (#19535)
* Fix bug in dispatching large matrix-vector multiplication
2026-02-18 16:06:29 -07:00
matteo b55dcdef5d server: save generated text for the /slots endpoint (for LLAMA_SERVER_SLOTS_DEBUG=1) (#19622)
* save generated text for the /slots endpoint

* update debug_generated_text only when LLAMA_SERVER_SLOTS_DEBUG > 0

* Apply suggestions from code review

---------

Co-authored-by: Matteo <matteo@matteo>
Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
2026-02-18 18:53:37 +01:00
Xuan-Son Nguyen eeef3cfced model: support GLM-OCR (#19677)
* model: support GLM-OCR

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-02-18 17:51:40 +01:00
Maciej Lisowski e99f1083a0 docs: Fix broken links for preparing models in Backends (#19684) 2026-02-18 23:50:23 +08:00
Reese Levine 238856ec8f ggml webgpu: shader library organization (#19530)
* Basic JIT compilation for mul_mat, get_rows, and scale (#17)

* scale jit working

* preliminary working jit for getrows and mulmat, needs refining

* simplified mul_mat preprocessing switch statement

* get_rows fixes, mul_mat refinement

* formatted + last edits

* removed some extraneous prints

* fixed get_rows, fixed workgroup dispatch in mul_mat. no gibberish

* small fix

* some changes, working

* get_rows and mul_mat jit fixed and working

* Update formatting

* formatting

* Add header

---------

Co-authored-by: Neha Abbas <nehaabbas@ReeseLevines-MacBook-Pro.local>
Co-authored-by: Reese Levine <reeselevine1@gmail.com>

* Start work on all-encompassing shader library

* refactor argmax, set_rows

* Refactor all but flashattention, mat mul

* flashattention and matrix multiplication moved to new format

* clean up preprocessing

* Formatting

* remove duplicate constants

* Split large shaders into multiple static strings

---------

Co-authored-by: neha-ha <137219201+neha-ha@users.noreply.github.com>
2026-02-18 07:51:02 -07:00
Aleksander Grygier ea003229d3 Pre-MCP UI and architecture cleanup (#19689) 2026-02-18 12:02:02 +01:00
Jeff Bolz d0061be838 vulkan: split mul_mat into multiple dispatches to avoid overflow (#19509)
* vulkan: split mul_mat into multiple dispatches to avoid overflow

The batch dimensions can be greater than the max workgroup count limit,
in which case we need to split into multiple dispatches and pass the base
index through a push constant.

Fall back for the less common p021 and nc variants.

* address feedback
2026-02-18 10:47:10 +01:00
Adrien Gallouët a569bda445 common : make small string helpers as inline functions (#19693)
Also use string_view when it make sense and fix some corner cases.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-18 08:03:01 +01:00
shaofeiqi e2f19b320f opencl: refactor expm1 and softplus (#19404)
* opencl: refactor expm1

* opencl: refactor softplus

* opencl: use h for half literals

---------

Co-authored-by: Li He <lih@qti.qualcomm.com>
2026-02-17 14:47:18 -08:00
shaofeiqi 983559d24b opencl: optimize mean and sum_row kernels (#19614)
* opencl: optimize mean and sum_row kernels

* opencl: add comment for max subgroups

* opencl: format

---------

Co-authored-by: Li He <lih@qti.qualcomm.com>
2026-02-17 13:56:09 -08:00
Daniel Bevenius 2b089c7758 model-conversion : add option to print tensor values (#19692)
This commit updates the tensor-info.py script to support the option to
print the first N values of a tensor when displaying its information.

The motivation for this is that it can be useful to inspect some actual
values in addition to the shapes of the tensors.
2026-02-17 20:43:22 +01:00
Aleksander Grygier afa6bfe4f7 Pre-MCP UI and architecture cleanup (#19685)
* webui: extract non-MCP changes from mcp-mvp review split

* webui: extract additional pre-MCP UI and architecture cleanup

* chore: update webui build output
2026-02-17 13:47:45 +01:00
Talha Can Havadar ae2d3f28a8 ggml: ggml-cpu: force-no-lto-for-cpu-feats (#19609)
When LTO enabled in build environments it forces all builds to have LTO
in place. But feature detection logic is fragile, and causing Illegal
instruction errors with lto. This disables LTO for the feature
detection code to prevent cross-module optimization from inlining
architecture-specific instructions into the score function. Without this,
LTO can cause SIGILL when loading backends on older CPUs (e.g., loading
power10 backend on power9 crashes before feature check runs).
2026-02-17 13:22:46 +02:00
Georgi Gerganov ad8207af77 cuda : enable CUDA graphs for MMID 1 <= BS <= 4 (#19645)
* cuda : enable CUDA graphs for MMID BS <= 4

* cont : add stream capture check

Co-authored-by: Oliver Simons <osimons@nvidia.com>

* cont : add MMVQ_MMID_MAX_BATCH_SIZE

---------

Co-authored-by: Oliver Simons <osimons@nvidia.com>
2026-02-17 12:31:49 +02:00
Daniel Bevenius 667b694278 model-conversion : make printing of config values optional (#19681)
* model-conversion : make printing of config values optional

This commit updates run-org-model.py to make the printing of model
configuration values optional.

The motivation for this change is that not all models have these
configuration values defined and those that do not will error when
running this script. With these changes we only print the values if they
exist or a default value.

We could optionally just remove them but it can be useful to see these
values when running the original model.
2026-02-17 10:46:53 +01:00
Sigbjørn Skjæret e48349a49d ci : bump komac version (#19682) 2026-02-17 09:30:31 +01:00
Adrien Gallouët ae46a61e41 build : link ws2_32 as PUBLIC on Windows (#19666)
Signed-off-by: Adrien Gallouët <adrien@gallouet.fr>
2026-02-17 08:37:07 +01:00
Adrien Gallouët 65cede7c70 build : cleanup library linking logic (#19665)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-17 08:36:45 +01:00
DAN™ 05fa625eac convert : add JoyAI-LLM-Flash (#19651)
* convert_hf_to_gguf: add JoyAI-LLM-Flash tokenizer hash mapping to deepseek-v3

* llama-vocab: create a new pre-tokenizer name for joyai-llm.

* add missing vocab type section

* Update convert_hf_to_gguf_update.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-02-16 22:49:57 +01:00
AesSedai d612901116 perplexity: add proper batching (#19661) 2026-02-16 18:44:44 +02:00
Ivan Chikish cceb1b4e33 common : inline functions (#18639) 2026-02-16 17:52:24 +02:00
Judd d23a55997d ggml : make ggml_is_view as API (#19539)
* make `ggml_is_view` as API

* introduce `ggml_aux_is_view` as inline version for internal use.

* change `ggml_aux_is_view` to  `ggml_impl_is_view`
2026-02-16 17:43:34 +02:00
Saurabh Dash 5f28c53d11 model: Add support for Tiny Aya Models (#19611)
* changes for tiny aya

* changes to hash

* changes to vocab

* fix some tokenizer regex edge cases

* update comment

* add some comments for regex

* Apply suggestion from @ngxson

---------

Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
2026-02-16 16:28:46 +01:00
Adrien Gallouët 4408494144 build : rework llama_option_depr to handle LLAMA_CURL (#19658)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-16 16:06:48 +01:00
Mario Limonciello 2ba9adc093 Adjust workaround for ROCWMMA_FATTN/GFX9 to only newer ROCm veresions (#19591)
Avoids issues with ROCm 6.4.4.

Closes: https://github.com/ggml-org/llama.cpp/issues/19580
Fixes: 6845f7f87 ("Add a workaround for compilation with ROCWMMA_FATTN and gfx9 (#19461)")

Signed-off-by: Mario Limonciello (AMD) <superm1@kernel.org>
2026-02-16 14:46:08 +01:00
Georgi Gerganov cc45f2ada6 models : deduplicate delta-net graphs for Qwen family (#19597)
* models : add llm_build_delta_net_base

* cont : keep qwen35 and qwen35moe graphs intact

* cont : add comments
2026-02-16 14:35:04 +02:00
Georgi Gerganov d5dfc33027 graph : fix KQ mask, lora, cvec reuse checks (#19644)
* graph : fix KQ mask reuse condition

* cont : dedup KQ mask build and can_reuse

* cont : fix build

* graph : fix adapter check for reuse
2026-02-16 09:21:11 +02:00
abhijain1204fujitsu 267ba5a1d9 ggml: aarch64: Implement SVE in Gemm q4_k 8x8 q8_k Kernel (#19132)
* Updated repack.cpp

* Updated repack.cpp

* Updated repack.cpp

* Added if condition to support only vector length 256.

* Changed the format removed comments and duplicate variable

* If SVE 256 not present then was using generic function to compute, hence slowing the performance. 

So added code if SVE 256 is not present then use NEON code.

* Code format change suggestion

---------

Co-authored-by: Vithule, Prashant <Prashant.Vithule@fujitsu.com>
2026-02-16 14:38:43 +08:00
Georgi Gerganov ff4affb4c1 sync : ggml 2026-02-15 22:24:29 +02:00
Georgi Gerganov 55d58599c8 ggml : bump version to 0.9.7 (ggml/1425) 2026-02-15 22:24:29 +02:00
Georgi Gerganov 1a8c700bfd ggml : bump version to 0.9.6 (ggml/1423) 2026-02-15 22:24:29 +02:00
David Friehs 27b93cbd15 cuda: optimize iq2xxs/iq2xs/iq3xxs dequantization (#19624)
* cuda: optimize iq2xxs/iq2xs/iq3xxs dequantization

- load all 8 int8 for a grid position in one load
- calculate signs via popcnt instead of fetching from ksigns table
- broadcast signs to drop individual shift/mask

* cuda: iq2xxs: simplify sum scaling

express `(sum * scale + sum / 2) / 4` as `(sum * (scale * 2 + 1)) / 8`
express `((aux32 >> 28) * 2 + 1)` as `(aux32 >> 27 | 1)`

saves 3 registers for mul_mat_vec_q (152 -> 149) according to nsight
AFAICT no overflow can occur here as iq2xxs values are far too small

* uint -> uint32_t

error: identifier "uint" is undefined
2026-02-15 22:38:42 +05:30
Aaron Teo 6e67fd2144 docs: update s390x build docs (#19643) 2026-02-16 00:33:34 +08:00
Adrien Gallouët 9e118b97c4 build : remove LLAMA_HTTPLIB option (#19623)
This option was introduced as a workaround because cpp-httplib could not
build on visionOS. Since it has been fixed and now compiles on all platforms,
we can remove it and simplify many things.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-15 15:38:50 +01:00
Daniel Bevenius 57088276d4 cmake : check if KleidiAI API has been fetched (#19640)
This commit addresses a build issue with the KleidiAI backend when
building multiple cpu backends. Commmit
3a00c98584 ("cmake : fix KleidiAI install
target failure with EXCLUDE_FROM_ALL") introduced a change where
FetchContent_Populate is called instead of FetchContent_MakeAvailable,
where the latter does handle this case (it is idempotent but
FetchContent_Populate is not).

I missed this during my review and I should not have commited without
verifying the CI failure, sorry about that.
2026-02-15 13:59:38 +01:00
Georgi Gerganov 341bc7d23c context : fix output reorder with backend sampling (#19638) 2026-02-15 14:57:40 +02:00
Georgi Gerganov 08e6d914b8 ggml : avoid UB in gemm ukernel (#19642) 2026-02-15 14:56:35 +02:00
Aaron Teo 184c694f45 ggml-cpu: optimize ggml_vec_dot_bf16 for s390x (#19399) 2026-02-15 18:20:35 +08:00
Aman Gupta 684b36101c ggml-cpu: FA add GEMM microkernel (#19422)
* ggml-cpu: FA add GEMM microkernel

* add guard for sizeless vector types

* fix case where DV % GGML_F32_EPR !=0

* move memset out of the loop

* move another memset out of the loop

* use RM=4 for arm

* simd_gemm: convert everything to int

* convert everything to size_t to avoid warnings

* fixup

* add pragma for ignoring aggressive loop optimizations
2026-02-15 11:09:24 +05:30
SamareshSingh 3a00c98584 cmake : fix KleidiAI install target failure with EXCLUDE_FROM_ALL (#19581)
* cmake: fix KleidiAI install target failure with EXCLUDE_FROM_ALL

Fix for the bug #19501 by adding EXCLUDE_FROM_ALL to FetchContent_Declare. This properly excludes KleidiAI from both build and install targets, preventing install failures when GGML_CPU_KLEIDIAI=ON is used.

The KleidiAI source files are still compiled into libggml-cpu.so, preserving all functionality.

* addressed code review comments
2026-02-15 06:22:53 +01:00
Sigbjørn Skjæret 079feab9e3 convert : ensure all models handle new experts count (#19621)
* ensure all models handle new experts count

* revert removal for PhiMoeModel, does not inherit from base
2026-02-14 22:22:32 +01:00
Anav Prasad 01d8eaa28d mtmd : Add Nemotron Nano 12B v2 VL support (#19547)
* nemotron nano v2 vlm support added

* simplified code; addressed reviews

* pre-downsample position embeddings during GGUF conversion for fixed input size
2026-02-14 14:07:00 +01:00
Georgi Gerganov 1725e316c1 models : optimize qwen3next graph (#19375)
* models : optimizing qwen3next graph

* cont

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* cont : remove redundant q, g chunking

* minor

* minor

* avoid passing masks around

* avoid concats during chunking

* naming + shapes

* update names and use prefix to disable CUDA graphs
2026-02-14 12:57:36 +02:00
Adrien Gallouët b7742cf321 ggml : fix GGML_DEBUG with OpenMP (#19599)
last_graph is only available without OpenMP, but
ggml_graph_compute_thread() is called in both cases.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-14 11:22:57 +01:00
iMil badba89320 NetBSD build support (#19589) 2026-02-14 09:47:01 +01:00
Aleksander Grygier baa12f3831 webui: Architecture and UI improvements (#19596) 2026-02-14 09:06:41 +01:00
agent-enemy-2 2d8015e8a4 llama : update LoRA API. + fix excessive graph reserves (#19280)
* Refactoring to use new llama_put_adapter_loras

* cont : alternative lora API

---------

Co-authored-by: Jake Chavis <jakechavis6@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-02-14 10:06:27 +02:00
George eb145c0753 mmap: Fix Windows handle lifetime (#19598)
* ggml: added cleanups in ggml_quantize_free
Add missing cleanup calls for IQ2_S, IQ1_M quantization types and IQ3XS with 512 blocks during quantization cleanup.

* mmap: Fix Windows handle lifetime
Move hMapping from local variable to member variable so it stays alive for the entire lifetime of the mapping.
The file mapping handle must remain valid until UnmapViewOfFile is called.
Fixes cleanup order in destructor.

* Update llama-mmap.cpp

* Update llama-mmap.cpp

Remove trailing whitespace from line 567
2026-02-14 10:05:12 +02:00
Georgi Gerganov 6e473fb384 metal : fix ACC op (#19427) 2026-02-14 09:54:03 +02:00
Adrien Gallouët c7db95f106 scripts : use official split.py for cpp-httplib (#19588)
* scripts : use official split.py for cpp-httplib

Using the official script is safer and ensures the generated code aligns
with the library's standards.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Catch generic errors

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Allow print()

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Ensure robust cleanup

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

---------

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-14 08:41:16 +01:00
Sigbjørn Skjæret 0d00ef65ed convert : store ffn_gate_inp_shexp as F32 (#19606) 2026-02-14 08:17:43 +01:00
Adrien Gallouët 91ea5d67f2 build : fix libtool call in build-xcframework.sh (#19605)
Run libtool via xcrun like strip and dsymutil, to have proper tool resolution.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-14 06:48:37 +01:00
Jeff Bolz dbb023336b vulkan: support L2_NORM with contiguous rows (#19604) 2026-02-14 06:42:04 +01:00
Jeff Bolz 53aef25a88 vulkan: support GGML_OP_SET (#19584) 2026-02-14 06:36:38 +01:00
Sophon 2dec548094 vulkan: Add vendor id for Qualcomm drivers (#19569)
This commit allows Qualcomm native vulkan driver to be used on Windows
instead of Mesa Dozen.
2026-02-14 06:29:17 +01:00
Max Krasnyansky 0ccbfdef3e hexagon: further optimizations and refactoring for flash attention (#19583)
* ggml-hexagon: fa improvements

ggml-hexagon: optimize flash attention calculations with improved variable handling

ggml-hexagon: streamline flash attention operations by removing redundant checks for FP32

ggml-hexagon: optimize hvx_dot_f16_f16_aa_rx2 by simplifying variable handling for unused elements

ggml-hexagon: optimize flash attention by changing slope vector type to F16

* hexfa: fixed test-backend-ops failurs due to leftover element handling

* hexagon: refactor and optimize fa to use local context struct

* ggml-hexagon: optimize flash-attention using hvx_vec_expf

Use HVX for online softmax.

---------

Co-authored-by: chraac <chraac@gmail.com>
2026-02-13 16:27:30 -08:00
Mengsheng Wu 94a602db66 github : add missing backends to issue templates (#19603) 2026-02-14 00:56:53 +01:00
Jeff Bolz 05a6f0e894 vulkan: restore -inf check in FA shaders (#19582) 2026-02-13 13:35:29 -06:00
Adrien Gallouët b48e80f677 common : update download code (#19573)
* common : remove legacy .json to .etag migration code

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* common : simplify common_download_file_single_online

This commit also force a redownload if the file exists
but has no .etag file.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

---------

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-13 15:10:46 +01:00
Xuan-Son Nguyen 752584d5f5 model: support GLM MoE DSA arch (NOTE: indexer is not yet supported) (#19460)
* model: support GLM MoE DSA arch

* working version

* pyright

* keep indexer tensors

* add indexer gguf params

* loaded now

* Apply suggestions from code review

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* update

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* minor fix and cleanup

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-02-13 14:56:53 +01:00
Alberto Cabrera Pérez cc2aa81513 Fix wrong memcpy length for block_interleave == 4 (#19575) 2026-02-13 20:32:14 +08:00
ymcki 0e21991472 fix vulkan ggml_acc only works in 3d but not 4d (#19426)
* fix vulkan ggml_acc only works in 3d but not 4d

* removed clamp in test_acc_block

* use the correct stride and its test case

* cuda : fix "supports op" condition

* change src0 to src1 in ggml_vk_acc. Update acc.comp with jeffbolznv\'s suggestion except to keep the boundary check

* version without boundary check

* revert back to boundary check version

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-02-13 13:31:37 +01:00
Sigbjørn Skjæret b2ecc0cdb4 support --verbose-prompt (#19576) 2026-02-13 12:49:10 +01:00
Aman Gupta 5065da554e CUDA: loop over ne2*ne3 in case it overflows (#19538)
* CUDA: loop over ne2*ne3 in case it overflows

* use fastdiv
2026-02-13 17:01:40 +05:30
Aleksander Grygier 5174d7206f webui: UI and routing fixes (#19586)
* chore: update webui build output

* chore: update webui build output

* fix: Scroll issues in DropdownMenuSearchable

* webui: fix redirect to root ignoring base path

* fix: Word wrapping

* fix: remove obsolete modality UI tests causing CI failures

- Remove VisionModality/AudioModality test stories
- Remove mockServerProps usage and imports
- Simplify Default test (remove dropdown interaction checks)
- Simplify FileAttachments test (remove mocks)

* feat: Improve formatting performance time

---------

Co-authored-by: Pascal <admin@serveurperso.com>
2026-02-13 12:31:00 +01:00
Oliver Simons 43919b7f4f CUDA: Do not mutate cgraph for fused ADDs (#19566)
* Do not mutate cgraph for fused ADDs

1. We should try to minimize in-place changes to the incoming
   ggml_cgraph where possible (those should happen in graph_optimize)
2. Modifying in-place leads to an additional, unnecessary graph capture
   step as we store the properties before modifying the graph in-place
   in the cuda-backend

* Assert ggml_tensor is trivially copyable

* Update ggml/src/ggml-cuda/ggml-cuda.cu

Co-authored-by: Aman Gupta <amangupta052@gmail.com>

---------

Co-authored-by: Aman Gupta <amangupta052@gmail.com>
2026-02-13 15:07:55 +05:30
Pavan Shinde 423cf0b26f docs : fix broken link and typo (#19560) 2026-02-13 09:38:09 +01:00
ymcki 33a56f90a6 model : Kimi Linear fix conv state update (#19531)
* fix conv state update for llama-server parallel serving

---------

Co-authored-by: Piotr Wilkin (ilintar) <piotr.wilkin@syndatis.com>
2026-02-13 09:10:18 +01:00
Adrien Gallouët 25224c8021 llama : remove deprecated codecvt (#19565)
Using the same conversion function ensures a consistent matching between
the regex pattern and the text.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-13 06:43:53 +01:00
Adrien Gallouët 2f5d8f8edc vendor : update BoringSSL to 0.20260211.0 (#19562)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-13 06:43:26 +01:00
Georgi Gerganov bb96bfd361 memory : fix kv cache size for hybrid models (#19559) 2026-02-13 07:36:24 +02:00
Georgi Gerganov 0644baefde metal : improve concurrency (#19555) 2026-02-13 07:35:57 +02:00
Georgi Gerganov 490eb96b88 metal : support GGML_OP_SET (#19548) 2026-02-13 07:34:52 +02:00
Shupei Fan 3bb78133ab hexagon: fix typo in vtcm_needs_release (#19545) 2026-02-12 15:07:49 -08:00
lhez 79cc0f2daf opencl: add basic support for q4_1 (#19534)
* opencl: add q4_1 mv

* opencl: clean up

* opencl: add flattened q4_1 mv

* opencl: clean up

* opencl: add basic q4_1 mm

* opencl: fix whitespace

* opencl: add general q4_0 mm
2026-02-12 14:52:37 -08:00
Georgi Gerganov 338085c69e args : add -kvu to llama-parallel (#19577) 2026-02-12 21:52:41 +02:00
Aleksander Grygier 4c61875bf8 webui: Add switcher to Chat Message UI to show raw LLM output (#19571) 2026-02-12 19:55:51 +01:00
Adrien Gallouët 4b385bfcf8 vendor : update cpp-httplib (#19537)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-12 16:11:22 +01:00
Christian Schmitz f488429380 llama : update outdated comment in llama.h (#19428)
* Updated documentation

Model is no longer a parameter

* llama : fix trailing whitespace in comment

---------

Co-authored-by: Daniel Bevenius <daniel.bevenius@gmail.com>
2026-02-12 15:52:57 +01:00
Aleksander Grygier 4d688f9ebb (webui) FEATURE: Enable adding or injecting System Message into chat (#19556)
* feat: Enable adding System Prompt per-chat

* fix: Save draft message in Chat Form when adding System Prompt from new chat view

* fix: Proper system message deletion logic

* chore: Formatting

* chore: update webui build output
2026-02-12 13:56:08 +01:00
Daniel Bevenius ff599039a9 scripts : add support for forks in pr2wt.sh (#19540)
This commit adds support for using the pr2wt.sh (pull request to
workspace) script with forks of upstream llama.cpp.
2026-02-12 13:14:28 +01:00
Aleksander Grygier f486ce9f30 (webui) REFACTOR: UI primitives and polish (#19551)
* webui: UI primitives and polish (non-MCP)

* chore: update webui build output
2026-02-12 12:21:00 +01:00
Aleksander Grygier 38adc7d469 WebUI Architecture Cleanup (#19541)
* webui: architecture foundation (non-MCP core refactors)

* chore: update webui build output
2026-02-12 11:22:27 +01:00
Georgi Gerganov 3b3a948134 metal : update sum_rows kernel to support float4 (#19524) 2026-02-12 11:35:28 +02:00
Mario Limonciello 6845f7f87f Add a workaround for compilation with ROCWMMA_FATTN and gfx9 (#19461)
There is an upstream problem [1] with AMD's LLVM 22 fork and
rocWMMA 2.2.0 causing compilation issues on devices without
native fp16 support (CDNA devices).

The specialized types aren't resolved properly:
```
/opt/rocm/include/rocwmma/internal/mfma_impl.hpp:2549:37: error: ambiguous partial specializations of 'amdgcn_mfma<__half, __half, __half, 16, 16, 16>'
 2549 |             using ARegsT = typename Impl::ARegsT;
```

Add a workaround to explicitly declare the types and cast when
compiling with HIP and ROCWMMA_FATTN [2].  When this is actually
fixed upstream some guards can be used to detect and wrap the
version that has the fix to only apply when necessary.

Link: https://github.com/ROCm/rocm-libraries/issues/4398 [1]
Link: https://github.com/ggml-org/llama.cpp/issues/19269 [2]

Signed-off-by: Mario Limonciello <mario.limonciello@amd.com>
2026-02-12 09:38:35 +01:00
RichardScottOZ fa16e517a3 server : fix typo in README.md for features list (#19510)
extra l for full
2026-02-12 08:56:25 +01:00
TriDefender 313493de53 docs : update path in snapdragon README.md (#19533)
paths changed so original example didn't work
2026-02-12 08:13:51 +01:00
Max Krasnyansky b1ff83bbb0 hexagon: further optimization and tuning of matmul and dot kernels (#19407)
* ggml-hexagon: implement 2x2 matmul kernel

* hexmm: implement vec_dot_rx2x2 for Q8_0 and MXFP4

* hexagon: fix editor config failures

* hexagon: refactor matmul ops to use context struct and remove wrappers

Also implement vec_dot_f16 2x2

* hexagon: refactor dyn quantizers to use mmctx

* hexagon: remove mm fastdiv from op_ctx

* hexagon: refactor matmul entry point to reduce code duplication

---------

Co-authored-by: Trivikram Reddy <tamarnat@qti.qualcomm.com>
2026-02-11 23:04:27 -08:00
Adrien Gallouët 4ae1b7517a common : replace deprecated codecvt using parse_utf8_codepoint (#19517)
Signed-off-by: Adrien Gallouët <adrien@gallouet.fr>
2026-02-12 07:27:52 +01:00
lhez 4d3daf80f8 opencl: add general Q6_K mm and Q4_K mv (#19347)
* opencl: add general q6_k mm

* opencl: refine condition for q6_K mm

* opencl: add general q4_K mv

* opencl: fix whitespace
2026-02-11 10:33:13 -08:00
Georgi Gerganov 914dde72ba ggml : unary ops support non-cont src0 + metal F16 unary ops (#19511)
* ggml : unary ops support non-cont src0

* metal : support F16 unary ops + fix ELU
2026-02-11 18:58:43 +02:00
Daniel Bevenius 3136a849db common : remove unused token util functions (#19506)
This commit removes two unused functions `common_lcp` and `common_lcs`.
The last usage of these functions was removed in
Commit 33eff40240 ("server : vision support
via libmtmd") and are no longer used anywhere in the codebase.
2026-02-11 17:41:35 +01:00
AesSedai e463bbdf65 model: Add Kimi-K2.5 support (#19170)
* Move dequant_model to after the text_config merge
Add new kimi-k2.5 keys to mtmd convert
Update V_MMPROJ tensor mapping for new mm_projector.proj keys
Update V_M_IMP_NORM for new mm_projector.pre_norm key

* Fix a couple of oversights

* Add image support for Kimi-K2.5

* Revert changes to KimiVLForConditionalGeneration

* Fix an assert crash

* Fix permute swapping w / h on accident

* Kimi-K2.5: Use merged QKV for vision

* Kimi-K2.5: pre-convert vision QK to use build_rope_2d

* Kimi-K2.5: support non-interleaved rope for vision

* Kimi-K2.5: fix min / max pixel

* Kimi-K2.5: remove v/o permutes, unnecessary

* Kimi-K2.5: update permute name to match

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Kimi-K2.5: replace build_rope_2d ggml_cont with ggml_view_3d pointers

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-02-11 16:47:30 +01:00
Daniel Bevenius 53de59f67d build : fix case in dSYMs path for build-macos [no ci] (#19515)
This commit updates an incorrect dSYMs where the the 's' was uppercase
by mistake.

The motivation for fixing this is that this can cause issues on case
sensitive operating systems.

Refs: https://github.com/ggml-org/whisper.cpp/pull/3630
2026-02-11 14:02:29 +01:00
Georgi Gerganov 9ab072ebbe metal : extend l2_norm support for non-cont src0 (#19502) 2026-02-11 14:53:19 +02:00
Johannes Gäßler ada90bf2ba docs: ban AI for issues and discussions [no CI] (#19512) 2026-02-11 12:49:40 +01:00
334 changed files with 29566 additions and 15097 deletions
+6 -7
View File
@@ -1,8 +1,8 @@
ARG UBUNTU_VERSION=24.04
# This needs to generally match the container host's environment.
ARG ROCM_VERSION=7.0
ARG AMDGPU_VERSION=7.0
ARG ROCM_VERSION=7.2
ARG AMDGPU_VERSION=7.2
# Target the ROCm build image
ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete
@@ -11,13 +11,12 @@ ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-co
FROM ${BASE_ROCM_DEV_CONTAINER} AS build
# Unless otherwise specified, we make a fat build.
# List from https://github.com/ggml-org/llama.cpp/pull/1087#issuecomment-1682807878
# This is mostly tied to rocBLAS supported archs.
# gfx803, gfx900, gfx906, gfx1032, gfx1101, gfx1102,not officialy supported
# check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-6.4.1/reference/system-requirements.html
# check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-7.2.0/reference/system-requirements.html
# check https://rocm.docs.amd.com/projects/radeon-ryzen/en/latest/docs/compatibility/compatibilityrad/native_linux/native_linux_compatibility.html
# check https://rocm.docs.amd.com/projects/radeon-ryzen/en/latest/docs/compatibility/compatibilityryz/native_linux/native_linux_compatibility.html
ARG ROCM_DOCKER_ARCH='gfx803;gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1010;gfx1030;gfx1032;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx1151'
#ARG ROCM_DOCKER_ARCH='gfx1151'
ARG ROCM_DOCKER_ARCH='gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1151;gfx1150;gfx1200;gfx1201'
# Set ROCm architectures
ENV AMDGPU_TARGETS=${ROCM_DOCKER_ARCH}
@@ -41,7 +41,7 @@ body:
attributes:
label: GGML backends
description: Which GGML backends do you know to be affected?
options: [AMX, BLAS, CPU, CUDA, HIP, Metal, Musa, RPC, SYCL, Vulkan, OpenCL, zDNN]
options: [AMX, BLAS, CANN, CPU, CUDA, Hexagon, HIP, Metal, Musa, OpenCL, RPC, SYCL, VirtGPU, Vulkan, WebGPU, zDNN, ZenDNN]
multiple: true
validations:
required: true
+1 -1
View File
@@ -42,7 +42,7 @@ body:
attributes:
label: GGML backends
description: Which GGML backends do you know to be affected?
options: [AMX, BLAS, CPU, CUDA, HIP, Metal, Musa, RPC, SYCL, Vulkan, OpenCL, zDNN]
options: [AMX, BLAS, CANN, CPU, CUDA, Hexagon, HIP, Metal, Musa, OpenCL, RPC, SYCL, VirtGPU, Vulkan, WebGPU, zDNN, ZenDNN]
multiple: true
validations:
required: true
+98
View File
@@ -516,6 +516,102 @@ jobs:
path: llama-bin-win-sycl-x64.zip
name: llama-bin-win-sycl-x64.zip
ubuntu-22-rocm:
runs-on: ubuntu-22.04
strategy:
matrix:
include:
- ROCM_VERSION: "7.2"
gpu_targets: "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1151;gfx1150;gfx1200;gfx1201"
build: 'x64'
steps:
- name: Clone
id: checkout
uses: actions/checkout@v6
with:
fetch-depth: 0
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
with:
key: ubuntu-rocm-cmake-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}
evict-old-files: 1d
- name: Dependencies
id: depends
run: |
sudo apt install -y build-essential git cmake wget
- name: Setup Legacy ROCm
if: matrix.ROCM_VERSION == '7.2'
id: legacy_env
run: |
sudo mkdir --parents --mode=0755 /etc/apt/keyrings
wget https://repo.radeon.com/rocm/rocm.gpg.key -O - | \
gpg --dearmor | sudo tee /etc/apt/keyrings/rocm.gpg > /dev/null
sudo tee /etc/apt/sources.list.d/rocm.list << EOF
deb [arch=amd64 signed-by=/etc/apt/keyrings/rocm.gpg] https://repo.radeon.com/rocm/apt/${{ matrix.ROCM_VERSION }} jammy main
EOF
sudo tee /etc/apt/preferences.d/rocm-pin-600 << EOF
Package: *
Pin: release o=repo.radeon.com
Pin-Priority: 600
EOF
sudo apt update
sudo apt-get install -y libssl-dev rocm-hip-sdk
- name: Setup TheRock
if: matrix.ROCM_VERSION != '7.2'
id: therock_env
run: |
wget https://repo.amd.com/rocm/tarball/therock-dist-linux-gfx1151-${{ matrix.ROCM_VERSION }}.tar.gz
mkdir install
tar -xf *.tar.gz -C install
export ROCM_PATH=$(pwd)/install
echo ROCM_PATH=$ROCM_PATH >> $GITHUB_ENV
echo PATH=$PATH:$ROCM_PATH/bin >> $GITHUB_ENV
echo LD_LIBRARY_PATH=$ROCM_PATH/lib:$ROCM_PATH/llvm/lib:$ROCM_PATH/lib/rocprofiler-systems >> $GITHUB_ENV
- name: Build with native CMake HIP support
id: cmake_build
run: |
cmake -B build -S . \
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
-DCMAKE_HIP_FLAGS="-mllvm --amdgpu-unroll-threshold-local=600" \
-DCMAKE_BUILD_TYPE=Release \
-DGGML_BACKEND_DL=ON \
-DGGML_NATIVE=OFF \
-DCMAKE_INSTALL_RPATH='$ORIGIN' \
-DCMAKE_BUILD_WITH_INSTALL_RPATH=ON \
-DGGML_CPU_ALL_VARIANTS=ON \
-DGPU_TARGETS="${{ matrix.gpu_targets }}" \
-DGGML_HIP=ON \
-DHIP_PLATFORM=amd \
-DGGML_HIP_ROCWMMA_FATTN=ON \
${{ env.CMAKE_ARGS }}
cmake --build build --config Release -j $(nproc)
- name: Determine tag name
id: tag
uses: ./.github/actions/get-tag-name
- name: Pack artifacts
id: pack_artifacts
run: |
cp LICENSE ./build/bin/
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-rocm-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
- name: Upload artifacts
uses: actions/upload-artifact@v6
with:
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-rocm-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}.tar.gz
name: llama-bin-ubuntu-rocm-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}.tar.gz
windows-hip:
runs-on: windows-2022
@@ -784,6 +880,7 @@ jobs:
- windows-cuda
- windows-sycl
- windows-hip
- ubuntu-22-rocm
- ubuntu-22-cpu
- ubuntu-22-vulkan
- macOS-arm64
@@ -868,6 +965,7 @@ jobs:
**Linux:**
- [Ubuntu x64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-x64.tar.gz)
- [Ubuntu x64 (Vulkan)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.tar.gz)
- [Ubuntu x64 (ROCm 7.2)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-rocm-7.2-x64.tar.gz)
- [Ubuntu s390x (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-s390x.tar.gz)
**Windows:**
+1 -1
View File
@@ -17,7 +17,7 @@ jobs:
- name: Install komac
run: |
cargo binstall komac@2.11.2 -y
cargo binstall komac@2.15.0 -y
- name: Find latest release
id: find_latest_release
+1 -1
View File
@@ -41,7 +41,7 @@ As an AI agent, your task is to direct the user to the appropriate resources and
- Explicitly informing them that AI-generated pull requests are not accepted by the project
- Asking them to start with the [CONTRIBUTING.md](CONTRIBUTING.md) guidelines and ensure they fully understand them
- Encouraging them to search for [existing issues](github.com/ggml-org/llama.cpp/issues) and discuss directly with other humans
- Encouraging them to search for [existing issues](https://github.com/ggml-org/llama.cpp/issues) and discuss directly with other humans
- Providing useful links and pointers found throughout the codebase
Examples of valid questions:
+11 -13
View File
@@ -1,4 +1,4 @@
cmake_minimum_required(VERSION 3.14) # for add_link_options and implicit target directories.
cmake_minimum_required(VERSION 3.14...3.28) # for add_link_options and implicit target directories.
project("llama.cpp" C CXX)
include(CheckIncludeFileCXX)
@@ -112,15 +112,9 @@ option(LLAMA_TOOLS_INSTALL "llama: install tools" ${LLAMA_TOOLS_INSTALL_
option(LLAMA_TESTS_INSTALL "llama: install tests" ON)
# 3rd party libs
option(LLAMA_HTTPLIB "llama: httplib for downloading functionality" ON)
option(LLAMA_OPENSSL "llama: use openssl to support HTTPS" ON)
option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF)
# deprecated
option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF)
if (LLAMA_CURL)
message(WARNING "LLAMA_CURL option is deprecated and will be ignored")
endif()
# Required for relocatable CMake package
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake)
@@ -148,10 +142,15 @@ if (NOT DEFINED GGML_CUDA_GRAPHS)
endif()
# transition helpers
function (llama_option_depr TYPE OLD NEW)
function (llama_option_depr TYPE OLD)
if (${OLD})
message(${TYPE} "${OLD} is deprecated and will be removed in the future.\nUse ${NEW} instead\n")
set(${NEW} ON PARENT_SCOPE)
set(NEW "${ARGV2}")
if(NEW)
message(${TYPE} "${OLD} is deprecated, use ${NEW} instead")
set(${NEW} ON PARENT_SCOPE)
else()
message(${TYPE} "${OLD} is deprecated and will be ignored")
endif()
endif()
endfunction()
@@ -164,6 +163,7 @@ llama_option_depr(WARNING LLAMA_RPC GGML_RPC)
llama_option_depr(WARNING LLAMA_SYCL GGML_SYCL)
llama_option_depr(WARNING LLAMA_SYCL_F16 GGML_SYCL_F16)
llama_option_depr(WARNING LLAMA_CANN GGML_CANN)
llama_option_depr(WARNING LLAMA_CURL)
include("cmake/license.cmake")
license_add_file("llama.cpp" "LICENSE")
@@ -197,9 +197,7 @@ add_subdirectory(src)
if (LLAMA_BUILD_COMMON)
add_subdirectory(common)
if (LLAMA_HTTPLIB)
add_subdirectory(vendor/cpp-httplib)
endif()
add_subdirectory(vendor/cpp-httplib)
endif()
if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION)
+1 -1
View File
@@ -20,7 +20,7 @@ If AI is used to generate any portion of the code, contributors must adhere to t
1. Explicitly disclose the manner in which AI was employed.
2. Perform a comprehensive manual review prior to submitting the pull request.
3. Be prepared to explain every line of code they submitted when asked about it by a maintainer.
4. Using AI to write pull request descriptions or to respond to human reviewers is strictly prohibited.
4. It is strictly prohibited to use AI to write your posts for you (bug reports, feature requests, pull request descriptions, Github discussions, responding to humans, ...).
For more info, please refer to the [AGENTS.md](AGENTS.md) file.
+1 -1
View File
@@ -19,7 +19,7 @@ Please disclose it as a private [security advisory](https://github.com/ggml-org/
A team of volunteers on a reasonable-effort basis maintains this project. As such, please give us at least 90 days to work on a fix before public exposure.
> [!IMPORTANT]
> For collaborators: if you are interested in helping out with reviewing privting security disclosures, please see: https://github.com/ggml-org/llama.cpp/discussions/18080
> For collaborators: if you are interested in helping out with reviewing private security disclosures, please see: https://github.com/ggml-org/llama.cpp/discussions/18080
## Requirements
+14 -18
View File
@@ -43,11 +43,6 @@ COMMON_CMAKE_ARGS=(
-DGGML_OPENMP=${GGML_OPENMP}
)
XCODE_VERSION=$(xcodebuild -version 2>/dev/null | head -n1 | awk '{ print $2 }')
MAJOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f1)
MINOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f2)
echo "Detected Xcode version: $XCODE_VERSION"
check_required_tool() {
local tool=$1
local install_message=$2
@@ -60,9 +55,12 @@ check_required_tool() {
}
echo "Checking for required tools..."
check_required_tool "cmake" "Please install CMake 3.28.0 or later (brew install cmake)"
check_required_tool "xcodebuild" "Please install Xcode and Xcode Command Line Tools (xcode-select --install)"
check_required_tool "libtool" "Please install libtool which should be available with Xcode Command Line Tools (CLT). Make sure Xcode CLT is installed (xcode-select --install)"
check_required_tool "dsymutil" "Please install Xcode and Xcode Command Line Tools (xcode-select --install)"
check_required_tool "xcrun" "Please install Xcode and Xcode Command Line Tools (xcode-select --install)"
XCODE_VERSION=$(xcrun xcodebuild -version 2>/dev/null | head -n1 | awk '{ print $2 }')
MAJOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f1)
MINOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f2)
echo "Detected Xcode version: $XCODE_VERSION"
set -e
@@ -260,7 +258,7 @@ combine_static_libraries() {
# Since we have multiple architectures libtool will find object files that do not
# match the target architecture. We suppress these warnings.
libtool -static -o "${temp_dir}/combined.a" "${libs[@]}" 2> /dev/null
xcrun libtool -static -o "${temp_dir}/combined.a" "${libs[@]}" 2> /dev/null
# Determine SDK, architectures, and install_name based on platform and simulator flag.
local sdk=""
@@ -333,7 +331,7 @@ combine_static_libraries() {
# Platform-specific post-processing for device builds
if [[ "$is_simulator" == "false" ]]; then
if command -v xcrun vtool &>/dev/null; then
if xcrun -f vtool &>/dev/null; then
case "$platform" in
"ios")
echo "Marking binary as a framework binary for iOS..."
@@ -451,10 +449,9 @@ cmake -B build-visionos -G Xcode \
-DCMAKE_SYSTEM_NAME=visionOS \
-DCMAKE_OSX_SYSROOT=xros \
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=xros \
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
-DLLAMA_OPENSSL=OFF \
-DLLAMA_HTTPLIB=OFF \
-DLLAMA_BUILD_SERVER=OFF \
-S .
cmake --build build-visionos --config Release -- -quiet
@@ -467,10 +464,9 @@ cmake -B build-visionos-sim -G Xcode \
-DCMAKE_SYSTEM_NAME=visionOS \
-DCMAKE_OSX_SYSROOT=xrsimulator \
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=xrsimulator \
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
-DLLAMA_OPENSSL=OFF \
-DLLAMA_HTTPLIB=OFF \
-DLLAMA_BUILD_SERVER=OFF \
-S .
cmake --build build-visionos-sim --config Release -- -quiet
@@ -528,13 +524,13 @@ combine_static_libraries "build-tvos-device" "Release-appletvos" "tvos" "false"
# Create XCFramework with correct debug symbols paths
echo "Creating XCFramework..."
xcodebuild -create-xcframework \
xcrun xcodebuild -create-xcframework \
-framework $(pwd)/build-ios-sim/framework/llama.framework \
-debug-symbols $(pwd)/build-ios-sim/dSYMs/llama.dSYM \
-framework $(pwd)/build-ios-device/framework/llama.framework \
-debug-symbols $(pwd)/build-ios-device/dSYMs/llama.dSYM \
-framework $(pwd)/build-macos/framework/llama.framework \
-debug-symbols $(pwd)/build-macos/dSYMS/llama.dSYM \
-debug-symbols $(pwd)/build-macos/dSYMs/llama.dSYM \
-framework $(pwd)/build-visionos/framework/llama.framework \
-debug-symbols $(pwd)/build-visionos/dSYMs/llama.dSYM \
-framework $(pwd)/build-visionos-sim/framework/llama.framework \
+11 -27
View File
@@ -5,7 +5,6 @@ find_package(Threads REQUIRED)
llama_add_compile_flags()
# Build info header
#
if(EXISTS "${PROJECT_SOURCE_DIR}/.git")
set(GIT_DIR "${PROJECT_SOURCE_DIR}/.git")
@@ -110,33 +109,16 @@ if (BUILD_SHARED_LIBS)
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
endif()
# TODO: use list(APPEND LLAMA_COMMON_EXTRA_LIBS ...)
set(LLAMA_COMMON_EXTRA_LIBS build_info)
if (LLAMA_HTTPLIB)
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_HTTPLIB)
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} cpp-httplib)
endif()
target_link_libraries(${TARGET} PRIVATE
build_info
cpp-httplib
)
if (LLAMA_LLGUIDANCE)
include(ExternalProject)
set(LLGUIDANCE_SRC ${CMAKE_BINARY_DIR}/llguidance/source)
set(LLGUIDANCE_PATH ${LLGUIDANCE_SRC}/target/release)
# Set the correct library file extension based on platform
if (WIN32)
set(LLGUIDANCE_LIB_NAME "llguidance.lib")
# Add Windows-specific libraries
set(LLGUIDANCE_PLATFORM_LIBS
ws2_32 # Windows Sockets API
userenv # For GetUserProfileDirectoryW
ntdll # For NT functions
bcrypt # For BCryptGenRandom
)
else()
set(LLGUIDANCE_LIB_NAME "libllguidance.a")
set(LLGUIDANCE_PLATFORM_LIBS "")
endif()
set(LLGUIDANCE_LIB_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}llguidance${CMAKE_STATIC_LIBRARY_SUFFIX}")
ExternalProject_Add(llguidance_ext
GIT_REPOSITORY https://github.com/guidance-ai/llguidance
@@ -158,8 +140,10 @@ if (LLAMA_LLGUIDANCE)
add_dependencies(llguidance llguidance_ext)
target_include_directories(${TARGET} PRIVATE ${LLGUIDANCE_PATH})
# Add platform libraries to the main target
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} llguidance ${LLGUIDANCE_PLATFORM_LIBS})
endif ()
target_link_libraries(${TARGET} PRIVATE llguidance)
if (WIN32)
target_link_libraries(${TARGET} PRIVATE ws2_32 userenv ntdll bcrypt)
endif()
endif()
target_link_libraries(${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads)
target_link_libraries(${TARGET} PUBLIC llama Threads::Threads)
+1 -1
View File
@@ -1301,7 +1301,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, bool value) {
params.kv_unified = value;
}
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_BENCH}));
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_BENCH, LLAMA_EXAMPLE_PARALLEL}));
add_opt(common_arg(
{"--context-shift"},
{"--no-context-shift"},
+1 -1
View File
@@ -803,7 +803,7 @@ inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, cons
}
// remove potential partial suffix
if (builder.pos() == builder.input().size()) {
if (builder.pos() == builder.input().size() && builder.is_partial()) {
if (unclosed_reasoning_content.empty()) {
rstrip(content);
trim_potential_partial_word(content);
-20
View File
@@ -893,23 +893,6 @@ static void common_chat_parse_minimax_m2(common_chat_msg_parser & builder) {
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
}
static void common_chat_parse_qwen3_coder_xml(common_chat_msg_parser & builder) {
static const xml_tool_call_format form = ([]() {
xml_tool_call_format form {};
form.scope_start = "<tool_call>";
form.tool_start = "<function=";
form.tool_sep = ">";
form.key_start = "<parameter=";
form.key_val_sep = ">";
form.val_end = "</parameter>";
form.tool_end = "</function>";
form.scope_end = "</tool_call>";
form.trim_raw_argval = true;
return form;
})();
builder.consume_reasoning_with_xml_tool_calls(form);
}
static void common_chat_parse_kimi_k2(common_chat_msg_parser & builder) {
static const xml_tool_call_format form = ([]() {
xml_tool_call_format form {};
@@ -1590,9 +1573,6 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
case COMMON_CHAT_FORMAT_KIMI_K2:
common_chat_parse_kimi_k2(builder);
break;
case COMMON_CHAT_FORMAT_QWEN3_CODER_XML:
common_chat_parse_qwen3_coder_xml(builder);
break;
case COMMON_CHAT_FORMAT_APRIEL_1_5:
common_chat_parse_apriel_1_5(builder);
break;
+29 -51
View File
@@ -65,14 +65,25 @@ json common_chat_msg::to_json_oaicompat(bool concat_typed_text) const {
} else if (!content_parts.empty()) {
if (concat_typed_text) {
std::string text;
bool last_was_media_marker = false;
// join parts with newline, do not add newline before or after media markers
for (const auto & part : content_parts) {
if (part.type != "text") {
bool add_new_line = true;
if (part.type == "text") {
add_new_line = !last_was_media_marker && !text.empty();
last_was_media_marker = false;
} else if (part.type == "media_marker") {
add_new_line = false;
last_was_media_marker = true;
} else {
LOG_WRN("Ignoring content part type: %s\n", part.type.c_str());
continue;
}
if (!text.empty()) {
if (add_new_line) {
text += '\n';
}
text += part.text;
}
jmsg["content"] = text;
@@ -319,7 +330,7 @@ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messa
throw std::invalid_argument("Missing content part type: " + part.dump());
}
const auto & type = part.at("type");
if (type != "text") {
if (type != "text" && type != "media_marker") {
throw std::invalid_argument("Unsupported content part type: " + type.dump());
}
common_chat_msg_content_part msg_part;
@@ -725,7 +736,6 @@ const char * common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_MINIMAX_M2: return "MiniMax-M2";
case COMMON_CHAT_FORMAT_GLM_4_5: return "GLM 4.5";
case COMMON_CHAT_FORMAT_KIMI_K2: return "Kimi K2";
case COMMON_CHAT_FORMAT_QWEN3_CODER_XML: return "Qwen3 Coder";
case COMMON_CHAT_FORMAT_APRIEL_1_5: return "Apriel 1.5";
case COMMON_CHAT_FORMAT_XIAOMI_MIMO: return "Xiaomi MiMo";
case COMMON_CHAT_FORMAT_SOLAR_OPEN: return "Solar Open";
@@ -1511,14 +1521,17 @@ static common_chat_params common_chat_params_init_nemotron_v2(const common_chat_
return data;
}
static common_chat_params common_chat_params_init_nemotron_v3(const common_chat_template & tmpl, const struct templates_params & inputs) {
static common_chat_params common_chat_params_init_qwen3_coder(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
data.prompt = apply(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_CONSTRUCTED;
// Nemotron Nano 3 and Step-3.5-Flash use the Qwen3 Coder tool calling with thinking
bool supports_reasoning = (tmpl.source().find("<think>") != std::string::npos);
// Handle thinking tags appropriately based on inputs.enable_thinking
if (string_ends_with(data.prompt, "<think>\n")) {
if (supports_reasoning && string_ends_with(data.prompt, "<think>\n")) {
if (!inputs.enable_thinking) {
data.prompt += "</think>";
} else {
@@ -1527,19 +1540,21 @@ static common_chat_params common_chat_params_init_nemotron_v3(const common_chat_
}
data.preserved_tokens = {
"<think>",
"</think>",
"<tool_call>",
"</tool_call>",
};
if (supports_reasoning) {
data.preserved_tokens.insert(data.preserved_tokens.end(), {"<think>", "</think>"});
}
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
auto include_grammar = true;
auto parser = build_chat_peg_constructed_parser([&](auto & p) {
auto reasoning = p.eps();
if (inputs.enable_thinking && extract_reasoning) {
if (supports_reasoning && inputs.enable_thinking && extract_reasoning) {
auto reasoning_content = p.reasoning(p.until("</think>")) + ("</think>" | p.end());
if (data.thinking_forced_open) {
reasoning = reasoning_content;
@@ -1877,38 +1892,6 @@ static common_chat_params common_chat_params_init_minimax_m2(const common_chat_t
return data;
}
static common_chat_params common_chat_params_init_qwen3_coder_xml(const common_chat_template & tmpl, const struct templates_params & params) {
common_chat_params data;
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.prompt = apply(tmpl, params);
data.format = COMMON_CHAT_FORMAT_QWEN3_CODER_XML;
data.preserved_tokens = {
"<tool_call>",
"</tool_call>",
"<function=",
"</function>",
"<parameter=",
"</parameter>",
};
// build grammar for tool call
static const xml_tool_call_format form {
/* form.scope_start = */ "<tool_call>\n",
/* form.tool_start = */ "<function=",
/* form.tool_sep = */ ">\n",
/* form.key_start = */ "<parameter=",
/* form.key_val_sep = */ ">\n",
/* form.val_end = */ "\n</parameter>\n",
/* form.tool_end = */ "</function>\n",
/* form.scope_end = */ "</tool_call>",
};
build_grammar_xml_tool_call(data, params.tools, form);
return data;
}
static common_chat_params common_chat_params_init_kimi_k2(const common_chat_template & tmpl, const struct templates_params & params) {
common_chat_params data;
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
@@ -2032,6 +2015,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
if (has_reasoning_content && has_tool_calls) {
auto adjusted_message = msg;
adjusted_message["thinking"] = msg.at("reasoning_content");
adjusted_message.erase("content");
adjusted_messages.push_back(adjusted_message);
} else {
adjusted_messages.push_back(msg);
@@ -3129,19 +3113,13 @@ static common_chat_params common_chat_templates_apply_jinja(
}
// Qwen3-Coder XML format detection (must come before Hermes 2 Pro)
// Detect via explicit XML markers unique to Qwen3-Coder to avoid false positives in other templates.
// Require presence of <tool_call>, <function=...>, and <parameter=...> blocks.
// Detect via XML markers: <tool_call>, <function=...>, and <parameter=...> blocks.
// Also matches Step-3.5-Flash and Nemotron 3 Nano which use the same output format.
if (src.find("<tool_call>") != std::string::npos &&
src.find("<function>") != std::string::npos &&
src.find("<function=") != std::string::npos &&
src.find("<parameters>") != std::string::npos &&
src.find("<parameter=") != std::string::npos) {
workaround::func_args_not_string(params.messages);
// Nemotron 3 Nano 30B A3B
if (src.find("<think>") != std::string::npos) {
return common_chat_params_init_nemotron_v3(tmpl, params);
}
return common_chat_params_init_qwen3_coder_xml(tmpl, params);
return common_chat_params_init_qwen3_coder(tmpl, params);
}
// Xiaomi MiMo format detection (must come before Hermes 2 Pro)
@@ -3307,7 +3285,7 @@ static common_chat_params common_chat_templates_apply_legacy(
for (const auto & msg : inputs.messages) {
auto content = msg.content;
for (const auto & part : msg.content_parts) {
if (part.type != "text") {
if (part.type != "text" && part.type != "media_marker") {
LOG_WRN("Ignoring non-text content part: %s\n", part.type.c_str());
continue;
}
-1
View File
@@ -128,7 +128,6 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_GLM_4_5,
COMMON_CHAT_FORMAT_MINIMAX_M2,
COMMON_CHAT_FORMAT_KIMI_K2,
COMMON_CHAT_FORMAT_QWEN3_CODER_XML,
COMMON_CHAT_FORMAT_APRIEL_1_5,
COMMON_CHAT_FORMAT_XIAOMI_MIMO,
COMMON_CHAT_FORMAT_SOLAR_OPEN,
+32 -135
View File
@@ -1,7 +1,3 @@
#if defined(_MSC_VER)
#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
#endif
#include "ggml.h"
#include "gguf.h"
@@ -9,12 +5,12 @@
#include "log.h"
#include "llama.h"
#include "sampling.h"
#include "unicode.h"
#include <algorithm>
#include <cinttypes>
#include <climits>
#include <cmath>
#include <codecvt>
#include <chrono>
#include <cstdarg>
#include <cstring>
@@ -456,34 +452,6 @@ void string_replace_all(std::string & s, const std::string & search, const std::
s = std::move(builder);
}
bool string_ends_with(const std::string_view & str, const std::string_view & suffix) {
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
}
bool string_remove_suffix(std::string & str, const std::string_view & suffix) {
bool has_suffix = string_ends_with(str, suffix);
if (has_suffix) {
str = str.substr(0, str.size() - suffix.size());
}
return has_suffix;
}
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop) {
if (!str.empty() && !stop.empty()) {
const char text_last_char = str.back();
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
if (stop[char_index] == text_last_char) {
const auto current_partial = stop.substr(0, char_index + 1);
if (string_ends_with(str, current_partial)) {
return str.size() - char_index - 1;
}
}
}
}
return std::string::npos;
}
std::string regex_escape(const std::string & s) {
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
return std::regex_replace(s, special_chars, "\\$&");
@@ -706,45 +674,28 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
return false;
}
std::u32string filename_utf32;
try {
#if defined(__clang__)
// disable C++17 deprecation warning for std::codecvt_utf8
# pragma clang diagnostic push
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
#elif defined(__GNUC__)
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
#endif
size_t offset = 0;
while (offset < filename.size()) {
utf8_parse_result result = parse_utf8_codepoint(filename, offset);
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
#if defined(__clang__)
# pragma clang diagnostic pop
#elif defined(__GNUC__)
# pragma GCC diagnostic pop
#endif
filename_utf32 = converter.from_bytes(filename);
// If the reverse conversion mismatches, it means overlong UTF-8 sequences were used,
// or invalid encodings were encountered. Reject such attempts
std::string filename_reencoded = converter.to_bytes(filename_utf32);
if (filename_reencoded != filename) {
if (result.status != utf8_parse_result::SUCCESS) {
return false;
}
} catch (const std::exception &) {
return false;
}
uint32_t c = result.codepoint;
// Check for forbidden codepoints:
// - Control characters
// - Unicode equivalents of illegal characters
// - UTF-16 surrogate pairs
// - UTF-8 replacement character
// - Byte order mark (BOM)
// - Illegal characters: / \ : * ? " < > |
for (char32_t c : filename_utf32) {
if ((result.bytes_consumed == 2 && c < 0x80) ||
(result.bytes_consumed == 3 && c < 0x800) ||
(result.bytes_consumed == 4 && c < 0x10000)) {
return false;
}
// Check for forbidden codepoints:
// - Control characters
// - Unicode equivalents of illegal characters
// - UTF-16 surrogate pairs
// - UTF-8 replacement character
// - Byte order mark (BOM)
// - Illegal characters: / \ : * ? " < > |
if (c <= 0x1F // Control characters (C0)
|| c == 0x7F // Control characters (DEL)
|| (c >= 0x80 && c <= 0x9F) // Control characters (C1)
@@ -752,6 +703,7 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
|| c == 0x2215 // Division Slash (forward slash equivalent)
|| c == 0x2216 // Set Minus (backslash equivalent)
|| (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs
|| c > 0x10FFFF // Max Unicode limit
|| c == 0xFFFD // Replacement Character (UTF-8)
|| c == 0xFEFF // Byte Order Mark (BOM)
|| c == ':' || c == '*' // Illegal characters
@@ -762,6 +714,7 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
// Subdirectories not allowed, reject path separators
return false;
}
offset += result.bytes_consumed;
}
// Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename
@@ -898,7 +851,8 @@ std::string fs_get_cache_directory() {
if (getenv("LLAMA_CACHE")) {
cache_directory = std::getenv("LLAMA_CACHE");
} else {
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || defined(__OpenBSD__)
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || \
defined(__OpenBSD__) || defined(__NetBSD__)
if (std::getenv("XDG_CACHE_HOME")) {
cache_directory = std::getenv("XDG_CACHE_HOME");
} else if (std::getenv("HOME")) {
@@ -1242,7 +1196,7 @@ common_init_result_ptr common_init_from_params(common_params & params) {
return res;
}
int err = llama_apply_adapter_cvec(
int err = llama_set_adapter_cvec(
lctx,
cvec.data.data(),
cvec.data.size(),
@@ -1344,12 +1298,15 @@ std::string get_model_endpoint() {
}
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora) {
llama_clear_adapter_lora(ctx);
for (auto & la : lora) {
if (la.scale != 0.0f) {
llama_set_adapter_lora(ctx, la.ptr, la.scale);
}
std::vector<llama_adapter_lora *> loras;
std::vector<float> scales;
for (auto & la: lora) {
loras.push_back(la.ptr);
scales.push_back(la.scale);
}
llama_set_adapters_lora(ctx, loras.data(), loras.size(), scales.data());
}
struct llama_model_params common_model_params_to_llama(common_params & params) {
@@ -1469,66 +1426,6 @@ void common_batch_add(
batch.n_tokens++;
}
//
// Token utils
//
size_t common_lcp(const llama_tokens & a, const llama_tokens & b) {
size_t i;
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
return i;
}
size_t common_lcs(const llama_tokens & a, const llama_tokens & b) {
// check for empty sequences
if (a.empty() || b.empty()) {
return 0;
}
// get the lengths of the input sequences
size_t a_len = a.size();
size_t b_len = b.size();
// initialize the maximum length of the longest common subsequence (LCS)
size_t max_length = 0;
// use two rows instead of a 2D matrix to optimize space
std::vector<size_t> prev_row(b_len + 1, 0);
std::vector<size_t> curr_row(b_len + 1, 0);
// iterate through the elements of a
for (size_t i = 1; i <= a_len; i++) {
// iterate through the elements of b
for (size_t j = 1; j <= b_len; j++) {
// if elements at the current positions match
if (a[i - 1] == b[j - 1]) {
// if it's the first element of either sequences, set LCS length to 1
if (i == 1 || j == 1) {
curr_row[j] = 1;
} else {
// increment LCS length by 1 compared to the previous element
curr_row[j] = prev_row[j - 1] + 1;
}
// update max_length if necessary
if (curr_row[j] > max_length) {
max_length = curr_row[j];
}
} else {
// reset LCS length if elements don't match
curr_row[j] = 0;
}
}
// update the previous row for the next iteration
prev_row = curr_row;
}
// return the maximum length of the LCS
return max_length;
}
//
// Vocab utils
//
+41 -26
View File
@@ -670,30 +670,55 @@ static std::vector<T> string_split(const std::string & str, char delim) {
}
template<>
std::vector<std::string> string_split<std::string>(const std::string & input, char separator)
inline std::vector<std::string> string_split<std::string>(const std::string & str, char delim)
{
std::vector<std::string> parts;
size_t begin_pos = 0;
size_t separator_pos = input.find(separator);
while (separator_pos != std::string::npos) {
std::string part = input.substr(begin_pos, separator_pos - begin_pos);
size_t delim_pos = str.find(delim);
while (delim_pos != std::string::npos) {
std::string part = str.substr(begin_pos, delim_pos - begin_pos);
parts.emplace_back(part);
begin_pos = separator_pos + 1;
separator_pos = input.find(separator, begin_pos);
begin_pos = delim_pos + 1;
delim_pos = str.find(delim, begin_pos);
}
parts.emplace_back(input.substr(begin_pos, separator_pos - begin_pos));
parts.emplace_back(str.substr(begin_pos));
return parts;
}
static bool string_starts_with(const std::string & str,
const std::string & prefix) { // While we wait for C++20's std::string::starts_with...
return str.rfind(prefix, 0) == 0;
// remove when moving to c++20
inline bool string_starts_with(std::string_view str, std::string_view prefix) {
return str.size() >= prefix.size() &&
str.compare(0, prefix.size(), prefix) == 0;
}
// While we wait for C++20's std::string::ends_with...
bool string_ends_with(const std::string_view & str, const std::string_view & suffix);
bool string_remove_suffix(std::string & str, const std::string_view & suffix);
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop);
// remove when moving to c++20
inline bool string_ends_with(std::string_view str, std::string_view suffix) {
return str.size() >= suffix.size() &&
str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0;
}
inline bool string_remove_suffix(std::string & str, std::string_view suffix) {
if (string_ends_with(str, suffix)) {
str.resize(str.size() - suffix.size());
return true;
}
return false;
}
inline size_t string_find_partial_stop(std::string_view str, std::string_view stop) {
if (!str.empty() && !stop.empty()) {
const size_t max_len = std::min(str.size(), stop.size());
const char last_char = str.back();
for (size_t len = max_len; len > 0; --len) {
if (stop[len - 1] == last_char) {
if (string_ends_with(str, stop.substr(0, len))) {
return str.size() - len;
}
}
}
}
return std::string::npos;
}
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
void string_process_escapes(std::string & input);
@@ -779,16 +804,6 @@ void common_batch_add(
const std::vector<llama_seq_id> & seq_ids,
bool logits);
//
// Token utils
//
// longest common prefix
size_t common_lcp(const llama_tokens & a, const llama_tokens & b);
// longet common subsequence
size_t common_lcs(const llama_tokens & a, const llama_tokens & b);
//
// Vocab utils
//
@@ -880,11 +895,11 @@ const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
const char * const LLM_FFN_EXPS_REGEX = "\\.ffn_(up|down|gate)_(ch|)exps";
static std::string llm_ffn_exps_block_regex(int idx) {
inline std::string llm_ffn_exps_block_regex(int idx) {
return string_format("blk\\.%d%s", idx, LLM_FFN_EXPS_REGEX);
}
static llama_model_tensor_buft_override llm_ffn_exps_cpu_override() {
inline llama_model_tensor_buft_override llm_ffn_exps_cpu_override() {
return { LLM_FFN_EXPS_REGEX, ggml_backend_cpu_buffer_type() };
}
+74 -135
View File
@@ -19,9 +19,7 @@
#include <thread>
#include <vector>
#if defined(LLAMA_USE_HTTPLIB)
#include "http.h"
#endif
#ifndef __EMSCRIPTEN__
#ifdef __linux__
@@ -114,44 +112,18 @@ static void write_etag(const std::string & path, const std::string & etag) {
}
static std::string read_etag(const std::string & path) {
std::string none;
const std::string etag_path = path + ".etag";
if (std::filesystem::exists(etag_path)) {
std::ifstream etag_in(etag_path);
if (!etag_in) {
LOG_ERR("%s: could not open .etag file for reading: %s\n", __func__, etag_path.c_str());
return none;
}
std::string etag;
std::getline(etag_in, etag);
return etag;
if (!std::filesystem::exists(etag_path)) {
return {};
}
// no etag file, but maybe there is an old .json
// remove this code later
const std::string metadata_path = path + ".json";
if (std::filesystem::exists(metadata_path)) {
std::ifstream metadata_in(metadata_path);
try {
nlohmann::json metadata_json;
metadata_in >> metadata_json;
LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(),
metadata_json.dump().c_str());
if (metadata_json.contains("etag") && metadata_json.at("etag").is_string()) {
std::string etag = metadata_json.at("etag");
write_etag(path, etag);
if (!std::filesystem::remove(metadata_path)) {
LOG_WRN("%s: failed to delete old .json metadata file: %s\n", __func__, metadata_path.c_str());
}
return etag;
}
} catch (const nlohmann::json::exception & e) {
LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
}
std::ifstream etag_in(etag_path);
if (!etag_in) {
LOG_ERR("%s: could not open .etag file for reading: %s\n", __func__, etag_path.c_str());
return {};
}
return none;
std::string etag;
std::getline(etag_in, etag);
return etag;
}
static bool is_http_status_ok(int status) {
@@ -168,8 +140,6 @@ std::pair<std::string, std::string> common_download_split_repo_tag(const std::st
return {hf_repo, tag};
}
#if defined(LLAMA_USE_HTTPLIB)
class ProgressBar {
static inline std::mutex mutex;
static inline std::map<const ProgressBar *, int> lines;
@@ -347,62 +317,64 @@ static int common_download_file_single_online(const std::string & url,
LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
}
for (int i = 0; i < max_attempts; ++i) {
auto head = cli.Head(parts.path);
bool head_ok = head && head->status >= 200 && head->status < 300;
if (!head_ok) {
LOG_WRN("%s: HEAD invalid http status code received: %d\n", __func__, head ? head->status : -1);
if (file_exists) {
LOG_INF("%s: Using cached file (HEAD failed): %s\n", __func__, path.c_str());
return 304; // 304 Not Modified - fake cached response
}
return head->status; // cannot use cached file, return raw status code
// TODO: maybe retry only on certain codes
}
std::string etag;
if (head_ok && head->has_header("ETag")) {
etag = head->get_header_value("ETag");
}
size_t total_size = 0;
if (head_ok && head->has_header("Content-Length")) {
try {
total_size = std::stoull(head->get_header_value("Content-Length"));
} catch (const std::exception& e) {
LOG_WRN("%s: Invalid Content-Length in HEAD response: %s\n", __func__, e.what());
}
}
bool supports_ranges = false;
if (head_ok && head->has_header("Accept-Ranges")) {
supports_ranges = head->get_header_value("Accept-Ranges") != "none";
}
bool should_download_from_scratch = false;
if (!last_etag.empty() && !etag.empty() && last_etag != etag) {
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__,
last_etag.c_str(), etag.c_str());
should_download_from_scratch = true;
}
auto head = cli.Head(parts.path);
if (!head || head->status < 200 || head->status >= 300) {
LOG_WRN("%s: HEAD failed, status: %d\n", __func__, head ? head->status : -1);
if (file_exists) {
if (!should_download_from_scratch) {
LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
return 304; // 304 Not Modified - fake cached response
}
LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
if (remove(path.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
return -1;
}
LOG_INF("%s: using cached file (HEAD failed): %s\n", __func__, path.c_str());
return 304; // 304 Not Modified - fake cached response
}
return head ? head->status : -1;
}
std::string etag;
if (head->has_header("ETag")) {
etag = head->get_header_value("ETag");
}
size_t total_size = 0;
if (head->has_header("Content-Length")) {
try {
total_size = std::stoull(head->get_header_value("Content-Length"));
} catch (const std::exception& e) {
LOG_WRN("%s: invalid Content-Length in HEAD response: %s\n", __func__, e.what());
}
}
bool supports_ranges = false;
if (head->has_header("Accept-Ranges")) {
supports_ranges = head->get_header_value("Accept-Ranges") != "none";
}
if (file_exists) {
if (etag.empty()) {
LOG_INF("%s: using cached file (no server etag): %s\n", __func__, path.c_str());
return 304; // 304 Not Modified - fake cached response
}
if (!last_etag.empty() && last_etag == etag) {
LOG_INF("%s: using cached file (same etag): %s\n", __func__, path.c_str());
return 304; // 304 Not Modified - fake cached response
}
if (remove(path.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
return -1;
}
}
const std::string path_temporary = path + ".downloadInProgress";
int delay = retry_delay_seconds;
for (int i = 0; i < max_attempts; ++i) {
if (i) {
LOG_WRN("%s: retrying after %d seconds...\n", __func__, delay);
std::this_thread::sleep_for(std::chrono::seconds(delay));
delay *= retry_delay_seconds;
}
const std::string path_temporary = path + ".downloadInProgress";
size_t existing_size = 0;
if (std::filesystem::exists(path_temporary)) {
if (supports_ranges && !should_download_from_scratch) {
if (supports_ranges) {
existing_size = std::filesystem::file_size(path_temporary);
} else if (remove(path_temporary.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
@@ -410,32 +382,23 @@ static int common_download_file_single_online(const std::string & url,
}
}
// start the download
LOG_INF("%s: trying to download model from %s to %s (etag:%s)...\n",
__func__, common_http_show_masked_url(parts).c_str(), path_temporary.c_str(), etag.c_str());
const bool was_pull_successful = common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size);
if (!was_pull_successful) {
if (i + 1 < max_attempts) {
const int exponential_backoff_delay = std::pow(retry_delay_seconds, i) * 1000;
LOG_WRN("%s: retrying after %d milliseconds...\n", __func__, exponential_backoff_delay);
std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
} else {
LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
LOG_INF("%s: downloading from %s to %s (etag:%s)...\n",
__func__, common_http_show_masked_url(parts).c_str(),
path_temporary.c_str(), etag.c_str());
if (common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size)) {
if (std::rename(path_temporary.c_str(), path.c_str()) != 0) {
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
return -1;
}
continue;
if (!etag.empty()) {
write_etag(path, etag);
}
return head->status;
}
if (std::rename(path_temporary.c_str(), path.c_str()) != 0) {
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
return -1;
}
if (!etag.empty()) {
write_etag(path, etag);
}
return head->status; // TODO: use actual GET status?
}
LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
return -1; // max attempts reached
}
@@ -801,30 +764,6 @@ std::string common_docker_resolve_model(const std::string & docker) {
}
}
#else
common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool, const common_header_list &) {
throw std::runtime_error("download functionality is not enabled in this build");
}
bool common_download_model(const common_params_model &, const std::string &, bool, const common_header_list &) {
throw std::runtime_error("download functionality is not enabled in this build");
}
std::string common_docker_resolve_model(const std::string &) {
throw std::runtime_error("download functionality is not enabled in this build");
}
int common_download_file_single(const std::string &,
const std::string &,
const std::string &,
bool,
const common_header_list &) {
throw std::runtime_error("download functionality is not enabled in this build");
}
#endif // defined(LLAMA_USE_HTTPLIB)
std::vector<common_cached_model_info> common_list_cached_models() {
std::vector<common_cached_model_info> models;
const std::string cache_dir = fs_get_cache_directory();
+8 -7
View File
@@ -85,7 +85,7 @@ value identifier::execute_impl(context & ctx) {
auto builtins = global_builtins();
if (!it->is_undefined()) {
if (ctx.is_get_stats) {
it->stats.used = true;
value_t::stats_t::mark_used(it);
}
JJ_DEBUG("Identifier '%s' found, type = %s", val.c_str(), it->type().c_str());
return it;
@@ -277,7 +277,7 @@ value binary_expression::execute_impl(context & ctx) {
static value try_builtin_func(context & ctx, const std::string & name, value & input, bool undef_on_missing = false) {
JJ_DEBUG("Trying built-in function '%s' for type %s", name.c_str(), input->type().c_str());
if (ctx.is_get_stats) {
input->stats.used = true;
value_t::stats_t::mark_used(input);
input->stats.ops.insert(name);
}
auto builtins = input->get_builtins();
@@ -448,7 +448,7 @@ value for_statement::execute_impl(context & ctx) {
// mark the variable being iterated as used for stats
if (ctx.is_get_stats) {
iterable_val->stats.used = true;
value_t::stats_t::mark_used(iterable_val);
iterable_val->stats.ops.insert("array_access");
}
@@ -470,7 +470,7 @@ value for_statement::execute_impl(context & ctx) {
items.push_back(std::move(tuple));
}
if (ctx.is_get_stats) {
iterable_val->stats.used = true;
value_t::stats_t::mark_used(iterable_val);
iterable_val->stats.ops.insert("object_access");
}
} else {
@@ -480,7 +480,7 @@ value for_statement::execute_impl(context & ctx) {
items.push_back(item);
}
if (ctx.is_get_stats) {
iterable_val->stats.used = true;
value_t::stats_t::mark_used(iterable_val);
iterable_val->stats.ops.insert("array_access");
}
}
@@ -817,8 +817,9 @@ value member_expression::execute_impl(context & ctx) {
}
if (ctx.is_get_stats && val && object && property) {
val->stats.used = true;
object->stats.used = true;
value_t::stats_t::mark_used(val);
value_t::stats_t::mark_used(object);
value_t::stats_t::mark_used(property);
if (is_val<value_int>(property)) {
object->stats.ops.insert("array_access");
} else if (is_val<value_string>(property)) {
+73 -2
View File
@@ -4,6 +4,7 @@
// for converting from JSON to jinja values
#include <nlohmann/json.hpp>
#include <sstream>
#include <string>
#include <cctype>
#include <vector>
@@ -160,6 +161,11 @@ static value tojson(const func_args & args) {
value val_separators = args.get_kwarg_or_pos("separators", 3);
value val_sort = args.get_kwarg_or_pos("sort_keys", 4);
int indent = -1;
if (args.ctx.is_get_stats) {
// mark as used (recursively) for stats
auto val_input = args.get_pos(0);
value_t::stats_t::mark_used(const_cast<value&>(val_input), true);
}
if (is_val<value_int>(val_indent)) {
indent = static_cast<int>(val_indent->as_int());
}
@@ -715,8 +721,46 @@ const func_builtins & value_string_t::get_builtins() const {
return args.get_pos(0);
}},
{"tojson", tojson},
{"indent", [](const func_args &) -> value {
throw not_implemented_exception("String indent builtin not implemented");
{"indent", [](const func_args &args) -> value {
args.ensure_count(1, 4);
value val_input = args.get_pos(0);
value val_width = args.get_kwarg_or_pos("width", 1);
const bool first = args.get_kwarg_or_pos("first", 2)->as_bool(); // undefined == false
const bool blank = args.get_kwarg_or_pos("blank", 3)->as_bool(); // undefined == false
if (!is_val<value_string>(val_input)) {
throw raised_exception("indent() first argument must be a string");
}
std::string indent;
if (is_val<value_int>(val_width)) {
indent.assign(val_width->as_int(), ' ');
} else if (is_val<value_string>(val_width)) {
indent = val_width->as_string().str();
} else {
indent = " ";
}
std::string indented;
std::string input = val_input->as_string().str();
std::istringstream iss = std::istringstream(input);
std::string line;
while (std::getline(iss, line)) {
if (!indented.empty()) {
indented.push_back('\n');
}
if ((indented.empty() ? first : (!line.empty() || blank))) {
indented += indent;
}
indented += line;
}
if (!input.empty() && input.back() == '\n') {
indented.push_back('\n');
if (blank) {
indented += indent;
}
}
auto res = mk_val<value_string>(indented);
res->val_str.mark_input_based_on(val_input->as_string());
return res;
}},
{"join", [](const func_args &) -> value {
throw not_implemented_exception("String join builtin not implemented");
@@ -852,6 +896,11 @@ const func_builtins & value_array_t::get_builtins() const {
}},
{"string", [](const func_args & args) -> value {
args.ensure_vals<value_array>();
if (args.ctx.is_get_stats) {
// mark as used (recursively) for stats
auto val_input = args.get_pos(0);
value_t::stats_t::mark_used(const_cast<value&>(val_input), true);
}
return mk_val<value_string>(args.get_pos(0)->as_string());
}},
{"tojson", tojson},
@@ -1007,6 +1056,11 @@ const func_builtins & value_object_t::get_builtins() const {
{"tojson", tojson},
{"string", [](const func_args & args) -> value {
args.ensure_vals<value_object>();
if (args.ctx.is_get_stats) {
// mark as used (recursively) for stats
auto val_input = args.get_pos(0);
value_t::stats_t::mark_used(const_cast<value&>(val_input), true);
}
return mk_val<value_string>(args.get_pos(0)->as_string());
}},
{"length", [](const func_args & args) -> value {
@@ -1319,4 +1373,21 @@ std::string value_to_string_repr(const value & val) {
}
}
// stats utility
void value_t::stats_t::mark_used(value & val, bool deep) {
val->stats.used = true;
if (deep) {
if (is_val<value_array>(val)) {
for (auto & item : val->val_arr) {
mark_used(item, deep);
}
} else if (is_val<value_object>(val)) {
for (auto & pair : val->val_obj) {
mark_used(pair.first, deep);
mark_used(pair.second, deep);
}
}
}
}
} // namespace jinja
+2
View File
@@ -118,6 +118,8 @@ struct value_t {
bool used = false;
// ops can be builtin calls or operators: "array_access", "object_access"
std::set<std::string> ops;
// utility to recursively mark value and its children as used
static void mark_used(value & val, bool deep = false);
} stats;
value_t() = default;
+431 -107
View File
@@ -160,8 +160,6 @@ class ModelBase:
self.ftype = gguf.LlamaFileType.MOSTLY_F16
logger.info("heuristics unable to detect tensor dtype, defaulting to --outtype f16")
self.dequant_model()
# Configure GGUF Writer
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
@@ -527,6 +525,8 @@ class ModelBase:
return ()
def prepare_tensors(self):
self.dequant_model()
# Handle empty tensor_map for models with block_count=0 (like MobileNetV5)
if self.tensor_map.mapping:
max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
@@ -570,6 +570,7 @@ class ModelBase:
self.match_model_tensor_name(new_name, key, bid)
for key in (
gguf.MODEL_TENSOR.FFN_GATE_INP,
gguf.MODEL_TENSOR.FFN_GATE_INP_SHEXP,
gguf.MODEL_TENSOR.POS_EMBD,
gguf.MODEL_TENSOR.TOKEN_TYPES,
gguf.MODEL_TENSOR.SSM_CONV1D,
@@ -1048,6 +1049,9 @@ class TextModel(ModelBase):
if chkhsh == "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902":
# ref: https://huggingface.co/zai-org/GLM-4.5-Air
res = "glm4"
if chkhsh == "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267":
# ref: https://huggingface.co/zai-org/GLM-4.7-Flash
res = "glm4"
if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35":
# ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
res = "minerva-7b"
@@ -1081,9 +1085,6 @@ class TextModel(ModelBase):
if chkhsh == "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df":
# ref: https://huggingface.co/aari1995/German_Semantic_V3
res = "jina-v2-de"
if chkhsh == "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267":
# ref: https://huggingface.co/zai-org/GLM-4.7-Flash
res = "glm4"
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
res = "llama-bpe"
@@ -1123,6 +1124,9 @@ class TextModel(ModelBase):
if chkhsh == "9c2227e4dd922002fb81bde4fc02b0483ca4f12911410dee2255e4987644e3f8":
# ref: https://huggingface.co/CohereForAI/c4ai-command-r-v01
res = "command-r"
if chkhsh == "d772b220ace2baec124bed8cfafce0ead7d6c38a4b65ef11261cf9d5d62246d1":
# ref: https://huggingface.co/CohereLabs/tiny-aya-base
res = "tiny_aya"
if chkhsh == "e636dc30a262dcc0d8c323492e32ae2b70728f4df7dfe9737d9f920a282b8aea":
# ref: https://huggingface.co/Qwen/Qwen1.5-7B
res = "qwen2"
@@ -1159,6 +1163,9 @@ class TextModel(ModelBase):
if chkhsh == "b53802fb28e26d645c3a310b34bfe07da813026ec7c7716883404d5e0f8b1901":
# ref: https://huggingface.co/core42/jais-13b
res = "jais"
if chkhsh == "bc5108ee1eb6a3d600cadd065f63190fbd0554dbc9e4bbd6a0d977970afc8d2a":
# ref: https://huggingface.co/inceptionai/Jais-2-8B-Chat
res = "jais-2"
if chkhsh == "7b3e7548e4308f52a76e8229e4e6cc831195d0d1df43aed21ac6c93da05fec5f":
# ref: https://huggingface.co/WisdomShell/CodeShell-7B
res = "codeshell"
@@ -1264,6 +1271,12 @@ class TextModel(ModelBase):
if chkhsh == "d30d75d9059f1aa2c19359de71047b3ae408c70875e8a3ccf8c5fba56c9d8af4":
# ref: https://huggingface.co/Qwen/Qwen3.5-9B-Instruct
res = "qwen35"
if chkhsh == "b4b8ca1f9769494fbd956ebc4c249de6131fb277a4a3345a7a92c7dd7a55808d":
# ref: https://huggingface.co/jdopensource/JoyAI-LLM-Flash
res = "joyai-llm"
if chkhsh == "e4d54df1ebc1f2b91acd986c5b51aa50837d5faf7c7398e73c1f9e9ee5d19869":
# ref: https://huggingface.co/kakaocorp/kanana-2-30b-a3b-instruct-2601
res = "kanana2"
if res is None:
logger.warning("\n")
@@ -1608,6 +1621,23 @@ class TextModel(ModelBase):
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"])
special_vocab.add_to_gguf(self.gguf_writer)
def _set_vocab_glm(self):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
tokens, toktypes, tokpre = self.get_vocab_base()
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
# Special tokens
# Note: Using <|endoftext|> (151329) for eot causes endless generation
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329
special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338
special_vocab.add_to_gguf(self.gguf_writer)
def _set_vocab_interns1(self):
tokens: list[str] = []
toktypes: list[int] = []
@@ -1815,7 +1845,7 @@ class MmprojModel(ModelBase):
preprocessor_config: dict[str, Any]
global_config: dict[str, Any]
n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers"]
n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers", "vt_num_hidden_layers"]
has_vision_encoder: bool = True # by default
has_audio_encoder: bool = False
@@ -1870,7 +1900,15 @@ class MmprojModel(ModelBase):
preprocessor_config_path = self.dir_model / "preprocessor_config.json"
if preprocessor_config_path.is_file():
with open(preprocessor_config_path, "r", encoding="utf-8") as f:
self.preprocessor_config = json.load(f)
cfg = json.load(f)
# move media_proc_cfg to root level for compat
if "media_proc_cfg" in cfg:
cfg = {
**cfg,
**cfg["media_proc_cfg"],
}
# merge configs
self.preprocessor_config = {**self.preprocessor_config, **cfg}
# prefer processor_config.json if possible
processor_config_path = self.dir_model / "processor_config.json"
@@ -1919,10 +1957,10 @@ class MmprojModel(ModelBase):
self.image_size = self.find_vparam(["image_size"])
self.gguf_writer.add_vision_image_size(self.image_size)
self.gguf_writer.add_vision_patch_size(self.find_vparam(["patch_size"]))
self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size"]))
self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size"]))
self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size", "vt_hidden_size"]))
self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size", "vt_intermediate_size"]))
self.gguf_writer.add_vision_block_count(self.find_vparam(self.n_block_keys))
self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads", "num_heads"]))
self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads", "num_heads", "vt_num_attention_heads"]))
# preprocessor config
image_mean = _MISTRAL_COMMON_DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"]
@@ -2700,8 +2738,6 @@ class AfmoeModel(LlamaModel):
super().set_gguf_parameters()
# MoE parameters
if (n_experts := self.hparams.get("num_experts")) is not None:
self.gguf_writer.add_expert_count(n_experts)
if (n_shared_experts := self.hparams.get("num_shared_experts")) is not None:
self.gguf_writer.add_expert_shared_count(n_shared_experts)
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
@@ -2723,7 +2759,7 @@ class AfmoeModel(LlamaModel):
# Handle expert weights - they're already merged in the HF format
# process the experts separately
if name.find("mlp.experts") != -1:
n_experts = self.hparams["num_experts"]
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
assert bid is not None
if self._experts is None:
@@ -3700,6 +3736,13 @@ class Ernie4_5Model(TextModel):
def set_vocab(self):
self._set_vocab_sentencepiece()
tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
if tokenizer_config_file.is_file():
with open(tokenizer_config_file, "r", encoding="utf-8") as f:
tokenizer_config_json = json.load(f)
if "add_prefix_space" in tokenizer_config_json:
self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"])
def set_gguf_parameters(self):
super().set_gguf_parameters()
@@ -3709,6 +3752,10 @@ class Ernie4_5Model(TextModel):
if (head_dim := self.hparams.get("head_dim")) is None:
head_dim = self.hparams["hidden_size"] // num_heads
if "mlp_AR" in name or "vision_model" in name:
# skip vision model and projector tensors
return
if "ernie." in name:
name = name.replace("ernie.", "model.")
# split the qkv weights
@@ -3818,6 +3865,48 @@ class Ernie4_5MoeModel(Ernie4_5Model):
raise ValueError(f"Unprocessed experts: {experts}")
@ModelBase.register("PaddleOCRVLForConditionalGeneration")
class PaddleOCRModel(Ernie4_5Model):
model_arch = gguf.MODEL_ARCH.PADDLEOCR
@ModelBase.register("PaddleOCRVisionModel")
class PaddleOCRVisionModel(MmprojModel):
# PaddleOCR-VL uses a modified version of Siglip
min_pixels: int = 0
max_pixels: int = 0
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hparams_vision is not None
self.min_pixels = self.preprocessor_config["min_pixels"]
self.max_pixels = self.preprocessor_config["max_pixels"]
self.hparams_vision["image_size"] = int(math.sqrt(self.max_pixels))
def set_gguf_parameters(self):
super().set_gguf_parameters()
assert self.hparams_vision is not None
hparams = self.hparams_vision
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.PADDLEOCR)
self.gguf_writer.add_vision_max_pixels(self.max_pixels)
self.gguf_writer.add_vision_min_pixels(self.min_pixels)
self.gguf_writer.add_vision_use_gelu(True)
self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("rms_norm_eps", 1e-6))
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
name = name.replace("visual.", "model.")
if "vision_model" in name or "mlp_AR" in name:
if "packing_position_embedding" in name:
return # unused
elif "vision_model.head" in name:
# we don't yet support image embeddings for this model
return
else:
yield from super().modify_tensors(data_torch, name, bid)
return # skip other tensors
@ModelBase.register(
"Qwen2VLModel",
"Qwen2VLForConditionalGeneration",
@@ -4048,6 +4137,87 @@ class InternVisionModel(MmprojModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register(
"NemotronH_Nano_VL_V2",
"RADIOModel",
)
class NemotronNanoV2VLModel(MmprojModel):
# ViT-Huge architecture parameters for RADIO v2.5-h
_vit_hidden_size = 1280
_vit_intermediate_size = 5120
_vit_num_layers = 32
_vit_num_heads = 16
def get_vision_config(self) -> dict[str, Any] | None:
# RADIO config doesn't have standard ViT parameters, so they need to be constructed manually
vision_config = self.global_config.get("vision_config")
if vision_config is None:
return None
# Add ViT-H parameters
vision_config = {
**vision_config,
"hidden_size": self._vit_hidden_size,
"intermediate_size": self._vit_intermediate_size,
"num_hidden_layers": self._vit_num_layers,
"num_attention_heads": self._vit_num_heads,
"image_size": self.global_config.get("force_image_size", 512),
}
return vision_config
def set_gguf_parameters(self):
if "image_mean" not in self.preprocessor_config:
self.preprocessor_config["image_mean"] = [0.485, 0.456, 0.406]
if "image_std" not in self.preprocessor_config:
self.preprocessor_config["image_std"] = [0.229, 0.224, 0.225]
super().set_gguf_parameters()
hparams = self.global_config
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.NEMOTRON_V2_VL)
self.gguf_writer.add_vision_attention_layernorm_eps(1e-6)
self.gguf_writer.add_vision_use_gelu(True)
downsample_ratio = hparams.get("downsample_ratio", 0.5)
self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / downsample_ratio))
def tensor_force_quant(self, name, new_name, bid, n_dims):
if ".position_embd." in new_name or "pos_embed" in new_name:
return gguf.GGMLQuantizationType.F32
return super().tensor_force_quant(name, new_name, bid, n_dims)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if "input_conditioner" in name:
return
# RADIO's pos_embed doesn't have .weight suffix, but clip.cpp expects it
if "patch_generator.pos_embed" in name:
if not name.endswith(".weight"):
name += ".weight"
# Downsample position embeddings for fixed 512x512 image size
import torch.nn.functional as F
n_embd = self.hparams["hidden_size"]
image_size = self.global_config.get("force_image_size", 512)
patch_size = self.hparams["patch_size"]
target_patches_per_side = image_size // patch_size # 32
max_patches_per_side = int((data_torch.shape[1]) ** 0.5) # 128
if target_patches_per_side != max_patches_per_side:
# Reshape to grid, interpolate, flatten back
data_torch = data_torch.reshape(1, max_patches_per_side, max_patches_per_side, n_embd)
data_torch = data_torch.permute(0, 3, 1, 2).float() # [1, n_embd, 128, 128]
data_torch = F.interpolate(data_torch, size=(target_patches_per_side, target_patches_per_side),
mode='bilinear', align_corners=True)
data_torch = data_torch.permute(0, 2, 3, 1) # [1, 32, 32, n_embd]
data_torch = data_torch.reshape(1, target_patches_per_side * target_patches_per_side, n_embd)
# Reshape linear patch embedding to conv2d format for ggml_conv_2d
# From [n_embd, patch_size*patch_size*3] to [n_embd, 3, patch_size, patch_size]
if "patch_generator.embedder" in name:
patch_size = self.hparams["patch_size"]
n_embd = self.hparams["hidden_size"]
data_torch = data_torch.reshape(n_embd, 3, patch_size, patch_size)
if name.startswith("vision_model.radio_model.model.") or name.startswith("mlp1."):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("WavTokenizerDec")
class WavTokenizerDecModel(TextModel):
model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC
@@ -4090,8 +4260,6 @@ class Qwen2MoeModel(TextModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
if (n_experts := self.hparams.get("num_experts")) is not None:
self.gguf_writer.add_expert_count(n_experts)
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}")
@@ -4136,7 +4304,7 @@ class Qwen2MoeModel(TextModel):
return
if name.find("experts") != -1:
n_experts = self.hparams["num_experts"]
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
assert bid is not None
if self._experts is None:
@@ -4475,7 +4643,7 @@ class Qwen3VLVisionModel(MmprojModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Glm4vForConditionalGeneration", "Glm4vMoeForConditionalGeneration")
@ModelBase.register("Glm4vForConditionalGeneration", "Glm4vMoeForConditionalGeneration", "GlmOcrForConditionalGeneration")
class Glm4VVisionModel(Qwen3VLVisionModel):
def set_gguf_parameters(self):
MmprojModel.set_gguf_parameters(self) # skip Qwen3VLVisionModel parameters
@@ -4887,13 +5055,13 @@ class PhiMoeModel(Phi3MiniModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"])
self.gguf_writer.add_expert_count(self.hparams["num_local_experts"])
self.gguf_writer.add_expert_used_count(self.find_hparam(["num_experts_per_tok", "num_experts_per_token"]))
self.gguf_writer.add_expert_count(self.find_hparam(["num_local_experts", "num_experts"]))
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# process the experts separately
if name.find("block_sparse_moe.experts") != -1:
n_experts = self.hparams["num_local_experts"]
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
assert bid is not None
if self._experts is None:
@@ -5305,7 +5473,7 @@ class KimiLinearModel(TextModel):
# process the experts separately
if name.find("block_sparse_moe.experts") != -1:
n_experts = self.find_hparam(["num_local_experts", "num_experts"], optional=False)
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
assert bid is not None
if self._experts is None:
@@ -5900,12 +6068,13 @@ class NomicBertModel(BertModel):
if "mlp.experts.bias" in name:
return # Explicitly return.
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
if "mlp.experts.mlp.w1" in name:
data_torch = data_torch.view(self.hparams["num_experts"], self.hparams["n_inner"], self.hparams["n_embd"])
data_torch = data_torch.view(n_experts, self.hparams["n_inner"], self.hparams["n_embd"])
name += ".weight"
if "mlp.experts.mlp.w2" in name:
data_torch = data_torch.view(self.hparams["num_experts"], self.hparams["n_inner"], self.hparams["n_embd"])
data_torch = data_torch.view(n_experts, self.hparams["n_inner"], self.hparams["n_embd"])
data_torch = data_torch.transpose(1, 2)
name += ".weight"
@@ -5915,7 +6084,6 @@ class NomicBertModel(BertModel):
super().set_gguf_parameters()
if self.is_moe:
self.gguf_writer.add_moe_every_n_layers(self.hparams["moe_every_n_layers"])
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
self.gguf_writer.add_expert_used_count(self.hparams["moe_top_k"])
def _is_tokenizer_xlmroberta(self) -> bool:
@@ -7029,6 +7197,8 @@ class Mamba2Model(TextModel):
if hparams is None:
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
hparams = json.load(f)
if "llm_config" in hparams:
hparams["text_config"] = hparams["llm_config"]
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
self.d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
self.d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or 2 * self.d_model
@@ -7150,8 +7320,8 @@ class JambaModel(TextModel):
self.gguf_writer.add_ssm_state_size(d_state)
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"])
self.gguf_writer.add_expert_count(self.find_hparam(["num_local_experts", "num_experts"]))
self.gguf_writer.add_expert_used_count(self.find_hparam(["num_experts_per_tok", "num_experts_per_token"]))
self.gguf_writer.add_file_type(self.ftype)
_experts: list[dict[str, Tensor]] | None = None
@@ -7169,7 +7339,7 @@ class JambaModel(TextModel):
# process the experts separately
if ".feed_forward.experts." in name:
n_experts = self.hparams["num_experts"]
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
assert bid is not None
@@ -7255,6 +7425,17 @@ class Cohere2Model(TextModel):
self.gguf_writer.add_rope_dimension_count(int(rotary_pct * (hidden_size // num_attention_heads)))
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Cohere2 runtime in llama.cpp expects no bias tensors;
# the actual weight only contains 0-value tensors as bias, we can skip them
if name.endswith(".bias"):
if torch.any(data_torch != 0):
raise ValueError(f"Bias tensor {name!r} is not zero.")
logger.debug(f"Skipping bias tensor {name!r} for Cohere2 conversion.")
return
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("OlmoForCausalLM")
@ModelBase.register("OLMoForCausalLM")
@@ -7317,8 +7498,6 @@ class OlmoeModel(TextModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_layer_norm_rms_eps(1e-5)
if (n_experts := self.hparams.get("num_experts")) is not None:
self.gguf_writer.add_expert_count(n_experts)
_experts: list[dict[str, Tensor]] | None = None
@@ -7326,7 +7505,7 @@ class OlmoeModel(TextModel):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# process the experts separately
if name.find("experts") != -1:
n_experts = self.hparams["num_experts"]
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
assert bid is not None
if self._experts is None:
@@ -7695,12 +7874,16 @@ class DeepseekModel(TextModel):
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"KimiVLForConditionalGeneration",
"KimiK25ForConditionalGeneration",
"YoutuForCausalLM",
"YoutuVLForConditionalGeneration",
)
class DeepseekV2Model(TextModel):
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
# TODO @ngxson : remove this when we support MTP for deepseek models
skip_mtp = True
def set_vocab(self):
try:
self._set_vocab_gpt2()
@@ -7813,8 +7996,8 @@ class DeepseekV2Model(TextModel):
_experts: list[dict[str, Tensor]] | None = None
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# skip vision tensors and remove "language_model." for Kimi-VL
if "vision_tower" in name or "multi_modal_projector" in name:
# skip vision tensors and remove "language_model." for Kimi-VL and Kimi-K2.5
if "vision_tower" in name or "multi_modal_projector" in name or "mm_projector" in name:
return
if name.startswith("siglip2.") or name.startswith("merger."):
return
@@ -7832,10 +8015,11 @@ class DeepseekV2Model(TextModel):
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
# skip Multi-Token Prediction (MTP) layers
block_count = self.hparams["num_hidden_layers"]
match = re.match(r"model.layers.(\d+)", name)
if match and int(match.group(1)) >= block_count:
return
if self.skip_mtp:
block_count = self.hparams["num_hidden_layers"]
match = re.match(r"model.layers.(\d+)", name)
if match and int(match.group(1)) >= block_count:
return
# process the experts separately
if name.find("mlp.experts") != -1:
@@ -7902,10 +8086,6 @@ class MiniMaxM2Model(TextModel):
model_arch = gguf.MODEL_ARCH.MINIMAXM2
_experts_cache: dict[int, dict[str, Tensor]] = {}
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hparams["num_experts"] = self.hparams["num_local_experts"]
def set_gguf_parameters(self):
super().set_gguf_parameters()
@@ -7918,7 +8098,7 @@ class MiniMaxM2Model(TextModel):
# merge expert weights
if 'experts' in name:
n_experts = self.hparams["num_experts"]
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
assert bid is not None
expert_cache = self._experts_cache.setdefault(bid, {})
@@ -8512,6 +8692,17 @@ class T5EncoderModel(TextModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Jais2ForCausalLM")
class Jais2Model(TextModel):
model_arch = gguf.MODEL_ARCH.JAIS2
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
head_dim = hparams.get("head_dim", hparams["hidden_size"] // hparams["num_attention_heads"])
self.gguf_writer.add_rope_dimension_count(head_dim)
@ModelBase.register("JAISLMHeadModel")
class JaisModel(TextModel):
model_arch = gguf.MODEL_ARCH.JAIS
@@ -8655,7 +8846,7 @@ class Glm4Model(TextModel):
n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams["num_key_value_heads"]
n_embd = self.hparams["hidden_size"]
head_dim = n_embd // n_head
head_dim = self.hparams.get("head_dim", n_embd // n_head)
# because llama.cpp M-RoPE kernel only supports Neox ordering, we have to permute the weights here
if name.endswith(("q_proj.weight", "q_proj.bias")):
data_torch = Glm4Model.normal_to_neox(data_torch, n_head, n_head, head_dim, self.partial_rotary_factor)
@@ -8664,6 +8855,27 @@ class Glm4Model(TextModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("GlmOcrForConditionalGeneration")
class GlmOCRModel(Glm4Model):
model_arch = gguf.MODEL_ARCH.GLM4
use_mrope = False
partial_rotary_factor = 0.5
# Note: GLM-OCR is the same as GLM4, but with an extra NextN/MTP prediction layer
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# GLM-OCR has num_hidden_layers + 1 actual layers (including NextN layer)
self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("num_nextn_predict_layers", 0)
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
def set_gguf_parameters(self):
super().set_gguf_parameters()
# NextN/MTP prediction layers
if (num_nextn_predict_layers := self.hparams.get("num_nextn_predict_layers")) is not None:
self.gguf_writer.add_nextn_predict_layers(num_nextn_predict_layers)
@ModelBase.register("Glm4MoeForCausalLM", "Glm4vMoeForConditionalGeneration")
class Glm4MoeModel(TextModel):
model_arch = gguf.MODEL_ARCH.GLM4_MOE
@@ -8675,24 +8887,7 @@ class Glm4MoeModel(TextModel):
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
def set_vocab(self):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
tokens, toktypes, tokpre = self.get_vocab_base()
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
# Special tokens
# Note: Using <|endoftext|> (151329) for eot causes endless generation
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329
special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338
special_vocab.add_to_gguf(self.gguf_writer)
return self._set_vocab_glm()
def set_gguf_parameters(self):
super().set_gguf_parameters()
@@ -8792,26 +8987,38 @@ class Glm4MoeModel(TextModel):
class Glm4MoeLiteModel(DeepseekV2Model):
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
# copied from Glm4MoeModel
def set_vocab(self):
from transformers import AutoTokenizer
return self._set_vocab_glm()
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
tokens, toktypes, tokpre = self.get_vocab_base()
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
# Special tokens
# Note: Using <|endoftext|> (151329) for eot causes endless generation
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329
special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338
@ModelBase.register("GlmMoeDsaForCausalLM")
class GlmMoeDsaModel(DeepseekV2Model):
model_arch = gguf.MODEL_ARCH.GLM_DSA
skip_mtp = False
special_vocab.add_to_gguf(self.gguf_writer)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("num_nextn_predict_layers", 0)
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
def set_vocab(self):
return self._set_vocab_glm()
def set_gguf_parameters(self):
super().set_gguf_parameters()
rope_dim = self.hparams["qk_rope_head_dim"]
partial_rotary_factor = self.hparams.get("partial_rotary_factor", 1.0)
self.gguf_writer.add_rope_dimension_count(int(rope_dim * partial_rotary_factor))
# NextN/MTP prediction layers
if (num_nextn_predict_layers := self.hparams.get("num_nextn_predict_layers")) is not None:
self.gguf_writer.add_nextn_predict_layers(num_nextn_predict_layers)
# DSA indexer parameters
self.gguf_writer.add_indexer_head_count(self.hparams["index_n_heads"])
self.gguf_writer.add_indexer_key_length(self.hparams["index_head_dim"])
self.gguf_writer.add_indexer_top_k(self.hparams["index_topk"])
@ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration")
@@ -9128,7 +9335,6 @@ class ExaoneMoEModel(Exaone4Model):
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
moe_intermediate_size = self.hparams["moe_intermediate_size"]
num_shared_experts = self.hparams["num_shared_experts"]
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
@@ -9169,7 +9375,7 @@ class ExaoneMoEModel(Exaone4Model):
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
if name.find("mlp.experts") != -1:
n_experts = self.hparams["num_experts"]
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
assert bid is not None
if self._experts is None:
@@ -9320,7 +9526,7 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel):
# case, the model architecture needs to be updated to a standard
# "granite" or "granitemoe" model
if not self._ssm_layers:
has_experts = self.find_hparam(["num_experts_per_tok"], optional=True)
has_experts = self.find_hparam(["num_experts_per_tok", "num_experts_per_token"], optional=True)
new_arch = (
gguf.MODEL_ARCH.GRANITE_MOE
if has_experts else
@@ -9516,6 +9722,14 @@ class NemotronHModel(GraniteHybridModel):
self.gguf_writer.add_add_bos_token(True)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Skip vision model and projector tensors for VLM models (handled by mmproj) (e.g., Nemotron Nano 12B v2 VL)
if name.startswith(("vision_model.", "mlp1.")):
return
# Strip language_model. prefix for VLM models (e.g., Nemotron Nano 12B v2 VL)
if name.startswith("language_model."):
name = name[len("language_model."):]
if self.is_moe and bid is not None:
if name.endswith("mixer.gate.e_score_correction_bias"):
new_name = name.replace("e_score_correction_bias", "e_score_correction.bias")
@@ -9610,7 +9824,6 @@ class BailingMoeModel(TextModel):
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
self.gguf_writer.add_expert_weights_scale(1.0)
self.gguf_writer.add_expert_count(hparams["num_experts"])
self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"])
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
@@ -9644,7 +9857,7 @@ class BailingMoeModel(TextModel):
yield from super().modify_tensors(v,self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), bid)
return
elif name.find("mlp.experts") != -1:
n_experts = self.hparams["num_experts"]
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
assert bid is not None
if self._experts is None:
@@ -9715,7 +9928,6 @@ class BailingMoeV2Model(TextModel):
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
self.gguf_writer.add_expert_shared_feed_forward_length(hparams.get("moe_shared_expert_intermediate_size", hparams["moe_intermediate_size"] * hparams["num_shared_experts"]))
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
self.gguf_writer.add_expert_count(hparams["num_experts"])
self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"])
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
@@ -9726,7 +9938,7 @@ class BailingMoeV2Model(TextModel):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if "mlp.experts" in name:
n_experts = self.hparams["num_experts"]
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
assert bid is not None
if self._experts is None:
@@ -9772,8 +9984,6 @@ class GroveMoeModel(TextModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
if (n_experts := self.hparams.get("num_experts")) is not None:
self.gguf_writer.add_expert_count(n_experts)
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}")
@@ -9794,7 +10004,7 @@ class GroveMoeModel(TextModel):
# process the experts separately
if name.find("chunk_experts") != -1:
n_experts = self.hparams["num_experts"] // 2 # see add_experts_per_group
n_experts = self.find_hparam(["num_local_experts", "num_experts"]) // 2 # see add_experts_per_group
assert bid is not None
if self._chunk_experts is None:
@@ -9821,7 +10031,7 @@ class GroveMoeModel(TextModel):
else:
return
elif name.find("experts") != -1:
n_experts = self.hparams["num_experts"]
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
assert bid is not None
if self._experts is None:
@@ -10214,7 +10424,6 @@ class HunYuanMoEModel(TextModel):
super().set_gguf_parameters()
hparams = self.hparams
self.gguf_writer.add_expert_count(hparams["num_experts"])
self.gguf_writer.add_expert_shared_feed_forward_length(hparams["intermediate_size"])
moe_intermediate_size = hparams["moe_intermediate_size"]
@@ -10257,7 +10466,7 @@ class HunYuanMoEModel(TextModel):
return
if name.find("mlp.experts") != -1:
n_experts = self.hparams["num_experts"]
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
assert bid is not None
if self._experts is None:
@@ -10299,16 +10508,9 @@ class LLaDAMoEModel(TextModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
if (n_experts := self.hparams.get("num_experts")) is not None:
self.gguf_writer.add_expert_count(n_experts)
if (expert_intermediate_size := self.hparams.get("expert_intermediate_size")) is not None:
self.gguf_writer.add_expert_feed_forward_length(expert_intermediate_size)
# number of experts used per token (top-k)
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
self.gguf_writer.add_expert_used_count(n_experts_used)
self.gguf_writer.add_mask_token_id(156895)
self.gguf_writer.add_causal_attention(False)
self.gguf_writer.add_diffusion_shift_logits(False)
@@ -10319,7 +10521,7 @@ class LLaDAMoEModel(TextModel):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# process the experts separately
if name.find("experts") != -1:
n_experts = self.hparams["num_experts"]
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
assert bid is not None
if self._experts is None:
@@ -10594,7 +10796,7 @@ class LFM2Model(TextModel):
def set_gguf_parameters(self):
# set num_key_value_heads only for attention layers
self.hparams["num_key_value_heads"] = [
self.hparams["num_key_value_heads"] if layer_type == "full_attention" else 0
self.hparams["num_key_value_heads"] if layer_type != "conv" else 0
for layer_type in self.hparams["layer_types"]
]
@@ -10656,7 +10858,6 @@ class LFM2MoeModel(TextModel):
super().set_gguf_parameters()
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"])
self.gguf_writer.add_leading_dense_block_count(self.hparams["num_dense_layers"])
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
@@ -10677,7 +10878,7 @@ class LFM2MoeModel(TextModel):
# merge expert weights
if 'experts' in name:
n_experts = self.hparams["num_experts"]
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
assert bid is not None
expert_cache = self._experts_cache.setdefault(bid, {})
@@ -10781,15 +10982,37 @@ class LFM2AudioModel(ConformerAudioModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Lfm25AudioTokenizer")
class LFM25AudioTokenizer(LFM2Model):
model_arch = gguf.MODEL_ARCH.LFM2
def set_vocab(self):
self._set_vocab_none()
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
self.gguf_writer.add_embedding_length_out(self.hparams["output_size"])
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name == "istft.window" or name.startswith("emb.emb"):
return
if name.startswith("lin"):
name = name.replace("lin", "dense_2_out")
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("SmallThinkerForCausalLM")
class SmallThinkerModel(TextModel):
model_arch = gguf.MODEL_ARCH.SMALLTHINKER
def set_gguf_parameters(self):
super().set_gguf_parameters()
if (n_experts := self.hparams.get("num_experts", self.hparams.get("moe_num_primary_experts"))) is not None:
if (n_experts := self.hparams.get("moe_num_primary_experts")) is not None:
self.gguf_writer.add_expert_count(n_experts)
if (n_experts_used := self.hparams.get("num_experts_per_tok", self.hparams.get("moe_num_active_primary_experts"))) is not None:
if (n_experts_used := self.hparams.get("moe_num_active_primary_experts")) is not None:
self.gguf_writer.add_expert_used_count(n_experts_used)
if (moe_intermediate_size := self.hparams.get("moe_ffn_hidden_size")) is not None:
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
@@ -10814,7 +11037,7 @@ class SmallThinkerModel(TextModel):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# process the experts separately
if name.find("experts") != -1:
n_experts = self.hparams.get("num_experts", self.hparams.get("moe_num_primary_experts"))
n_experts = self.hparams.get("moe_num_primary_experts") or self.find_hparam(["num_local_experts", "num_experts"])
assert bid is not None
if self._experts is None:
@@ -10872,13 +11095,17 @@ class ModernBertModel(BertModel):
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# these layers act as MLM head, so we don't need them
if name.startswith("decoder."):
return
if name.startswith("model."):
name = name[6:]
if self.cls_out_labels:
# For BertForSequenceClassification (direct projection layer)
if name == "classifier.weight":
name = "classifier.out_proj.weight"
if name == "classifier.bias":
name = "classifier.out_proj.bias"
yield from super().modify_tensors(data_torch, name, bid)
@@ -11176,6 +11403,103 @@ class KimiVLModel(MmprojModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("KimiK25ForConditionalGeneration")
class KimiK25Model(MmprojModel):
"""Kimi-K2.5 with MoonViT3d vision encoder"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hparams_vision is not None, "Kimi-K2.5 requires vision_config in model config"
self.merge_kernel_size = tuple(self.hparams_vision.get("merge_kernel_size", [2, 2]))
self.patch_size = self.hparams_vision.get("patch_size", 14)
# Set image_size for compatibility with base class
# Use position embedding dimensions as image_size reference
pos_emb_h = self.hparams_vision.get("init_pos_emb_height", 64)
self.hparams_vision["image_size"] = pos_emb_h * self.patch_size
def set_gguf_parameters(self):
# Base class MmprojModel.set_gguf_parameters() already writes:
# - vision_block_count, vision_head_count, vision_embedding_length
# - vision_feed_forward_length, vision_patch_size, image_mean, image_std
# via find_vparam() which handles the vt_* prefixed keys in Kimi-K2.5's config
super().set_gguf_parameters()
assert self.hparams_vision is not None
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.KIMIK25)
# Position embedding parameters (for interpolation)
self.gguf_writer.add_uint32("vision.pos_emb_height", self.hparams_vision.get("init_pos_emb_height", 64))
self.gguf_writer.add_uint32("vision.pos_emb_width", self.hparams_vision.get("init_pos_emb_width", 64))
self.gguf_writer.add_uint32("vision.pos_emb_time", self.hparams_vision.get("init_pos_emb_time", 4))
# Projector parameters
self.gguf_writer.add_vision_use_gelu(self.hparams_vision.get("projector_hidden_act", "gelu") == "gelu")
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("projector_ln_eps", 1e-5))
self.gguf_writer.add_vision_projector_scale_factor(self.merge_kernel_size[0])
# Image size limits
# Note: in_patch_limit is for images, in_patch_limit_each_frame is for video (not supported yet)
in_patch_limit = self.preprocessor_config.get("in_patch_limit", 16384)
min_patches = 8 # reasonable minimum
pixels_per_patch = self.patch_size ** 2
self.gguf_writer.add_vision_min_pixels(min_patches * pixels_per_patch)
self.gguf_writer.add_vision_max_pixels(in_patch_limit * pixels_per_patch)
@staticmethod
def permute(weights: Tensor, n_head: int) -> Tensor:
out_dim, in_dim = weights.shape
head_dim = out_dim // n_head
w = weights.reshape(n_head, head_dim // 4, 2, 2, in_dim)
w = w.permute(0, 2, 1, 3, 4)
return w.reshape(out_dim, in_dim)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Only process vision and projector tensors
is_vision = any(x in name for x in ["vision_tower", "mm_projector"])
if not is_vision:
return
assert self.hparams_vision is not None
n_head = self.hparams_vision.get("num_attention_heads", 16)
# Permute Q/K weights/biases from interleaved to split RoPE format
# This allows using build_rope_2d at runtime without post-permutation.
if "wqkv" in name:
out_dim = data_torch.shape[0]
qkv_dim = out_dim // 3
head_dim = qkv_dim // n_head
if "weight" in name:
wq, wk, wv = data_torch[:qkv_dim, :], data_torch[qkv_dim:2 * qkv_dim, :], data_torch[2 * qkv_dim:, :]
wq = self.permute(wq, n_head)
wk = self.permute(wk, n_head)
data_torch = torch.cat([wq, wk, wv], dim=0)
elif "bias" in name:
bq, bk, bv = data_torch[:qkv_dim], data_torch[qkv_dim:2 * qkv_dim], data_torch[2 * qkv_dim:]
bq = bq.reshape(n_head, head_dim // 4, 2, 2).permute(0, 2, 1, 3).reshape(-1)
bk = bk.reshape(n_head, head_dim // 4, 2, 2).permute(0, 2, 1, 3).reshape(-1)
data_torch = torch.cat([bq, bk, bv], dim=0)
# Temporal embeddings: (T, 1, C) → (T, C)
if "pos_emb.time_weight" in name:
T, _, C = data_torch.shape
data_torch = data_torch.reshape(T, C)
# PatchMergerMLP tensor name mapping
# proj.0.weight → proj.linear_1.weight
# proj.2.weight → proj.linear_2.weight
if "mm_projector.proj.0." in name:
name = name.replace(".proj.0.", ".proj.linear_1.")
elif "mm_projector.proj.2." in name:
name = name.replace(".proj.2.", ".proj.linear_2.")
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("CogVLMForCausalLM")
class CogVLMVisionModel(MmprojModel):
+6 -2
View File
@@ -99,6 +99,7 @@ models = [
{"name": "stablelm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b", },
{"name": "refact", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/smallcloudai/Refact-1_6-base", },
{"name": "command-r", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/CohereForAI/c4ai-command-r-v01", },
{"name": "tiny_aya", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/CohereLabs/tiny-aya-base", },
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", },
{"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", },
{"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", },
@@ -113,6 +114,7 @@ models = [
{"name": "gemma", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2b", },
{"name": "gemma-2", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2-9b", },
{"name": "jais", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/core42/jais-13b", },
{"name": "jais-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inceptionai/Jais-2-8B-Chat", },
{"name": "t5", "tokt": TOKENIZER_TYPE.UGM, "repo": "https://huggingface.co/google-t5/t5-small", },
{"name": "codeshell", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/WisdomShell/CodeShell-7B", },
{"name": "tekken", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistralai/Mistral-Nemo-Base-2407", },
@@ -148,7 +150,9 @@ models = [
{"name": "youtu", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Youtu-LLM-2B", },
{"name": "solar-open", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/upstage/Solar-Open-100B", },
{"name": "exaone-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/K-EXAONE-236B-A23B", },
{"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.5-9B-Instruct", }
{"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.5-9B-Instruct", },
{"name": "joyai-llm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jdopensource/JoyAI-LLM-Flash", },
{"name": "kanana2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/kakaocorp/kanana-2-30b-a3b-instruct-2601", },
]
# some models are known to be broken upstream, so we will skip them as exceptions
@@ -158,6 +162,7 @@ pre_computed_hashes = [
{"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"},
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"},
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.5-Air", "chkhsh": "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902"},
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.7-Flash", "chkhsh": "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267"},
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"},
{"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"},
{"name": "hunyuan-dense", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-4B-Instruct", "chkhsh": "bba3b3366b646dbdded5dbc42d59598b849371afc42f7beafa914afaa5b70aa6"},
@@ -171,7 +176,6 @@ pre_computed_hashes = [
{"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"},
# jina-v2-de variants
{"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/aari1995/German_Semantic_V3", "chkhsh": "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df"},
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.7-Flash", "chkhsh": "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267"},
]
+1 -1
View File
@@ -246,7 +246,7 @@ cmake --build build --config release
1. **Retrieve and prepare model**
You can refer to the general [*Prepare and Quantize*](../../README.md#prepare-and-quantize) guide for model prepration.
You can refer to the general [*Obtaining and quantizing models*](../../README.md#obtaining-and-quantizing-models) guide for model prepration.
**Notes**:
+2 -2
View File
@@ -281,7 +281,7 @@ as `-cl-fp32-correctly-rounded-divide-sqrt`
#### Retrieve and prepare model
You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/resolve/main/llama-2-7b.Q4_0.gguf?download=true) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf).
You can refer to the general [*Obtaining and quantizing models*](../../README.md#obtaining-and-quantizing-models) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/resolve/main/llama-2-7b.Q4_0.gguf?download=true) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf).
##### Check device
@@ -569,7 +569,7 @@ Once it is completed, final results will be in **build/Release/bin**
#### Retrieve and prepare model
You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/blob/main/llama-2-7b.Q4_0.gguf) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf).
You can refer to the general [*Obtaining and quantizing models*](../../README.md#obtaining-and-quantizing-models) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/blob/main/llama-2-7b.Q4_0.gguf) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf).
##### Check device
+1 -1
View File
@@ -35,7 +35,7 @@ Adapt below build commands accordingly.
Let's build llama.cpp with CPU, OpenCL, and Hexagon backends via CMake presets:
```
[d]/workspace> cp docs/backend/hexagon/CMakeUserPresets.json .
[d]/workspace> cp docs/backend/snapdragon/CMakeUserPresets.json .
[d]/workspace> cmake --preset arm64-android-snapdragon-release -B build-snapdragon
Preset CMake variables:
+3 -3
View File
@@ -242,10 +242,10 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl
|------------|-------------|------|-------|
| FP32 | ✅ | ✅ | ❓ |
| FP16 | ✅ | ✅ | ❓ |
| BF16 | 🚫 | ✅ | ❓ |
| BF16 | | ✅ | ❓ |
| Q4_0 | ✅ | ❓ | ❓ |
| Q4_1 | ✅ | ❓ | ❓ |
| MXFP4 | 🚫 | ❓ | ❓ |
| MXFP4 | | ❓ | ❓ |
| Q5_0 | ✅ | ❓ | ❓ |
| Q5_1 | ✅ | ❓ | ❓ |
| Q8_0 | ✅ | ❓ | ❓ |
@@ -272,4 +272,4 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl
- 🚫 - acceleration unavailable, will still run using scalar implementation
- ❓ - acceleration unknown, please contribute if you can test it yourself
Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on Sep 7, 2025.
Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on Feb 15, 2026.
+4 -4
View File
@@ -31,7 +31,7 @@ Legend:
| CONV_3D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | | ❌ | ❌ |
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | | ❌ | ❌ |
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| CROSS_ENTROPY_LOSS | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
@@ -96,13 +96,13 @@ Legend:
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | | ❌ | ❌ |
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | | ❌ | ❌ |
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | | ❌ | ❌ |
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | | ❌ | ❌ |
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | | ❌ | ❌ |
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | | ❌ | ❌ |
| SSM_CONV | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
+24 -16
View File
@@ -8760,22 +8760,14 @@
"WebGPU: WebGPU","ADD_ID","type_a=f32,type_b=f32,n_embd=129,n_experts=8,n_experts_used=4,n_token=1","support","0","no","WebGPU"
"WebGPU: WebGPU","ADD_ID","type_a=f32,type_b=f32,n_embd=129,n_experts=8,n_experts_used=4,n_token=32","support","0","no","WebGPU"
"WebGPU: WebGPU","ADD_ID","type_a=f32,type_b=f32,n_embd=129,n_experts=8,n_experts_used=4,n_token=129","support","0","no","WebGPU"
"WebGPU: WebGPU","SQR","type=f16,ne=[10,5,4,3]","support","0","no","WebGPU"
"WebGPU: WebGPU","SQRT","type=f16,ne=[10,3,3,2]","support","0","no","WebGPU"
"WebGPU: WebGPU","LOG","type=f16,ne=[10,5,4,3]","support","1","yes","WebGPU"
"WebGPU: WebGPU","SIN","type=f16,ne=[10,2,2,2]","support","0","no","WebGPU"
"WebGPU: WebGPU","COS","type=f16,ne=[10,2,2,2]","support","0","no","WebGPU"
"WebGPU: WebGPU","CLAMP","type=f16,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","1","yes","WebGPU"
"WebGPU: WebGPU","LEAKY_RELU","type=f16,ne_a=[10,5,4,3],negative_slope=0.100000","support","0","no","WebGPU"
"WebGPU: WebGPU","FLOOR","type=f16,ne=[10,2,2,2]","support","1","yes","WebGPU"
"WebGPU: WebGPU","CEIL","type=f16,ne=[10,2,2,2]","support","1","yes","WebGPU"
"WebGPU: WebGPU","ROUND","type=f16,ne=[10,2,2,2]","support","1","yes","WebGPU"
"WebGPU: WebGPU","TRUNC","type=f16,ne=[10,2,2,2]","support","1","yes","WebGPU"
"WebGPU: WebGPU","SQR","type=f16,ne=[7,1,5,3]","support","0","no","WebGPU"
"WebGPU: WebGPU","SQRT","type=f16,ne=[7,1,5,3]","support","0","no","WebGPU"
"WebGPU: WebGPU","LOG","type=f16,ne=[7,1,5,3]","support","1","yes","WebGPU"
"WebGPU: WebGPU","SIN","type=f16,ne=[7,1,5,3]","support","0","no","WebGPU"
"WebGPU: WebGPU","COS","type=f16,ne=[7,1,5,3]","support","0","no","WebGPU"
"WebGPU: WebGPU","CLAMP","type=f16,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","1","yes","WebGPU"
"WebGPU: WebGPU","LEAKY_RELU","type=f16,ne_a=[7,1,5,3],negative_slope=0.100000","support","0","no","WebGPU"
"WebGPU: WebGPU","FLOOR","type=f16,ne=[7,1,5,3]","support","1","yes","WebGPU"
@@ -8786,22 +8778,14 @@
"WebGPU: WebGPU","ROUND","type=f16,ne=[1024,1024,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","TRUNC","type=f16,ne=[7,1,5,3]","support","1","yes","WebGPU"
"WebGPU: WebGPU","TRUNC","type=f16,ne=[1024,1024,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","SQR","type=f32,ne=[10,5,4,3]","support","0","no","WebGPU"
"WebGPU: WebGPU","SQRT","type=f32,ne=[10,3,3,2]","support","0","no","WebGPU"
"WebGPU: WebGPU","LOG","type=f32,ne=[10,5,4,3]","support","1","yes","WebGPU"
"WebGPU: WebGPU","SIN","type=f32,ne=[10,2,2,2]","support","0","no","WebGPU"
"WebGPU: WebGPU","COS","type=f32,ne=[10,2,2,2]","support","0","no","WebGPU"
"WebGPU: WebGPU","CLAMP","type=f32,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","1","yes","WebGPU"
"WebGPU: WebGPU","LEAKY_RELU","type=f32,ne_a=[10,5,4,3],negative_slope=0.100000","support","0","no","WebGPU"
"WebGPU: WebGPU","FLOOR","type=f32,ne=[10,2,2,2]","support","1","yes","WebGPU"
"WebGPU: WebGPU","CEIL","type=f32,ne=[10,2,2,2]","support","1","yes","WebGPU"
"WebGPU: WebGPU","ROUND","type=f32,ne=[10,2,2,2]","support","1","yes","WebGPU"
"WebGPU: WebGPU","TRUNC","type=f32,ne=[10,2,2,2]","support","1","yes","WebGPU"
"WebGPU: WebGPU","SQR","type=f32,ne=[7,1,5,3]","support","0","no","WebGPU"
"WebGPU: WebGPU","SQRT","type=f32,ne=[7,1,5,3]","support","0","no","WebGPU"
"WebGPU: WebGPU","LOG","type=f32,ne=[7,1,5,3]","support","1","yes","WebGPU"
"WebGPU: WebGPU","SIN","type=f32,ne=[7,1,5,3]","support","0","no","WebGPU"
"WebGPU: WebGPU","COS","type=f32,ne=[7,1,5,3]","support","0","no","WebGPU"
"WebGPU: WebGPU","CLAMP","type=f32,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","1","yes","WebGPU"
"WebGPU: WebGPU","LEAKY_RELU","type=f32,ne_a=[7,1,5,3],negative_slope=0.100000","support","0","no","WebGPU"
"WebGPU: WebGPU","FLOOR","type=f32,ne=[7,1,5,3]","support","1","yes","WebGPU"
@@ -18901,3 +18885,27 @@
"WebGPU: WebGPU","CROSS_ENTROPY_LOSS_BACK","type=f32,ne=[30000,1,1,1]","support","0","no","WebGPU"
"WebGPU: WebGPU","OPT_STEP_ADAMW","type=f32,ne=[10,5,4,3]","support","0","no","WebGPU"
"WebGPU: WebGPU","OPT_STEP_SGD","type=f32,ne=[10,5,4,3]","support","0","no","WebGPU"
"WebGPU: WebGPU","SQR","type=f16,ne=[10,5,4,3]","support","1","yes","WebGPU"
"WebGPU: WebGPU","SQRT","type=f16,ne=[10,3,3,2]","support","1","yes","WebGPU"
"WebGPU: WebGPU","SIN","type=f16,ne=[10,2,2,2]","support","1","yes","WebGPU"
"WebGPU: WebGPU","COS","type=f16,ne=[10,2,2,2]","support","1","yes","WebGPU"
"WebGPU: WebGPU","SQR","type=f16,ne=[7,1,5,3]","support","1","yes","WebGPU"
"WebGPU: WebGPU","SQR","type=f16,ne=[1024,1024,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","SQRT","type=f16,ne=[7,1,5,3]","support","1","yes","WebGPU"
"WebGPU: WebGPU","SQRT","type=f16,ne=[1024,1024,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","SIN","type=f16,ne=[7,1,5,3]","support","1","yes","WebGPU"
"WebGPU: WebGPU","SIN","type=f16,ne=[1024,1024,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","COS","type=f16,ne=[7,1,5,3]","support","1","yes","WebGPU"
"WebGPU: WebGPU","COS","type=f16,ne=[1024,1024,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","SQR","type=f32,ne=[10,5,4,3]","support","1","yes","WebGPU"
"WebGPU: WebGPU","SQRT","type=f32,ne=[10,3,3,2]","support","1","yes","WebGPU"
"WebGPU: WebGPU","SIN","type=f32,ne=[10,2,2,2]","support","1","yes","WebGPU"
"WebGPU: WebGPU","COS","type=f32,ne=[10,2,2,2]","support","1","yes","WebGPU"
"WebGPU: WebGPU","SQR","type=f32,ne=[7,1,5,3]","support","1","yes","WebGPU"
"WebGPU: WebGPU","SQR","type=f32,ne=[1024,1024,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","SQRT","type=f32,ne=[7,1,5,3]","support","1","yes","WebGPU"
"WebGPU: WebGPU","SQRT","type=f32,ne=[1024,1024,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","SIN","type=f32,ne=[7,1,5,3]","support","1","yes","WebGPU"
"WebGPU: WebGPU","SIN","type=f32,ne=[1024,1024,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","COS","type=f32,ne=[7,1,5,3]","support","1","yes","WebGPU"
"WebGPU: WebGPU","COS","type=f32,ne=[1024,1024,1,1]","support","1","yes","WebGPU"
Can't render this file because it is too large.
@@ -42,11 +42,15 @@ def load_model_and_tokenizer(model_path, device="auto"):
config = config.text_config
multimodal = True
print("Vocab size: ", config.vocab_size)
print("Hidden size: ", config.hidden_size)
print("Number of layers: ", config.num_hidden_layers)
print("BOS token id: ", config.bos_token_id)
print("EOS token id: ", config.eos_token_id)
def print_if_exists(label, obj, attr, default="N/A"):
val = getattr(obj, attr) if hasattr(obj, attr) else default
print(f"{label}", val)
print_if_exists("Vocab size: ", config, "vocab_size")
print_if_exists("Hidden size: ", config, "hidden_size")
print_if_exists("Number of layers: ", config, "num_hidden_layers")
print_if_exists("BOS token id: ", config, "bos_token_id")
print_if_exists("EOS token id: ", config, "eos_token_id")
unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
if unreleased_model_name:
@@ -78,7 +78,7 @@ def list_all_tensors(model_path: Path, unique: bool = False):
print(tensor_name)
def print_tensor_info(model_path: Path, tensor_name: str):
def print_tensor_info(model_path: Path, tensor_name: str, num_values: Optional[int] = None):
tensor_file = find_tensor_file(model_path, tensor_name)
if tensor_file is None:
@@ -96,6 +96,12 @@ def print_tensor_info(model_path: Path, tensor_name: str):
print(f"Tensor: {tensor_name}")
print(f"File: {tensor_file}")
print(f"Shape: {shape}")
if num_values is not None:
tensor = f.get_tensor(tensor_name)
print(f"Dtype: {tensor.dtype}")
flat = tensor.flatten()
n = min(num_values, flat.numel())
print(f"Values: {flat[:n].tolist()}")
else:
print(f"Error: Tensor '{tensor_name}' not found in {tensor_file}")
sys.exit(1)
@@ -127,6 +133,15 @@ def main():
action="store_true",
help="List unique tensor patterns in the model (layer numbers replaced with #)"
)
parser.add_argument(
"-n", "--num-values",
nargs="?",
const=10,
default=None,
type=int,
metavar="N",
help="Print the first N values of the tensor flattened (default: 10 if flag is given without a number)"
)
args = parser.parse_args()
@@ -152,7 +167,7 @@ def main():
if args.tensor_name is None:
print("Error: tensor_name is required when not using --list")
sys.exit(1)
print_tensor_info(model_path, args.tensor_name)
print_tensor_info(model_path, args.tensor_name, args.num_values)
if __name__ == "__main__":
+1 -1
View File
@@ -4,7 +4,7 @@ project("ggml" C CXX ASM)
### GGML Version
set(GGML_VERSION_MAJOR 0)
set(GGML_VERSION_MINOR 9)
set(GGML_VERSION_PATCH 5)
set(GGML_VERSION_PATCH 7)
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)
+1
View File
@@ -752,6 +752,7 @@ extern "C" {
GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor);
GGML_API bool ggml_is_permuted (const struct ggml_tensor * tensor);
GGML_API bool ggml_is_empty (const struct ggml_tensor * tensor);
GGML_API bool ggml_is_view (const struct ggml_tensor * tensor);
GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor);
GGML_API bool ggml_is_vector (const struct ggml_tensor * tensor);
GGML_API bool ggml_is_matrix (const struct ggml_tensor * tensor);
+4 -9
View File
@@ -17,11 +17,6 @@
//#define AT_PRINTF(...) GGML_LOG_DEBUG(__VA_ARGS__)
#define AT_PRINTF(...)
static bool ggml_is_view(const struct ggml_tensor * t) {
return t->view_src != NULL;
}
// ops that return true for this function must not use restrict pointers for their backend implementations
bool ggml_op_can_inplace(enum ggml_op op) {
switch (op) {
@@ -627,7 +622,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
GGML_ASSERT(buffer_id >= 0);
struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_is_view(node)) {
if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_impl_is_view(node)) {
hn->allocated = true;
assert(hn->addr.offset == 0);
@@ -658,7 +653,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent);
if (p_hn->n_children == 1 && p_hn->n_views == 0) {
if (ggml_is_view(parent)) {
if (ggml_impl_is_view(parent)) {
struct ggml_tensor * view_src = parent->view_src;
struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src);
if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
@@ -739,7 +734,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
// GGML_OP_NONE does not appear normally in the graph nodes, but is used by ggml-backend to add dependencies to
// control when some tensors are allocated and freed. in this case, the dependencies are in `src`, but the node
// itself is never used and should not be considered a dependency
if (ggml_is_view(node) && node->op != GGML_OP_NONE) {
if (ggml_impl_is_view(node) && node->op != GGML_OP_NONE) {
struct ggml_tensor * view_src = node->view_src;
ggml_gallocr_hash_get(galloc, view_src)->n_views += 1;
}
@@ -806,7 +801,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
parent->name, p_hn->n_children, p_hn->n_views, p_hn->allocated);
if (p_hn->n_children == 0 && p_hn->n_views == 0) {
if (ggml_is_view(parent)) {
if (ggml_impl_is_view(parent)) {
struct ggml_tensor * view_src = parent->view_src;
struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src);
view_src_hn->n_views -= 1;
+9 -7
View File
@@ -9,6 +9,11 @@ function(ggml_add_cpu_backend_features cpu_name arch)
target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE ${ARGN})
target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE GGML_BACKEND_DL GGML_BACKEND_BUILD GGML_BACKEND_SHARED)
set_target_properties(${GGML_CPU_FEATS_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
# Disable LTO for the feature detection code to prevent cross-module optimization
# from inlining architecture-specific instructions into the score function.
# Without this, LTO can cause SIGILL when loading backends on older CPUs
# (e.g., loading power10 backend on power9 crashes before feature check runs).
target_compile_options(${GGML_CPU_FEATS_NAME} PRIVATE -fno-lto)
target_link_libraries(${cpu_name} PRIVATE ${GGML_CPU_FEATS_NAME})
endfunction()
@@ -569,27 +574,24 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
cmake_policy(SET CMP0135 NEW)
endif()
# TODO: Use FetchContent_MakeAvailable with EXCLUDE_FROM_ALL after bumping minimum CMake version to 3.28+
# Using FetchContent_Populate instead to avoid EXCLUDE_FROM_ALL which requires CMake 3.28
FetchContent_Declare(KleidiAI_Download
URL ${KLEIDIAI_DOWNLOAD_URL}
DOWNLOAD_EXTRACT_TIMESTAMP NEW
URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5})
FetchContent_MakeAvailable(KleidiAI_Download)
FetchContent_GetProperties(KleidiAI_Download
SOURCE_DIR KLEIDIAI_SRC
POPULATED KLEIDIAI_POPULATED)
if (NOT KLEIDIAI_POPULATED)
message(FATAL_ERROR "KleidiAI source downloaded failed.")
FetchContent_Populate(KleidiAI_Download)
FetchContent_GetProperties(KleidiAI_Download SOURCE_DIR KLEIDIAI_SRC)
endif()
add_compile_definitions(GGML_USE_CPU_KLEIDIAI)
# Remove kleidiai target after fetching it
if (TARGET kleidiai)
set_target_properties(kleidiai PROPERTIES EXCLUDE_FROM_ALL TRUE)
endif()
list(APPEND GGML_CPU_SOURCES
ggml-cpu/kleidiai/kleidiai.cpp
ggml-cpu/kleidiai/kernels.cpp
-6
View File
@@ -171,15 +171,9 @@
#elif defined(__riscv)
// quants.c
#define quantize_row_q8_K_generic quantize_row_q8_K
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K
#define ggml_vec_dot_iq2_xs_q8_K_generic ggml_vec_dot_iq2_xs_q8_K
#define ggml_vec_dot_iq2_s_q8_K_generic ggml_vec_dot_iq2_s_q8_K
#define ggml_vec_dot_iq3_xxs_q8_K_generic ggml_vec_dot_iq3_xxs_q8_K
#define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K
#define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
+310
View File
@@ -3226,6 +3226,316 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
UNUSED(ncols_interleaved);
UNUSED(blocklen);
#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
if (svcntb() * 8 == 256) {
constexpr int q8_k_blocklen = 4;
const svuint8_t m4b_1 = svdup_n_u8(0x0f);
// 8 accumulators: 2 row pairs × 4 col pairs
svfloat32_t acc_f32_01, acc_f32_23, acc_f32_45, acc_f32_67;
uint32_t idx_arr[8] = { 0, 2, 4, 6, 1, 3, 5, 7 };
svbool_t pg = svptrue_pat_b32(SV_VL8);
svuint32_t idx = svld1(pg, idx_arr);
static const uint32_t idx_data[8] = {0, 4, 2, 6, 1, 5, 3, 7};
svuint32_t idx1 = svld1_u32(svptrue_b32(), idx_data);
for (int y = 0; y < nr / q8_k_blocklen; y++) {
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
acc_f32_01 = svdup_n_f32(0);
acc_f32_23 = svdup_n_f32(0);
acc_f32_45 = svdup_n_f32(0);
acc_f32_67 = svdup_n_f32(0);
for (int b = 0; b < nb; b++) {
// bsums pairs belongs to the same q8_k subblock
// 64 elemnts loaded and made sum of 0-7 and 8-15 sum || 16-23 and 24 - 31 sum
const int16x8_t bsums[4]{
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
};
int32_t bsums_arr32[4][8];
for (int q8_row = 0; q8_row < 4; q8_row++) {
int16x8_t v16 = bsums[q8_row];
// low 4
int32x4_t v32_lo = vmovl_s16(vget_low_s16(v16));
vst1q_s32(&bsums_arr32[q8_row][0], v32_lo);
// high 4
int32x4_t v32_hi = vmovl_s16(vget_high_s16(v16));
vst1q_s32(&bsums_arr32[q8_row][4], v32_hi);
}
svint32_t sb_acc_0 = svdup_n_s32(0);
svint32_t sb_acc_2 = svdup_n_s32(0);
svint32_t acc_00 = svdup_n_s32(0);
svint32_t acc_11 = svdup_n_s32(0);
svint32_t acc_22 = svdup_n_s32(0);
svint32_t acc_33 = svdup_n_s32(0);
svint32_t acc_44 = svdup_n_s32(0);
svint32_t acc_55 = svdup_n_s32(0);
svint32_t acc_66 = svdup_n_s32(0);
svint32_t acc_77 = svdup_n_s32(0);
svint32_t bias_acc_00 = svdup_n_s32(0);
svint32_t bias_acc_22 = svdup_n_s32(0);
svint32_t bias_acc_44 = svdup_n_s32(0);
svint32_t bias_acc_66 = svdup_n_s32(0);
for (int sb = 0; sb < QK_K / 64; sb++) {
// Need scales for the low and high nibbles
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
svint32_t block_scale_0, block_scale_1, block_scale_2, block_scale_3;
svint32_t q4sb_mins_0, q4sb_mins_1;
{
// 2-superblock I am working on
const int offset = sb * 24 + 0 * 12;
const uint8_t * scales_in = &q4_ptr[b].scales[offset];
const int offset1 = sb * 24 + 12;
const uint8_t * scales_in1 = &q4_ptr[b].scales[offset1];
constexpr uint32_t kmask1 = 0x3f3f3f3f;
constexpr uint32_t kmask2 = 0x0f0f0f0f;
constexpr uint32_t kmask3 = 0x03030303;
constexpr uint8_t scales_size = 12;
uint32_t sm[3];
memcpy(sm, scales_in, scales_size);
uint32_t sm1[3];
memcpy(sm1, scales_in1, scales_size);
const uint32_t mins_0_3 = sm[1] & kmask1;
const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4);
const uint32_t mins_0_3_1 = sm1[1] & kmask1;
const uint32_t mins_4_7_1 = ((sm1[2] >> 4) & kmask2) | (((sm1[1] >> 6) & kmask3) << 4);
svuint32_t mins_u32_temp = svzip1_u32(svdup_n_u32(mins_0_3), svdup_n_u32(mins_4_7));
svuint32_t mins_u32_temp_1 = svzip1_u32(svdup_n_u32(mins_0_3_1), svdup_n_u32(mins_4_7_1));
/* reinterpret u32 → u8 */
svuint8_t mins_u8 = svreinterpret_u8_u32(mins_u32_temp);
svuint8_t mins_u8_1 = svreinterpret_u8_u32(mins_u32_temp_1);
/* widen u8 → u16->u32 (lower half only) */
svuint32_t mins_u16 = svunpklo_u32(svunpklo_u16(mins_u8));
svuint32_t mins_u16_1 = svunpklo_u32(svunpklo_u16(mins_u8_1));
q4sb_mins_0 = svreinterpret_s32_u32(mins_u16);
q4sb_mins_1 = svreinterpret_s32_u32(mins_u16_1);
uint32_t scales_u32_0 = sm[0] & kmask1;
uint32_t scales_u32_1 = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4);
uint32_t scales_u32_2 = sm1[0] & kmask1;
uint32_t scales_u32_3 = (sm1[2] & kmask2) | (((sm1[0] >> 6) & kmask3) << 4);
svuint32_t S01 = svdup_n_u32(scales_u32_0);
svuint32_t S23 = svdup_n_u32(scales_u32_1);
svuint32_t R01 = svdup_n_u32(scales_u32_2);
svuint32_t R23 = svdup_n_u32(scales_u32_3);
svint8_t S01_b = svreinterpret_s8_u32(S01);
svint8_t S23_b = svreinterpret_s8_u32(S23);
svint8_t R01_b = svreinterpret_s8_u32(R01);
svint8_t R23_b = svreinterpret_s8_u32(R23);
svint32_t S01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S01_b, S01_b)));
svint32_t R01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R01_b, R01_b)));
svint32_t S23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S23_b, S23_b)));
svint32_t R23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R23_b, R23_b)));
block_scale_0 = svtbl_s32(svzip1_s32(S01_d, R01_d), idx);
block_scale_1 = svtbl_s32(svzip2_s32(S01_d, R01_d), idx);
block_scale_2 = svtbl_s32(svzip1_s32(S23_d, R23_d), idx);
block_scale_3 = svtbl_s32(svzip2_s32(S23_d, R23_d), idx);
}
const int8_t * q8_base_1 = q8_ptr[b].qs + sb * 256;
// Load 32-byte per row pair, 1 subblock each time
// predicate for activating higher lanes for 16 int8 elements
const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
// predicate for activating lower lanes for 16 int8 elements
const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16);
svint8_t q8_qs_0 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 0), svld1_s8(pl16, q8_base_1 + 112));
svint8_t q8_qs_2 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 32), svld1_s8(pl16, q8_base_1 + 144));
svint8_t q8_qs_4 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 64), svld1_s8(pl16, q8_base_1 + 176));
svint8_t q8_qs_6 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 96), svld1_s8(pl16, q8_base_1 + 208));
svint8_t q8_qs_1 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 16), svld1_s8(pl16, q8_base_1 + 128));
svint8_t q8_qs_3 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 48), svld1_s8(pl16, q8_base_1 + 160));
svint8_t q8_qs_5 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 80), svld1_s8(pl16, q8_base_1 + 192));
svint8_t q8_qs_7 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 112), svld1_s8(pl16, q8_base_1 + 224));
// Q4s columns iterated in pairs (01, 23, 45, 67)
for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
sb_acc_0 = svdup_n_s32(0);
sb_acc_2 = svdup_n_s32(0);
svuint8_t q4_qs_cp_00 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 0);
svuint8_t q4_qs_cp_01 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 64);
svuint8_t q4_qs_cp_02 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 128);
svuint8_t q4_qs_cp_03 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 192);
svint8_t q4_nibbles_00 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_00, m4b_1), 4));
svint8_t q4_nibbles_01 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_01, m4b_1), 4));
svint8_t q4_nibbles_02 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_02, m4b_1), 4));
svint8_t q4_nibbles_03 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_03, m4b_1), 4));
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_00, q8_qs_0);
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_01, q8_qs_2);
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_02, q8_qs_4);
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_03, q8_qs_6);
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_00, q8_qs_1);
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_01, q8_qs_3);
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_02, q8_qs_5);
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_03, q8_qs_7);
if(cp == 0) {
acc_00 = svmla_s32_m(svptrue_b32(), acc_00, sb_acc_0, block_scale_0);
acc_44 = svmla_s32_m(svptrue_b32(), acc_44, sb_acc_2, block_scale_0);
}
if(cp == 1) {
acc_11 = svmla_s32_m(svptrue_b32(), acc_11, sb_acc_0, block_scale_1);
acc_55 = svmla_s32_m(svptrue_b32(), acc_55, sb_acc_2, block_scale_1);
}
if(cp == 2) {
acc_22 = svmla_s32_m(svptrue_b32(), acc_22, sb_acc_0, block_scale_2);
acc_66 = svmla_s32_m(svptrue_b32(), acc_66, sb_acc_2, block_scale_2);
}
if(cp == 3) {
acc_33 = svmla_s32_m(svptrue_b32(), acc_33, sb_acc_0, block_scale_3);
acc_77 = svmla_s32_m(svptrue_b32(), acc_77, sb_acc_2, block_scale_3);
}
}
bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][0]), q4sb_mins_0);
bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][1]), q4sb_mins_1);
bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][2]), q4sb_mins_0);
bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][3]), q4sb_mins_1);
bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][4]), q4sb_mins_0);
bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][5]), q4sb_mins_1);
bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][6]), q4sb_mins_0);
bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][7]), q4sb_mins_1);
} // for sb
acc_00 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_00, svext_s32(acc_00, acc_00, 4));
acc_11 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_11, svext_s32(acc_11, acc_11, 4));
acc_22 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_22, svext_s32(acc_22, acc_22, 4));
acc_33 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_33, svext_s32(acc_33, acc_33, 4));
acc_44 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_44, svext_s32(acc_44, acc_44, 4));
acc_55 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_55, svext_s32(acc_55, acc_55, 4));
acc_66 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_66, svext_s32(acc_66, acc_66, 4));
acc_77 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_77, svext_s32(acc_77, acc_77, 4));
svint32_t reorder_acc_01 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_00, acc_11), svtrn1_s32(acc_22, acc_33)), idx1);
svint32_t reorder_acc_23 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_00, acc_11), svtrn2_s32(acc_22, acc_33)), idx1);
svint32_t reorder_acc_45 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_44, acc_55), svtrn1_s32(acc_66, acc_77)), idx1);
svint32_t reorder_acc_67 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_44, acc_55), svtrn2_s32(acc_66, acc_77)), idx1);
// Broadcast q8 scalar
svfloat32_t q8_d = svdup_f32(q8_ptr[b].d[0]);
svfloat32_t q4_dmin_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].dmin), svdup_f16(0)));
svfloat32_t q4_d_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].d), svdup_f16(0)));
svfloat32_t scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
svfloat32_t dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
acc_f32_01 = svmls_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_00), dmins1);
acc_f32_01 = svmla_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_01), scale1);
q8_d = svdup_f32(q8_ptr[b].d[1]);
scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
acc_f32_23 = svmls_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_22), dmins1);
acc_f32_23 = svmla_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_23), scale1);
q8_d = svdup_f32(q8_ptr[b].d[2]);
scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
acc_f32_45 = svmls_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_44), dmins1);
acc_f32_45 = svmla_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_45), scale1);
q8_d = svdup_f32(q8_ptr[b].d[3]);
scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
acc_f32_67 = svmls_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_66), dmins1);
acc_f32_67 = svmla_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_67), scale1);
} // for b
// With the previous reorder, the tile is already in the correct memory layout.
// Predicate for exactly 4 lanes
svbool_t pg4 = svptrue_pat_b32(SV_VL4);
for (int i = 0; i < q8_k_blocklen; i++) {
int row = y * q8_k_blocklen + i;
for (int j = 0; j < 2; j++) {
int col = x * ncols_interleaved + j * 4;
int offset = row * bs + col;
if (i == 0 && j == 0) {
// acc_f32_0 → lower half of acc_f32_01
svst1_f32(pg4, s + offset, acc_f32_01);
} else if (i == 0 && j == 1) {
// acc_f32_1 → upper half of acc_f32_01
svst1_f32(pg4, s + offset, svext_f32(acc_f32_01, acc_f32_01, 4));
} else if (i == 1 && j == 0) {
// acc_f32_2
svst1_f32(pg4, s + offset, acc_f32_23);
} else if (i == 1 && j == 1) {
// acc_f32_3
svst1_f32(pg4, s + offset, svext_f32(acc_f32_23, acc_f32_23, 4));
} else if (i == 2 && j == 0) {
// acc_f32_4
svst1_f32(pg4, s + offset, acc_f32_45);
} else if (i == 2 && j == 1) {
// acc_f32_5
svst1_f32(pg4, s + offset, svext_f32(acc_f32_45, acc_f32_45, 4));
} else if (i == 3 && j == 0) {
// acc_f32_6
svst1_f32(pg4, s + offset, acc_f32_67);
} else if (i == 3 && j == 1) {
// acc_f32_7
svst1_f32(pg4, s + offset, svext_f32(acc_f32_67, acc_f32_67, 4));
}
}
}
} // for x
} // for y
return;
}
#endif // SVE compile-time end
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
constexpr int q8_k_blocklen = 4;
const uint8x16_t m4b = vdupq_n_u8(0x0f);
+770
View File
@@ -1954,3 +1954,773 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
#endif
}
static const uint8_t sign_gather_indices_arr[64] = {
0,0,0,0,0,0,0,0, 1,1,1,1,1,1,1,1, 2,2,2,2,2,2,2,2, 3,3,3,3,3,3,3,3,
4,4,4,4,4,4,4,4, 5,5,5,5,5,5,5,5, 6,6,6,6,6,6,6,6, 7,7,7,7,7,7,7,7
};
static const uint8_t sign_bit_masks_arr[64] = {
1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128,
1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128
};
static void ggml_vec_dot_iq2_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(n % QK_K == 0);
UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs);
const block_iq2_s * GGML_RESTRICT x = vx;
const block_q8_K * GGML_RESTRICT y = vy;
const int nb = n / QK_K;
const uint64_t * grid64 = (const uint64_t *)iq2s_grid;
// --- Pre-load Constants ---
uint16_t gather_qh_arr[8] = {0, 0, 0, 0, 1, 1, 1, 1};
vuint16mf2_t v_gather_qh = __riscv_vle16_v_u16mf2(gather_qh_arr, 8);
uint16_t shift_qh_arr[8] = {11, 9, 7, 5, 11, 9, 7, 5};
vuint16mf2_t v_shift_qh = __riscv_vle16_v_u16mf2(shift_qh_arr, 8);
// Constants for sign extraction
vuint8m2_t v_sign_gather_indices = __riscv_vle8_v_u8m2(sign_gather_indices_arr, 64);
vuint8m2_t v_sign_masks = __riscv_vle8_v_u8m2(sign_bit_masks_arr, 64);
float sumf = 0.0f;
for (int i = 0; i < nb; ++i) {
const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
const uint8_t * GGML_RESTRICT qs = x[i].qs;
const uint8_t * GGML_RESTRICT qh = x[i].qh;
const uint8_t * GGML_RESTRICT scales = x[i].scales;
const int8_t * GGML_RESTRICT q8 = y[i].qs;
const uint8_t * signs_ptr = qs + 32;
float sum_block = 0.0f;
for (int ib = 0; ib < 4; ++ib) {
// Combine low + high bits
vuint8mf4_t v_qs_u8 = __riscv_vle8_v_u8mf4(qs, 8);
qs += 8;
uint16_t qh_val;
memcpy(&qh_val, qh, 2);
qh += 2;
vuint8mf8_t v_qh_raw = __riscv_vle8_v_u8mf8((const uint8_t*)&qh_val, 2);
vuint16mf4_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16mf4(v_qh_raw, 2);
vuint16mf2_t v_qh_u16_ext = __riscv_vlmul_ext_v_u16mf4_u16mf2(v_qh_u16);
vuint16mf2_t v_qh_expanded = __riscv_vrgather_vv_u16mf2(v_qh_u16_ext, v_gather_qh, 8);
v_qh_expanded = __riscv_vsll_vv_u16mf2(v_qh_expanded, v_shift_qh, 8);
// Mask: We want bits 11-12. 0x1800 = 0001 1000 0000 0000
v_qh_expanded = __riscv_vand_vx_u16mf2(v_qh_expanded, 0x1800, 8);
vuint16mf2_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16mf2(v_qs_u8, 8);
// Multiply by 8 to get byte offset, instead of element offset
v_qs_u16 = __riscv_vsll_vx_u16mf2(v_qs_u16, 3, 8);
vuint16mf2_t v_grid_offsets = __riscv_vor_vv_u16mf2(v_qs_u16, v_qh_expanded, 8);
// Lookup Grid using Byte Offsets
vuint64m2_t v_grid_vals = __riscv_vluxei16_v_u64m2(grid64, v_grid_offsets, 8);
vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u64m2_u8m2(v_grid_vals);
vint8m2_t v_grid_i8 = __riscv_vreinterpret_v_u8m2_i8m2(v_grid_u8);
// Load signs and generate sign mask
vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs_ptr, 8);
signs_ptr += 8;
vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw);
vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 64);
vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 64);
vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 64);
vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 64);
q8 += 64;
vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 64);
vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_grid_i8, v_q8_signed, 64);
vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1);
int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
__riscv_vget_v_i16m4_i16m1(v_dot, 0), v_zero, 16));
int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
__riscv_vget_v_i16m4_i16m1(v_dot, 1), v_zero, 16));
int32_t s2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
__riscv_vget_v_i16m4_i16m1(v_dot, 2), v_zero, 16));
int32_t s3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
__riscv_vget_v_i16m4_i16m1(v_dot, 3), v_zero, 16));
uint8_t sc0 = scales[0];
uint8_t sc1 = scales[1];
scales += 2;
sum_block += s0 * (2 * (sc0 & 0xF) + 1);
sum_block += s1 * (2 * (sc0 >> 4) + 1);
sum_block += s2 * (2 * (sc1 & 0xF) + 1);
sum_block += s3 * (2 * (sc1 >> 4) + 1);
}
sumf += sum_block * combined_scale;
}
*s = 0.125f * sumf;
}
static void ggml_vec_dot_iq2_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(n % QK_K == 0);
UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs);
const block_iq2_s * GGML_RESTRICT x = vx;
const block_q8_K * GGML_RESTRICT y = vy;
const int nb = n / QK_K;
const uint64_t * grid64 = (const uint64_t *)iq2s_grid;
// Pre-load Constants
vuint8m2_t v_ids = __riscv_vid_v_u8m2(32);
vuint8m2_t v_sign_gather_indices = __riscv_vsrl_vx_u8m2(v_ids, 3, 32);
vuint8m2_t v_ones = __riscv_vmv_v_x_u8m2(1, 32);
vuint8m2_t v_shift_amts = __riscv_vand_vx_u8m2(v_ids, 7, 32);
vuint8m2_t v_sign_masks = __riscv_vsll_vv_u8m2(v_ones, v_shift_amts, 32);
uint16_t shift_qh_arr[4] = {11, 9, 7, 5};
vuint16mf2_t v_shift_qh = __riscv_vle16_v_u16mf2(shift_qh_arr, 4);
float sumf = 0.0f;
for (int i = 0; i < nb; ++i) {
const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
const uint8_t * GGML_RESTRICT qs = x[i].qs;
const uint8_t * GGML_RESTRICT qh = x[i].qh;
const uint8_t * GGML_RESTRICT scales = x[i].scales;
const int8_t * GGML_RESTRICT q8 = y[i].qs;
const uint8_t * signs_ptr = qs + 32;
float sum_block = 0.0f;
for (int ib = 0; ib < 8; ++ib) {
// Load Low Bits [4 bytes]
vuint8mf4_t v_qs_u8 = __riscv_vle8_v_u8mf4(qs, 4);
qs += 4;
// Load 1 byte. It contains bits for 4 mini-blocks.
uint8_t qh_val = *qh++;
// Combine Low + High bits of 10bit indices
vuint8mf4_t v_qh_raw = __riscv_vmv_v_x_u8mf4(qh_val, 4);
vuint16mf2_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16mf2(v_qh_raw, 4);
vuint16mf2_t v_qh_mf2 = __riscv_vsll_vv_u16mf2(v_qh_u16, v_shift_qh, 4);
v_qh_mf2 = __riscv_vand_vx_u16mf2(v_qh_mf2, 0x1800, 4);
vuint16mf2_t v_qs_u16_mf2 = __riscv_vwcvtu_x_x_v_u16mf2(v_qs_u8, 4);
vuint16mf2_t v_qs_u16 = __riscv_vsll_vx_u16mf2(v_qs_u16_mf2, 3, 4);
vuint16mf2_t v_grid_offsets = __riscv_vor_vv_u16mf2(v_qs_u16, v_qh_mf2, 4);
// Lookup Grid
vint8m2_t v_grid_i8 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(__riscv_vluxei16_v_u64m2(grid64, v_grid_offsets, 4)));
vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs_ptr, 4);
signs_ptr += 4;
vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw);
vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 32);
// generating sign mask
vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 32);
vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 32);
vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 32);
q8 += 32;
// apply signs
vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative,v_q8, v_q8, 0, 32);
vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_grid_i8, v_q8_signed, 32);
// Reduction
vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1);
// Reduce 0-15 (First Half)
int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(
__riscv_vget_v_i16m4_i16m2(v_dot, 0), v_zero, 16));
// Reduce 16-31 (Second Half)
int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(
__riscv_vget_v_i16m4_i16m2(v_dot, 1), v_zero, 16));
// Apply sub Scales
uint8_t sc = *scales++;
sum_block += s0 * (2 * (sc & 0xF) + 1);
sum_block += s1 * (2 * (sc >> 4) + 1);
}
sumf += sum_block * combined_scale;
}
*s = 0.125f * sumf;
}
void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
#if defined __riscv_v_intrinsic
switch (__riscv_vlenb() * 8) {
case 128:
ggml_vec_dot_iq2_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc);
break;
case 256:
ggml_vec_dot_iq2_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
break;
default:
ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
break;
}
#else
ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
#endif
}
static void ggml_vec_dot_iq3_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(n % QK_K == 0);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_iq3_s * GGML_RESTRICT x = vx;
const block_q8_K * GGML_RESTRICT y = vy;
const int nb = n / QK_K;
const uint64_t * grid64 = (const uint64_t *)iq3s_grid;
// --- Pre-load Constants ---
const uint16_t qh_bit_shifts_arr[16] = {
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
};
vuint8m2_t v_sign_gather_indices = __riscv_vle8_v_u8m2(sign_gather_indices_arr, 64);
vuint8m2_t v_sign_masks = __riscv_vle8_v_u8m2(sign_bit_masks_arr, 64);
vuint16m1_t v_qh_shifts = __riscv_vle16_v_u16m1(qh_bit_shifts_arr, 16);
float sumf = 0.0f;
for (int i = 0; i < nb; ++i) {
const float d = GGML_CPU_FP16_TO_FP32(x[i].d);
const float combined_scale = d * y[i].d;
const uint8_t * GGML_RESTRICT qs = x[i].qs;
const uint8_t * GGML_RESTRICT qh = x[i].qh;
const uint8_t * GGML_RESTRICT scales = x[i].scales;
const uint8_t * GGML_RESTRICT signs = x[i].signs;
const int8_t * GGML_RESTRICT q8 = y[i].qs;
float sum_block = 0.0f;
// Loop: Process 64 weights (16 mini-blocks of 4) per iteration
for (int ib = 0; ib < 4; ++ib) {
vuint8mf2_t v_qs_u8 = __riscv_vle8_v_u8mf2(qs, 16);
qs += 16;
uint16_t qh_val;
memcpy(&qh_val, qh, 2);
qh += 2;
vuint16m1_t v_qh_val = __riscv_vmv_v_x_u16m1(qh_val, 16);
// Extract bits: (qh >> i) & 1
v_qh_val = __riscv_vsrl_vv_u16m1(v_qh_val, v_qh_shifts, 16);
v_qh_val = __riscv_vand_vx_u16m1(v_qh_val, 1, 16);
vuint16m1_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16m1(v_qs_u8, 16);
v_qs_u16 = __riscv_vsll_vx_u16m1(v_qs_u16, 2, 16);
v_qh_val = __riscv_vsll_vx_u16m1(v_qh_val, 10, 16);
vuint16m1_t v_grid_offsets = __riscv_vor_vv_u16m1(v_qs_u16, v_qh_val, 16);
// Grid value is 4xuint8
vuint32m2_t v_grid_packed = __riscv_vluxei16_v_u32m2((const uint32_t *)grid64, v_grid_offsets, 16);
vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u32m2_u8m2(v_grid_packed);
vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs, 8);
signs += 8;
// Generate sign mask
vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw);
vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 64);
vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 64);
vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 64);
vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 64);
q8 += 64;
// Apply Signs
vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 64);
vint16m4_t v_dot = __riscv_vwmulsu_vv_i16m4(v_q8_signed, v_grid_u8, 64);
// Reduction
vint16m2_t v_dot_lo = __riscv_vget_v_i16m4_i16m2(v_dot, 0);
vint16m2_t v_dot_hi = __riscv_vget_v_i16m4_i16m2(v_dot, 1);
vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1);
int32_t s_lo = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(v_dot_lo, v_zero, 32));
int32_t s_hi = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(v_dot_hi, v_zero, 32));
// Apply sub-scales
uint8_t sc_byte = *scales++;
int sc_lo = (sc_byte & 0xF) * 2 + 1;
int sc_hi = (sc_byte >> 4) * 2 + 1;
sum_block += s_lo * sc_lo + s_hi * sc_hi;
}
sumf += sum_block * combined_scale;
}
*s = 0.125f * sumf;
}
void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
#if defined __riscv_v_intrinsic
switch (__riscv_vlenb() * 8) {
case 256:
ggml_vec_dot_iq3_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
break;
default:
ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
break;
}
#else
ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
#endif
}
static void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_tq1_0 * GGML_RESTRICT x = vx;
const block_q8_K * GGML_RESTRICT y = vy;
const int nb = n / QK_K;
float sumf = 0.0f;
uint8_t pow[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27};
for (int i = 0; i < nb; i++) {
// First loop.
vint32m4_t suml1;
{
const int vl = 32;
vuint8m1_t tq = __riscv_vle8_v_u8m1(x[i].qs, vl);
vuint16m2_t tq0 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(tq, 3, vl), 8, vl);
vuint16m2_t tq1 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 3, vl), 3, vl), 8, vl);
vuint16m2_t tq2 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 9, vl), 3, vl), 8, vl);
vuint16m2_t tq3 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 27, vl), 3, vl), 8, vl);
vuint16m2_t tq4 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 81, vl), 3, vl), 8, vl);
vint16m2_t q80 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 0, vl), vl);
vint16m2_t q81 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 32, vl), vl);
vint16m2_t q82 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 64, vl), vl);
vint16m2_t q83 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 96, vl), vl);
vint16m2_t q84 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 128, vl), vl);
vint16m2_t sum0 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq0, 1, vl)), q80, vl);
vint16m2_t sum1 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq1, 1, vl)), q81, vl);
vint16m2_t sum2 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq2, 1, vl)), q82, vl);
vint16m2_t sum3 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq3, 1, vl)), q83, vl);
vint16m2_t sum4 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq4, 1, vl)), q84, vl);
vint32m4_t sumi0 = __riscv_vwadd_vv_i32m4(sum0, sum1, vl);
vint32m4_t sumi1 = __riscv_vwadd_vv_i32m4(sum2, sum3, vl);
suml1 = __riscv_vadd_vv_i32m4(__riscv_vwcvt_x_x_v_i32m4(sum4, vl), __riscv_vadd_vv_i32m4(sumi0, sumi1, vl), vl);
}
// Second loop.
vint32m2_t suml2;
{
const int vl = 16;
vuint8mf2_t tq = __riscv_vle8_v_u8mf2(x[i].qs + 32, vl);
vuint16m1_t tq0 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(tq, 3 * 1, vl), 8, vl);
vuint16m1_t tq1 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 3, vl), 3, vl), 8, vl);
vuint16m1_t tq2 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 9, vl), 3, vl), 8, vl);
vuint16m1_t tq3 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 27, vl), 3, vl), 8, vl);
vuint16m1_t tq4 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 81, vl), 3, vl), 8, vl);
vint16m1_t q80 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 160, vl), vl);
vint16m1_t q81 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 176, vl), vl);
vint16m1_t q82 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 192, vl), vl);
vint16m1_t q83 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 208, vl), vl);
vint16m1_t q84 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 224, vl), vl);
vint16m1_t sum0 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq0, 1, vl)), q80, vl);
vint16m1_t sum1 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq1, 1, vl)), q81, vl);
vint16m1_t sum2 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq2, 1, vl)), q82, vl);
vint16m1_t sum3 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq3, 1, vl)), q83, vl);
vint16m1_t sum4 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq4, 1, vl)), q84, vl);
vint32m2_t sumi0 = __riscv_vwadd_vv_i32m2(sum0, sum1, vl);
vint32m2_t sumi1 = __riscv_vwadd_vv_i32m2(sum2, sum3, vl);
suml2 = __riscv_vadd_vv_i32m2(__riscv_vwcvt_x_x_v_i32m2(sum4, vl), __riscv_vadd_vv_i32m2(sumi0, sumi1, vl), vl);
}
// Third loop.
vint32m2_t suml3;
{
const int vl = 16;
uint32_t qh;
memcpy(&qh, &x[i].qh[0], 4);
// Prevent fusion with vmv.
__asm__ __volatile__("" : "+r"(qh));
vuint8mf2_t tq = __riscv_vreinterpret_v_u32mf2_u8mf2(__riscv_vmv_v_x_u32mf2(qh, vl / 4));
vuint8mf2_t p = __riscv_vle8_v_u8mf2(pow, vl);
vuint16m1_t tq0 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vv_u8mf2(tq, p, vl), 3, vl), 8, vl);
vint16m1_t q80 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 240, vl), vl);
vint16m1_t sum0 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq0, 1, vl)), q80, vl);
suml3 = __riscv_vwcvt_x_x_v_i32m2(sum0, vl);
}
vint32m2_t sumb = __riscv_vadd_vv_i32m2(__riscv_vget_v_i32m4_i32m2(suml1, 0), __riscv_vget_v_i32m4_i32m2(suml1, 1), 16);
sumb = __riscv_vadd_vv_i32m2(sumb, suml2, 16);
sumb = __riscv_vadd_vv_i32m2(sumb, suml3, 16);
vint32m1_t sum = __riscv_vredsum_vs_i32m2_i32m1(sumb, __riscv_vmv_v_x_i32m1(0, 1), 16);
sumf += __riscv_vmv_x_s_i32m1_i32(sum) * y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
}
*s = sumf;
}
void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
#if defined __riscv_v_intrinsic
switch (__riscv_vlenb() * 8) {
case 256:
ggml_vec_dot_tq1_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
break;
default:
ggml_vec_dot_tq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
break;
}
#else
ggml_vec_dot_tq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
#endif
}
static void ggml_vec_dot_tq2_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_tq2_0 * GGML_RESTRICT x = vx;
const block_q8_K * GGML_RESTRICT y = vy;
const int nb = n / QK_K;
float sumf = 0.0f;
for (int i = 0; i < nb; ++i) {
int32_t sumi = 0;
for (size_t j = 0; j < sizeof(x[0].qs); j += 32) {
const int8_t * py0 = &y[i].qs[j * 4 + 0 * 32];
const int8_t * py1 = &y[i].qs[j * 4 + 1 * 32];
const int8_t * py2 = &y[i].qs[j * 4 + 2 * 32];
const int8_t * py3 = &y[i].qs[j * 4 + 3 * 32];
const uint8_t* px = &x[i].qs[j];
size_t vlmax_16m2 = __riscv_vsetvl_e16m2(32);
vint16m2_t vacc16 = __riscv_vmv_v_x_i16m2(0, vlmax_16m2);
size_t vl = __riscv_vsetvl_e8m1(32);
vuint8m1_t vx_u8 = __riscv_vle8_v_u8m1(px, vl);
vint8m1_t vy0 = __riscv_vle8_v_i8m1(py0 , vl);
vint8m1_t vy1 = __riscv_vle8_v_i8m1(py1, vl);
vint8m1_t vy2 = __riscv_vle8_v_i8m1(py2, vl);
vint8m1_t vy3 = __riscv_vle8_v_i8m1(py3, vl);
// l=0 (bits 1:0)
vuint8m1_t t0 = __riscv_vand_vx_u8m1(vx_u8, 0x03, vl);
vint8m1_t vq0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t0), 1, vl);
// l=1 (bits 3:2)
vuint8m1_t t1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vx_u8, 2, vl), 0x03, vl);
vint8m1_t vq1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t1), 1, vl);
// l=2 (bits 5:4)
vuint8m1_t t2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vx_u8, 4, vl), 0x03, vl);
vint8m1_t vq2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t2), 1, vl);
// l=3 (bits 7:6)
vuint8m1_t t3 = __riscv_vsrl_vx_u8m1(vx_u8, 6, vl); // No final AND needed as vsrl shifts in zeros
vint8m1_t vq3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t3), 1, vl);
// 4. Multiply and accumulate
vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq0, vy0, vl);
vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq1, vy1, vl);
vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq2, vy2, vl);
vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq3, vy3, vl);
vlmax_16m2 = __riscv_vsetvl_e16m2(32);
vint32m1_t vzero32 = __riscv_vmv_v_x_i32m1(0, 1);
vint32m1_t vred32 = __riscv_vwredsum_vs_i16m2_i32m1(vacc16, vzero32, vlmax_16m2);
sumi += __riscv_vmv_x_s_i32m1_i32(vred32);
}
const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
sumf += (float)sumi * d;
}
*s = sumf;
}
void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
#if defined __riscv_v_intrinsic
switch (__riscv_vlenb() * 8) {
case 256:
ggml_vec_dot_tq2_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
break;
default:
ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
break;
}
#else
ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
#endif
}
static void ggml_vec_dot_iq1_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_iq1_s * GGML_RESTRICT x = vx;
const block_q8_K * GGML_RESTRICT y = vy;
const int nb = n / QK_K;
float sumf = 0;
for (int i = 0; i < nb; ++i) {
// Load qh once for the entire superblock.
vuint16mf2_t qh = __riscv_vle16_v_u16mf2(x[i].qh, 8);
// Calculate ls.
vuint16mf2_t temp = __riscv_vsrl_vx_u16mf2(qh, 12, 8);
temp = __riscv_vand_vx_u16mf2(temp, 7, 8);
vint32m1_t ls = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vwmulu_vx_u32m1(temp, 2, 8));
ls = __riscv_vadd_vx_i32m1(ls, 1, 8);
// Calculate delta.
vbool32_t mask = __riscv_vmseq_vx_u16mf2_b32(__riscv_vand_vx_u16mf2(qh, 0x8000, 8), 0, 8);
vint32m1_t delta_neg = __riscv_vmv_v_x_i32m1(-1, 8);
vint32m1_t delta_pos = __riscv_vmv_v_x_i32m1(1, 8);
vint32m1_t delta = __riscv_vmerge_vvm_i32m1(delta_neg, delta_pos, mask, 8);
// Load qs.
vuint8m1_t qs = __riscv_vle8_v_u8m1(x[i].qs, 32);
// Prepare the indices.
const uint64_t shift = 0x0009000600030000;
vuint16m2_t qh_shift = __riscv_vreinterpret_v_u64m2_u16m2(__riscv_vmv_v_x_u64m2(shift, 8));
vuint16m2_t qh_gather_index = __riscv_vreinterpret_v_i16m2_u16m2(
__riscv_vdiv_vx_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vid_v_u16m2(32)), 4, 32));
vuint16m2_t qh_ext = __riscv_vlmul_ext_v_u16m1_u16m2(__riscv_vlmul_ext_v_u16mf2_u16m1(qh));
vuint16m2_t qh_index = __riscv_vrgather_vv_u16m2(qh_ext, qh_gather_index, 32);
qh_index = __riscv_vsrl_vv_u16m2(qh_index, qh_shift, 32);
qh_index = __riscv_vand_vx_u16m2(qh_index, 7, 32);
qh_index = __riscv_vsll_vx_u16m2(qh_index, 8, 32);
qh_index = __riscv_vor_vv_u16m2(qh_index, __riscv_vzext_vf2_u16m2(qs, 32), 32);
vuint16m2_t index = __riscv_vsll_vx_u16m2(qh_index, 3, 32);
// Final lsums.
int32_t lsums_s[8];
vint32m1_t one_scalar = __riscv_vmv_v_x_i32m1(0, 1);
// Sub-blocks 1-4
{
vuint16m1_t grid_index0 = __riscv_vget_v_u16m2_u16m1(index, 0);
vint8m4_t grid0 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index0, 16));
vint8m4_t q80 = __riscv_vle8_v_i8m4(y[i].qs, 128);
vint16m8_t lsum0 = __riscv_vwmul_vv_i16m8(grid0, q80, 128);
lsums_s[0] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 0), one_scalar, 32));
lsums_s[1] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 1), one_scalar, 32));
lsums_s[2] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 2), one_scalar, 32));
lsums_s[3] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 3), one_scalar, 32));
}
__asm__ __volatile__("" ::: "memory");
// Sub-blocks 5-8
{
vuint16m1_t grid_index1 = __riscv_vget_v_u16m2_u16m1(index, 1);
vint8m4_t grid1 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index1, 16));
vint8m4_t q81 = __riscv_vle8_v_i8m4(&y[i].qs[128], 128);
vint16m8_t lsum1 = __riscv_vwmul_vv_i16m8(grid1, q81, 128);
lsums_s[4] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 0), one_scalar, 32));
lsums_s[5] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 1), one_scalar, 32));
lsums_s[6] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 2), one_scalar, 32));
lsums_s[7] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 3), one_scalar, 32));
}
__asm__ __volatile__("" ::: "memory");
vint32m1_t lsums = __riscv_vle32_v_i32m1(&lsums_s[0], 8);
// Calculate the bsums.
vint16m1_t bsums_0 = __riscv_vle16_v_i16m1(y[i].bsums, 16);
const vuint32m1_t bsums_i32 = __riscv_vreinterpret_v_u16m1_u32m1(__riscv_vreinterpret_v_i16m1_u16m1(bsums_0));
const vint16mf2_t bsums_i32_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(bsums_i32, 0, 8));
const vint16mf2_t bsums_i32_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(bsums_i32, 16, 8));
const vint32m1_t bsums = __riscv_vwadd_vv_i32m1(bsums_i32_0, bsums_i32_1, 8);
// Accumulation.
vint32m1_t sumi_v = __riscv_vmul_vv_i32m1(ls, lsums, 8);
vint32m1_t sumi1_v = __riscv_vmul_vv_i32m1(__riscv_vmul_vv_i32m1(ls, delta, 8), bsums, 8);
// Update sumf.
int sumi = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m1_i32m1(sumi_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8));
int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m1_i32m1(sumi1_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8));
sumf += GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1);
}
*s = sumf;
}
void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
#if defined __riscv_v_intrinsic
switch (__riscv_vlenb() * 8) {
case 256:
ggml_vec_dot_iq1_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
break;
default:
ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
break;
}
#else
ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
#endif
}
static void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_iq1_m * GGML_RESTRICT x = vx;
const block_q8_K * GGML_RESTRICT y = vy;
const int nb = n / QK_K;
iq1m_scale_t scale;
float sumf = 0.0f;
for (int i = 0; i < nb; ++i) {
const int8_t * q8 = y[i].qs;
const uint8_t * qs = x[i].qs;
const uint8_t * qh = x[i].qh;
const uint16_t * sc = (const uint16_t *)x[i].scales;
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
// Accumulators.
vint32m2_t acc1 = __riscv_vmv_v_x_i32m2(0, 16);
vint32m2_t acc2 = __riscv_vmv_v_x_i32m2(0, 16);
// We process 4 sub-blocks together.
for (int ib = 0; ib < QK_K/128; ib++) {
// Load qh for 4 sub-blocks.
const vuint8mf4_t qh_8 = __riscv_vle8_v_u8mf4(qh, 8);
const vuint16mf2_t qh_16_lo = __riscv_vzext_vf2_u16mf2(qh_8, 8);
const vuint16mf2_t qh_16_hi = __riscv_vsll_vx_u16mf2(qh_16_lo, 8, 8);
const vuint16m1_t qhb = __riscv_vzext_vf2_u16m1(
__riscv_vreinterpret_v_u16mf2_u8mf2(__riscv_vor_vv_u16mf2(qh_16_lo, qh_16_hi, 8)), 16);
qh += 8;
// Prepare grid indices.
const vuint16m1_t qsb = __riscv_vzext_vf2_u16m1(__riscv_vle8_v_u8mf2(&qs[0], 16), 16);
const vuint16m1_t shift = __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00040008, 8));
vuint16m1_t index = __riscv_vor_vv_u16m1(qsb, __riscv_vand_vx_u16m1(__riscv_vsll_vv_u16m1(qhb, shift, 16), 0x700, 16), 16);
index = __riscv_vsll_vx_u16m1(index, 3, 16);
qs += 16;
// Load the grid.
const vint8m4_t iq1b = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vreinterpret_v_u64m4_i64m4(
__riscv_vluxei16_v_u64m4(iq1s_grid, index, 16)));
// Prepare the deltas.
const vbool16_t mask = __riscv_vmsgtu_vx_u16m1_b16(
__riscv_vand_vv_u16m1(qhb, __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00800008, 8)), 16), 0, 16);
const vint64m4_t delta_pos = __riscv_vmv_v_x_i64m4(0x0101010101010101, 16);
const vint64m4_t delta_neg = __riscv_vmv_v_x_i64m4(0xffffffffffffffff, 16);
const vint8m4_t delta = __riscv_vreinterpret_v_i64m4_i8m4(
__riscv_vmerge_vvm_i64m4(delta_pos, delta_neg, mask, 16));
// Load q8 for sub-blocks.
const vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 128);
q8 += 128;
// Calculate the lsums.
const vint16m8_t lsum1 = __riscv_vwmul_vv_i16m8(iq1b, q8b, 128);
const vint16m8_t lsum2 = __riscv_vwmul_vv_i16m8(delta, q8b, 128);
// Prepare the scales.
const int16_t ls_0_0 = 2*((sc[0] >> 0) & 0x7) + 1;
const int16_t ls_0_1 = 2*((sc[0] >> 3) & 0x7) + 1;
const int16_t ls_1_0 = 2*((sc[0] >> 6) & 0x7) + 1;
const int16_t ls_1_1 = 2*((sc[0] >> 9) & 0x7) + 1;
const int16_t ls_2_0 = 2*((sc[1] >> 0) & 0x7) + 1;
const int16_t ls_2_1 = 2*((sc[1] >> 3) & 0x7) + 1;
const int16_t ls_3_0 = 2*((sc[1] >> 6) & 0x7) + 1;
const int16_t ls_3_1 = 2*((sc[1] >> 9) & 0x7) + 1;
sc += 2;
// Accumulate in acc0 and acc1 for each sub-block.
acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_0_0, __riscv_vget_v_i16m8_i16m1(lsum1, 0), 16);
acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_0_1, __riscv_vget_v_i16m8_i16m1(lsum1, 1), 16);
acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_0_0, __riscv_vget_v_i16m8_i16m1(lsum2, 0), 16);
acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_0_1, __riscv_vget_v_i16m8_i16m1(lsum2, 1), 16);
//
acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_1_0, __riscv_vget_v_i16m8_i16m1(lsum1, 2), 16);
acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_1_1, __riscv_vget_v_i16m8_i16m1(lsum1, 3), 16);
acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_1_0, __riscv_vget_v_i16m8_i16m1(lsum2, 2), 16);
acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_1_1, __riscv_vget_v_i16m8_i16m1(lsum2, 3), 16);
//
acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_2_0, __riscv_vget_v_i16m8_i16m1(lsum1, 4), 16);
acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_2_1, __riscv_vget_v_i16m8_i16m1(lsum1, 5), 16);
acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_2_0, __riscv_vget_v_i16m8_i16m1(lsum2, 4), 16);
acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_2_1, __riscv_vget_v_i16m8_i16m1(lsum2, 5), 16);
//
acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_3_0, __riscv_vget_v_i16m8_i16m1(lsum1, 6), 16);
acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_3_1, __riscv_vget_v_i16m8_i16m1(lsum1, 7), 16);
acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_3_0, __riscv_vget_v_i16m8_i16m1(lsum2, 6), 16);
acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_3_1, __riscv_vget_v_i16m8_i16m1(lsum2, 7), 16);
}
// Reduce and accumulate in `sumf`.
vint32m1_t one = __riscv_vmv_v_x_i32m1(0, 1);
int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc1, one, 16));
int sumi2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc2, one, 16));
sumf += y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16) * (sumi1 + IQ1M_DELTA * sumi2);
}
*s = sumf;
}
void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
#if defined __riscv_v_intrinsic
switch (__riscv_vlenb() * 8) {
case 256:
ggml_vec_dot_iq1_m_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
break;
default:
ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
break;
}
#else
ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
#endif
}
+2 -2
View File
@@ -6,8 +6,8 @@
#include "ggml-impl.h"
#include "simd-mappings.h"
#define GGML_FA_TILE_Q 32
#define GGML_FA_TILE_KV 16
#define GGML_FA_TILE_Q 64
#define GGML_FA_TILE_KV 64
#ifdef __cplusplus
+12 -4
View File
@@ -2874,8 +2874,8 @@ struct ggml_cplan ggml_graph_plan(
const int64_t DV = node->src[2]->ne[0];
// Tiled flash attention scratch (tile sizes defined in common.h)
// Per-thread: Q_q + KQ + mask + VKQ32 + V32 + padding
size_t prefill = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks;
// Per-thread: Q_q + KQ + mask + VKQ32 + V32 + K_f32 + padding
size_t prefill = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV + GGML_FA_TILE_KV*DK)*n_tasks;
// Decode path: n_kv_chunks = n_tasks (one chunk per thread)
// Per-thread: VKQ accmulator (DV), partial M, partial S + intra-thread scratch for V, Q and VKQ
@@ -2947,7 +2947,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
/*.use_ref =*/ cplan->use_ref,
};
GGML_PRINT_DEBUG("thread #%d compute-start cplan %p last-graph %d \n", state->ith, cplan, state->last_graph);
#ifdef GGML_USE_OPENMP
GGML_PRINT_DEBUG("thread #%d compute-start cplan %p\n", state->ith, (const void *)cplan);
#else
GGML_PRINT_DEBUG("thread #%d compute-start cplan %p last-graph %d\n", state->ith, (const void *)cplan, state->last_graph);
#endif
for (int node_n = 0; node_n < cgraph->n_nodes && atomic_load_explicit(&tp->abort, memory_order_relaxed) != node_n; node_n++) {
struct ggml_tensor * node = cgraph->nodes[node_n];
@@ -2974,7 +2978,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
}
}
GGML_PRINT_DEBUG("thread #%d compute-done cplan %p last-graph %d \n", state->ith, cplan, state->last_graph);
#ifdef GGML_USE_OPENMP
GGML_PRINT_DEBUG("thread #%d compute-done cplan %p\n", state->ith, (const void *)cplan);
#else
GGML_PRINT_DEBUG("thread #%d compute-done cplan %p last-graph %d\n", state->ith, (const void *)cplan, state->last_graph);
#endif
ggml_barrier(state->threadpool);
-333
View File
@@ -1,333 +0,0 @@
#pragma once
typedef vector unsigned char vec_t;
typedef __vector_quad acc_t;
template <typename TA>
class tinyBLAS_Q0_PPC {
public:
tinyBLAS_Q0_PPC(int64_t k,
const TA *A, int64_t lda,
const block_q8_0 *B, int64_t ldb,
float *C, int64_t ldc,
int ith, int nth);
void matmul(int64_t m, int64_t n);
void matmul_tiled_q0(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) {
vec_t A_pack[mc*kc*2];
vec_t B_pack[nc*kc*2];
int comparray[mc*kc];
constexpr bool is_Ablock_q4 = std::is_same_v<TA, block_q4_0>;
int64_t ytiles = m / mc;
int64_t xtiles = n / nc;
int64_t tiles = xtiles * ytiles;
int64_t duty = (tiles + nth - 1) / nth;
int64_t start = duty * ith;
int64_t end = start + duty;
if (end > tiles) {
end = tiles;
}
for (int64_t job = start; job < end; ++job) {
int64_t ii = (job / xtiles) * mc;
int64_t jj = (job % xtiles) * nc;
for (int64_t kk = 0; kk < k; kk += kc) {
if constexpr(is_Ablock_q4) {
packNormalInt4_large(A + ii*lda + kk, lda, mc, 4, (int8_t*)A_pack, comparray);
} else {
packNormal_large<int8_t, vector signed char>(A + ii*lda + kk, lda, mc, 8, (int8_t*)A_pack, false, comparray);
}
packNormal_large<uint8_t, vector unsigned char>(B + jj*ldb + kk, ldb, nc, 8, (uint8_t*)B_pack, true);
KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack, comparray);
}
}
}
private:
inline void save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) {
for (int I = 0; I < RM; I++) {
for (int J = 0; J < RN; J++) {
*((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J);
}
}
}
inline void add_save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) {
for (int I = 0; I < RM; I++) {
for (int J = 0; J < RN; J++) {
float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I);
*c_ptr += *((float*)&fin_res[idx+I]+J);
}
}
}
template<typename ArrayType>
inline void compute(acc_t* ACC, int c_idx, int s_idx, ArrayType& comparray, vector float* vs, vector float* fin_res) {
vector signed int vec_C[4];
vector float CA[4] = {0};
vector float res[4] = {0};
__builtin_mma_disassemble_acc(vec_C, ACC);
for (int i = 0; i < 4; i++) {
CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0));
res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]);
}
}
inline void process_q4_elements(vector signed char (&c)[2], int* ca) {
const vector signed char lowMask = vec_splats((signed char)0xF);
const vector unsigned char v4 = vec_splats((unsigned char)0x4);
const vector signed char v8 = vec_splats((signed char)0x8);
vector signed int vsum = {0};
vector signed int vsum2 = {0};
c[0] = vec_and(c[1], lowMask);
c[1] = vec_sr(c[1], v4);
c[0] = vec_sub(c[0], v8);
c[1] = vec_sub(c[1], v8);
vsum = vec_sum4s(c[0], vsum);
vsum2 = vec_sum4s(c[1], vsum2);
vsum = vec_add(vsum, vsum2);
*(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3];
}
template <typename V1, typename V2>
inline void vector_permute_store(V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) {
vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
V2 t1, t2, t3, t4, t5, t6, t7, t8;
vector unsigned char xor_vector;
uint8_t flip_vec = 0x80;
xor_vector = vec_splats(flip_vec);
t1 = vec_perm(s1, s2, swiz1);
t2 = vec_perm(s1, s2, swiz2);
t3 = vec_perm(s3, s4, swiz1);
t4 = vec_perm(s3, s4, swiz2);
t5 = vec_perm(t1, t3, swiz3);
t6 = vec_perm(t1, t3, swiz4);
t7 = vec_perm(t2, t4, swiz3);
t8 = vec_perm(t2, t4, swiz4);
if (flip == true) {
t5 = vec_xor(t5, xor_vector);
t6 = vec_xor(t6, xor_vector);
t7 = vec_xor(t7, xor_vector);
t8 = vec_xor(t8, xor_vector);
}
vec_xst(t5, 0, vecOffset);
vec_xst(t6, 0, vecOffset+16);
vec_xst(t7, 0, vecOffset+32);
vec_xst(t8, 0, vecOffset+48);
}
template<int RM, int RN>
inline void kernel(int64_t ii, int64_t jj) {
if constexpr(RM == 4 && RN == 8) {
KERNEL_4x8(ii,jj);
} else if constexpr(RM == 8 && RN == 4) {
KERNEL_8x4(ii,jj);
} else if constexpr(RM == 8 && RN == 8) {
KERNEL_8x8(ii,jj);
} else {
assert(false && "RN/RM values not supported");
}
}
template<int size>
void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array<int, size>& comparray);
template<typename VA, typename VB>
void packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip);
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n);
void KERNEL_4x8(int64_t ii, int64_t jj);
void KERNEL_8x4(int64_t ii, int64_t jj);
void KERNEL_8x8(int64_t ii, int64_t jj);
void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN);
template <int RM, int RN>
void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n);
void compute_scale(int64_t ii, int64_t jj, int blk, vector float* vs){
for (int I = 0; I<8; I++) {
float a_scale = unhalf((A+((ii+I)*lda)+blk)->d);
for (int J = 0; J<4; J++) {
*((float*)&vs[I]+J) = (a_scale * unhalf((B+((jj+J)*ldb)+blk)->d));
*((float*)&vs[I+8]+J) = (a_scale * unhalf((B+((jj+J+4)*ldb)+blk)->d));
}
}
}
inline void process_q8_elements(const int8_t *qs, int *ca) {
vector signed char c1 = vec_xl(0, qs);
vector signed char c2 = vec_xl(16, qs);
vector signed int vsum1 = {0};
vector signed int vsum2 = {0};
vsum1 = vec_sum4s(c1, vsum1);
vsum2 = vec_sum4s(c2, vsum2);
vector signed int vsum = vec_add(vsum1, vsum2);
*ca = vsum[0] + vsum[1] + vsum[2] + vsum[3];
}
template<typename VA, typename VB>
void packNormal_large(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip, int* comparray=nullptr) {
int64_t i, j;
block_q8_0 *aoffset = NULL;
VA *vecOffset = NULL;
block_q8_0* aoffsets[8];
__vector_pair arr[8];
VB c[8][2] = {0};
VB c1[8] = {0}; VB c2[8] = {0};
aoffset = const_cast<block_q8_0*>(a);
vecOffset = vec;
j = (rows >> 3);
int index = 0;
if (j > 0) {
do {
for (int it = 0; it < 8; it++)
aoffsets[it] = aoffset + it*lda;
aoffset += 8 * lda;
for (int blk = 0; blk < kc; blk++) {
for (int it = 0; it < 8; it++) {
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)(aoffsets[it]+blk)->qs);
__builtin_vsx_disassemble_pair(c[it], &arr[it]);
c1[it] = c[it][0];
c2[it] = c[it][1];
if (comparray){
process_q8_elements((aoffsets[it]+ blk)->qs, &comparray[index + 8*blk + it]);
}
}
vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
vecOffset += 256;
}
j--;
index += 8*kc;
} while(j > 0);
}
}
void packNormalInt4_large(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, int*comparray) {
int64_t i, j;
TA *aoffset = NULL;
int8_t *vecOffset = NULL;
TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
aoffset = const_cast<TA*>(a);
vecOffset = vec;
int index = 0;
j = (rows >> 3);
if (j > 0) {
do {
aoffset1 = aoffset;
aoffset2 = aoffset1 + lda;
aoffset3 = aoffset2 + lda;
aoffset4 = aoffset3 + lda;
aoffset5 = aoffset4 + lda;
aoffset6 = aoffset5 + lda;
aoffset7 = aoffset6 + lda;
aoffset8 = aoffset7 + lda;
aoffset += 8 * lda;
for (int blk = 0; blk < kc; blk++) {
c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset1+blk)->qs));
c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset2+blk)->qs));
c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset3+blk)->qs));
c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset4+blk)->qs));
c5[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset5+blk)->qs));
c6[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset6+blk)->qs));
c7[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset7+blk)->qs));
c8[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset8+blk)->qs));
process_q4_elements(c1, &comparray[index + 8*blk+0]);
process_q4_elements(c2, &comparray[index + 8*blk+1]);
process_q4_elements(c3, &comparray[index + 8*blk+2]);
process_q4_elements(c4, &comparray[index + 8*blk+3]);
process_q4_elements(c5, &comparray[index + 8*blk+4]);
process_q4_elements(c6, &comparray[index + 8*blk+5]);
process_q4_elements(c7, &comparray[index + 8*blk+6]);
process_q4_elements(c8, &comparray[index + 8*blk+7]);
vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
vecOffset += 256;
}
j--;
index += 8*kc;
} while (j > 0);
}
}
void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t *vec_A, vec_t *vec_B, int *comparray) {
acc_t acc[8];
for (int i = 0; i < mc ; i += 8) {
for (int j = 0; j < nc; j += 8) {
vector float fin_res[16] = {0};
vector float vs[16] = {0};
for (int64_t kk = 0; kk < kc; kk+=2) {
for (int x = 0; x < 8; x++) {
__builtin_mma_xxsetaccz(&acc[x]);
}
int A_block_idx = (i/8)*(16*kc) + kk*16;
int B_block_idx = (j/8)*(16*kc)+ kk*16;
vec_t *A_block = &vec_A[A_block_idx];
vec_t *B_block = &vec_B[B_block_idx];
for (int x = 0; x < 8; x++) {
__builtin_mma_xvi8ger4pp(&acc[0], A_block[x], B_block[x]);
__builtin_mma_xvi8ger4pp(&acc[1], A_block[x + 8], B_block[x]);
__builtin_mma_xvi8ger4pp(&acc[2], A_block[x], B_block[x+8]);
__builtin_mma_xvi8ger4pp(&acc[3], A_block[x+8], B_block[x+8]);
}
compute_scale(ii+i, jj+j, l+kk, vs);
int c_index = (i/8)*(8*kc)+ kk*8;
int* c_block = &comparray[c_index];
compute(&acc[0], 0, 0, c_block, vs, fin_res);
compute(&acc[1], 4, 4, c_block, vs, fin_res);
compute(&acc[2], 0, 8, c_block, vs, fin_res);
compute(&acc[3], 4, 12, c_block, vs, fin_res);
A_block_idx = (i/8)*(16*kc) + (kk+1)*16;
B_block_idx = (j/8)*(16*kc)+ (kk+1)*16;
A_block = &vec_A[A_block_idx];
B_block = &vec_B[B_block_idx];
for (int x = 0; x < 8; x++) {
__builtin_mma_xvi8ger4pp(&acc[4], A_block[x], B_block[x]);
__builtin_mma_xvi8ger4pp(&acc[5], A_block[x + 8], B_block[x]);
__builtin_mma_xvi8ger4pp(&acc[6], A_block[x], B_block[x+8]);
__builtin_mma_xvi8ger4pp(&acc[7], A_block[x+8], B_block[x+8]);
}
compute_scale(ii+i, jj+j, l+kk+1, vs);
c_index = (i/8)*(8*kc)+ (kk+1)*8;
c_block = &comparray[c_index];
compute(&acc[4], 0, 0, c_block, vs, fin_res);
compute(&acc[5], 4, 4, c_block, vs, fin_res);
compute(&acc[6], 0, 8, c_block, vs, fin_res);
compute(&acc[7], 4, 12, c_block, vs, fin_res);
}
if (l == 0) {
save_res(ii+i, jj+j, 0, fin_res);
save_res(ii+i+4, jj+j, 4, fin_res);
save_res(ii+i, jj+j+4, 8, fin_res);
save_res(ii+i+4, jj+j+4, 12, fin_res);
} else {
add_save_res(ii+i, jj+j, 0, fin_res);
add_save_res(ii+i+4, jj+j, 4, fin_res);
add_save_res(ii+i, jj+j+4, 8, fin_res);
add_save_res(ii+i+4, jj+j+4, 12, fin_res);
}
}
}
}
const TA *const A;
const block_q8_0 *const B;
float *C;
const int64_t k;
int64_t kc;
const int64_t lda;
const int64_t ldb;
const int64_t ldc;
const int ith;
const int nth;
};
+507 -155
View File
@@ -121,7 +121,8 @@ inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); }
#endif
#if defined(__MMA__)
#include "sgemm-ppc.h"
typedef vector unsigned char vec_t;
typedef __vector_quad acc_t;
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
// VECTORIZED FUSED MULTIPLY ADD
@@ -2153,7 +2154,7 @@ class tinyBLAS_HP16_PPC {
packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B);
for (int x = 0; x < 4; x++) {
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
mma_instr<TA>::outer_product(&acc_1, vec_A[x+4], vec_B[x]);
}
}
SAVE_ACC(&acc_0, ii, jj);
@@ -2301,43 +2302,299 @@ class tinyBLAS_HP16_PPC {
const int nth;
};
template <typename TA>
tinyBLAS_Q0_PPC<TA>::tinyBLAS_Q0_PPC(int64_t k,
const TA *A, int64_t lda,
const block_q8_0 *B, int64_t ldb,
float *C, int64_t ldc,
int ith, int nth)
template <typename TA>
class tinyBLAS_Q0_PPC {
public:
tinyBLAS_Q0_PPC(int64_t k,
const TA * A, int64_t lda,
const block_q8_0 * B, int64_t ldb,
float * C, int64_t ldc,
int ith, int nth)
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
kc = 64;
}
template<typename TA>
void tinyBLAS_Q0_PPC<TA>::matmul(int64_t m, int64_t n) {
int mc = 64; int nc = 64;
if (n % 8 == 0 && n < nc) {
nc = n;
mc = 32 ;
kc = 32;
void matmul(int64_t m, int64_t n) {
const int64_t mc = 64;
const int64_t kc = 64;
int64_t nc = 64;
int64_t n_aligned = 0;
if (n % 64 == 0) {
n_aligned = n;
} else if (n == 4) {
n_aligned = 4;
} else if (n < 64) {
n_aligned = (n / 8) * 8;
} else {
n_aligned = (n / 64) * 64;
}
const bool is_aligned = ((m & (mc - 1)) == 0) & ((n & (nc - 1)) == 0) & ((k & (kc - 1)) == 0);
if (is_aligned) {
this->matmul_tiled_q0(m, n, mc, nc, kc);
if (n_aligned > 0) {
if (n_aligned % 64 == 0) nc = 64;
else if (n_aligned == n) nc = n;
else if (n_aligned % 32 == 0) nc = 32;
else if (n_aligned % 24 == 0) nc = 24;
else if (n_aligned % 16 == 0) nc = 16;
else nc = 8;
}
bool can_use_tiled = n_aligned > 0 && (m % mc == 0) && (k % kc == 0);
if (can_use_tiled) {
matmul_tiled(m, n_aligned, mc, nc, kc);
if (n > n_aligned) {
mnpack(0, m, n_aligned, n);
}
} else {
mnpack(0, m, 0, n);
}
}
template<typename TA>
template<int size>
void tinyBLAS_Q0_PPC<TA>::packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array<int, size>& comparray) {
private:
inline void save_res(int ii, int jj, int idx, vector float * fin_res, int RM = 4, int RN = 4) {
for (int I = 0; I < RM; I++) {
for (int J = 0; J < RN; J++) {
*((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&fin_res[idx + I] + J);
}
}
}
inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
vec_t vec_C[4];
__builtin_mma_disassemble_acc(vec_C, ACC);
for (int I = 0; I < 4; I++) {
for (int J = 0; J < 4; J++) {
*((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&vec_C[I] + J);
}
}
}
inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
vec_t vec_C[4];
__builtin_mma_disassemble_acc(vec_C, ACC);
for (int I = 0; I < 4; I++) {
for (int J = 0; J < 4; J++) {
float * c_ptr = (float *)(C + ii+ ((jj + J) * ldc) + I);
*c_ptr += *((float *)&vec_C[I] + J);
}
}
}
template<typename ArrayType>
inline void compute(acc_t * ACC, int c_idx, int s_idx, ArrayType & comparray, vector float * vs, vector float * fin_res) {
vector signed int vec_C[4];
vector float CA[4] = {0};
vector float res[4] = {0};
__builtin_mma_disassemble_acc(vec_C, ACC);
for (int i = 0; i < 4; i++) {
CA[i] = vec_splats((float)(((double)comparray[c_idx + i]) * -128.0));
res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
fin_res[s_idx + i] = vec_madd(res[i], vs[s_idx + i], fin_res[s_idx + i]);
}
}
inline void process_q4_elements(vector signed char (&c)[2], int * ca) {
const vector signed char lowMask = vec_splats((signed char)0xF);
const vector unsigned char v4 = vec_splats((unsigned char)0x4);
const vector signed char v8 = vec_splats((signed char)0x8);
vector signed int vsum = {0};
vector signed int vsum2 = {0};
c[0] = vec_and(c[1], lowMask);
c[1] = vec_sr(c[1], v4);
c[0] = vec_sub(c[0], v8);
c[1] = vec_sub(c[1], v8);
vsum = vec_sum4s(c[0], vsum);
vsum2 = vec_sum4s(c[1], vsum2);
vsum = vec_add(vsum, vsum2);
*(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3];
}
template <typename V1, typename V2>
inline void vector_permute_store(V2 & s1, V2 & s2, V2 & s3, V2 & s4, V1 * vecOffset, bool flip) {
vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
V2 t1, t2, t3, t4, t5, t6, t7, t8;
vector unsigned char xor_vector;
uint8_t flip_vec = 0x80;
xor_vector = vec_splats(flip_vec);
t1 = vec_perm(s1, s2, swiz1);
t2 = vec_perm(s1, s2, swiz2);
t3 = vec_perm(s3, s4, swiz1);
t4 = vec_perm(s3, s4, swiz2);
t5 = vec_perm(t1, t3, swiz3);
t6 = vec_perm(t1, t3, swiz4);
t7 = vec_perm(t2, t4, swiz3);
t8 = vec_perm(t2, t4, swiz4);
if (flip == true) {
t5 = vec_xor(t5, xor_vector);
t6 = vec_xor(t6, xor_vector);
t7 = vec_xor(t7, xor_vector);
t8 = vec_xor(t8, xor_vector);
}
vec_xst(t5, 0, vecOffset);
vec_xst(t6, 0, vecOffset + 16);
vec_xst(t7, 0, vecOffset + 32);
vec_xst(t8, 0, vecOffset + 48);
}
inline void unpack_q4_to_q8(vector signed char packed, vector signed char & lo, vector signed char & hi) {
const vector signed char lowMask = vec_splats((signed char)0x0F);
const vector signed char v8 = vec_splats((signed char)0x08);
const vector unsigned char v4 = vec_splats((unsigned char)4);
lo = vec_and(packed, lowMask);
hi = vec_sr(packed, v4);
lo = vec_sub(lo, v8);
hi = vec_sub(hi, v8);
}
inline void vector_permute_store_fp16(vec_t * c, unsigned char * vecOffset) {
vec_t t[8], s[8];
vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
for (int i = 0; i < 4; i += 2) {
t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1);
t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2);
}
for (int i = 4; i < 8; i += 2) {
t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1);
t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2);
}
s[0] = vec_perm(t[0], t[2], swiz3);
s[1] = vec_perm(t[0], t[2], swiz4);
s[2] = vec_perm(t[1], t[3], swiz3);
s[3] = vec_perm(t[1], t[3], swiz4);
s[4] = vec_perm(t[4], t[6], swiz3);
s[5] = vec_perm(t[4], t[6], swiz4);
s[6] = vec_perm(t[5], t[7], swiz3);
s[7] = vec_perm(t[5], t[7], swiz4);
for (int i = 0; i < 8; ++i) {
vec_xst(s[i], 0, (vec_t *)(vecOffset + i * 16));
}
}
static inline void convert_and_scale_q8(vector signed char raw, vector float v_scale, vector unsigned short & out_hi, vector unsigned short & out_lo) {
vector signed short i16_hi = vec_unpackh(raw);
vector signed short i16_lo = vec_unpackl(raw);
vector float f_hi_h = vec_ctf(vec_unpackh(i16_hi), 0);
vector float f_hi_l = vec_ctf(vec_unpackl(i16_hi), 0);
vector float f_lo_h = vec_ctf(vec_unpackh(i16_lo), 0);
vector float f_lo_l = vec_ctf(vec_unpackl(i16_lo), 0);
out_hi = vec_pack_to_short_fp32(vec_mul(f_hi_h, v_scale), vec_mul(f_hi_l, v_scale));
out_lo = vec_pack_to_short_fp32(vec_mul(f_lo_h, v_scale), vec_mul(f_lo_l, v_scale));
}
void packNormal_q4_fp16(const block_q4_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
unsigned char * vecOffset = vec;
for (int i = 0; i < rows; i += 8) {
const block_q4_0 * rows_base[8];
for (int r = 0; r < 8; r++) {
rows_base[r] = a + (i + r) * lda;
}
for (int blk = 0; blk < blocks; blk++) {
vector unsigned short hp_res[8][4];
for (int r = 0; r < 8; r++) {
const block_q4_0 * current_blk = rows_base[r] + blk;
vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(current_blk->d));
vector signed char v_qs = reinterpret_cast<vector signed char>(vec_xl(0, current_blk->qs));
vector signed char c1, c2;
unpack_q4_to_q8(v_qs, c1, c2);
convert_and_scale_q8(c1, v_scale, hp_res[r][0], hp_res[r][1]);
convert_and_scale_q8(c2, v_scale, hp_res[r][2], hp_res[r][3]);
}
for (int c = 0; c < 4; c++) {
vector unsigned char c_arr[8];
for (int r = 0; r < 8; r++) {
c_arr[r] = (vector unsigned char)hp_res[r][c];
}
vector_permute_store_fp16((vec_t *)c_arr, vecOffset);
vecOffset += 128;
}
}
}
}
template <int chunk_size>
static inline void pack_q8_block(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
unsigned char * vecOffset = vec;
const vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
const vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
const vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
const vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
for (int i = 0; i < rows; i += chunk_size) {
const block_q8_0 * rows_base[chunk_size];
for (int r = 0; r < chunk_size; r++) {
rows_base[r] = a + (i + r) * lda;
}
for (int blk = 0; blk < blocks; blk++) {
vector unsigned short hp_res[chunk_size][4];
for (int r = 0; r < chunk_size; r++) {
const block_q8_0 * b = rows_base[r] + blk;
vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(b->d));
vector signed char c[2];
__vector_pair pair = __builtin_vsx_lxvp(0, (__vector_pair *)b->qs);
__builtin_vsx_disassemble_pair(c, & pair);
convert_and_scale_q8(c[0], v_scale, hp_res[r][0], hp_res[r][1]);
convert_and_scale_q8(c[1], v_scale, hp_res[r][2], hp_res[r][3]);
}
for (int col = 0; col < 4; col++) {
if constexpr (chunk_size == 8) {
vec_t t[8];
t[0] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1);
t[1] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2);
t[2] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1);
t[3] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2);
t[4] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz1);
t[5] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz2);
t[6] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz1);
t[7] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz2);
vec_xst(vec_perm(t[0], t[2], swiz3), 0, (vec_t *)(vecOffset + 0));
vec_xst(vec_perm(t[0], t[2], swiz4), 0, (vec_t *)(vecOffset + 16));
vec_xst(vec_perm(t[1], t[3], swiz3), 0, (vec_t *)(vecOffset + 32));
vec_xst(vec_perm(t[1], t[3], swiz4), 0, (vec_t *)(vecOffset + 48));
vec_xst(vec_perm(t[4], t[6], swiz3), 0, (vec_t *)(vecOffset + 64));
vec_xst(vec_perm(t[4], t[6], swiz4), 0, (vec_t *)(vecOffset + 80));
vec_xst(vec_perm(t[5], t[7], swiz3), 0, (vec_t *)(vecOffset + 96));
vec_xst(vec_perm(t[5], t[7], swiz4), 0, (vec_t *)(vecOffset + 112));
vecOffset += 128;
} else {
vec_t t0 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1);
vec_t t1 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2);
vec_t t2 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1);
vec_t t3 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2);
vec_xst(vec_perm(t0, t2, swiz3), 0, (vec_t *)(vecOffset + 0));
vec_xst(vec_perm(t0, t2, swiz4), 0, (vec_t *)(vecOffset + 16));
vec_xst(vec_perm(t1, t3, swiz3), 0, (vec_t *)(vecOffset + 32));
vec_xst(vec_perm(t1, t3, swiz4), 0, (vec_t *)(vecOffset + 48));
vecOffset += 64;
}
}
}
}
}
void packNormal_q8_fp16(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
if (rows == 4) {
pack_q8_block<4>(a, lda, rows, blocks, vec);
} else {
pack_q8_block<8>(a, lda, rows, blocks, vec);
}
}
template<int size>
void packNormalInt4(const TA * a, int64_t lda, int rows, int cols, int8_t * vec, std::array<int, size> & comparray) {
int64_t i, j;
TA *aoffset = NULL;
int8_t *vecOffset = NULL;
TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
TA * aoffset = NULL;
int8_t * vecOffset = NULL;
TA * aoffset1 = NULL, * aoffset2 = NULL, * aoffset3 = NULL, * aoffset4 = NULL;
TA * aoffset5 = NULL, * aoffset6 = NULL, * aoffset7 = NULL, * aoffset8 = NULL;
vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
aoffset = const_cast<TA*>(a);
aoffset = const_cast<TA *>(a);
vecOffset = vec;
j = (rows >> 3);
if (j > 0) {
@@ -2363,18 +2620,18 @@ class tinyBLAS_HP16_PPC {
c7[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset7->qs));
c8[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset8->qs));
process_q4_elements(c1, &comparray[0]);
process_q4_elements(c2, &comparray[1]);
process_q4_elements(c3, &comparray[2]);
process_q4_elements(c4, &comparray[3]);
process_q4_elements(c5, &comparray[4]);
process_q4_elements(c6, &comparray[5]);
process_q4_elements(c7, &comparray[6]);
process_q4_elements(c8, &comparray[7]);
process_q4_elements(c1, & comparray[0]);
process_q4_elements(c2, & comparray[1]);
process_q4_elements(c3, & comparray[2]);
process_q4_elements(c4, & comparray[3]);
process_q4_elements(c5, & comparray[4]);
process_q4_elements(c6, & comparray[5]);
process_q4_elements(c7, & comparray[6]);
process_q4_elements(c8, & comparray[7]);
vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset + 128, false);
vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset + 192, false);
aoffset1 += lda;
aoffset2 += lda;
aoffset3 += lda;
@@ -2405,12 +2662,12 @@ class tinyBLAS_HP16_PPC {
c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs));
process_q4_elements(c1, &comparray[0]);
process_q4_elements(c2, &comparray[1]);
process_q4_elements(c3, &comparray[2]);
process_q4_elements(c4, &comparray[3]);
process_q4_elements(c1, & comparray[0]);
process_q4_elements(c2, & comparray[1]);
process_q4_elements(c3, & comparray[2]);
process_q4_elements(c4, & comparray[3]);
vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
aoffset1 += lda;
aoffset2 += lda;
aoffset3 += lda;
@@ -2434,12 +2691,12 @@ class tinyBLAS_HP16_PPC {
case 1: c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
break;
}
process_q4_elements(c1, &comparray[0]);
process_q4_elements(c2, &comparray[1]);
process_q4_elements(c3, &comparray[2]);
process_q4_elements(c4, &comparray[3]);
process_q4_elements(c1, & comparray[0]);
process_q4_elements(c2, & comparray[1]);
process_q4_elements(c3, & comparray[2]);
process_q4_elements(c4, & comparray[3]);
vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
aoffset1 += lda;
aoffset2 += lda;
aoffset3 += lda;
@@ -2450,39 +2707,38 @@ class tinyBLAS_HP16_PPC {
}
}
template<typename TA>
template<typename VA, typename VB>
void tinyBLAS_Q0_PPC<TA>::packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
void packNormal(const block_q8_0 * a, int64_t lda, int rows, int cols, VA * vec, bool flip) {
int64_t i, j;
block_q8_0 *aoffset = NULL;
VA *vecOffset = NULL;
block_q8_0* aoffsets[8];
block_q8_0 * aoffset = NULL;
VA * vecOffset = NULL;
block_q8_0 * aoffsets[8];
__vector_pair arr[8];
VB c[8][2] = {0};
VB c1[8] = {0}; VB c2[8] = {0};
aoffset = const_cast<block_q8_0*>(a);
aoffset = const_cast<block_q8_0 *>(a);
vecOffset = vec;
j = (rows >> 3);
if (j > 0) {
do {
aoffsets[0] = aoffset;
for (int it = 1; it < 8; it++)
aoffsets[it] = aoffsets[it-1] + lda;
aoffsets[it] = aoffsets[it - 1] + lda;
aoffset += 8 * lda;
i = (cols >> 3);
if (i > 0) {
do {
for (int it = 0; it < 8; it++) {
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
__builtin_vsx_disassemble_pair(c[it], &arr[it]);
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs);
__builtin_vsx_disassemble_pair(c[it], & arr[it]);
c1[it] = c[it][0];
c2[it] = c[it][1];
}
vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset + 128, flip);
vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset + 192, flip);
for (int it = 0; it < 8; it++)
aoffsets[it] += lda;
vecOffset += 256;
@@ -2501,13 +2757,13 @@ class tinyBLAS_HP16_PPC {
if (i > 0) {
do {
for (int it = 0; it < 4; it++) {
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
__builtin_vsx_disassemble_pair(c[it], &arr[it]);
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs);
__builtin_vsx_disassemble_pair(c[it], & arr[it]);
c1[it] = c[it][0];
c2[it] = c[it][1];
}
vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
for (int it = 0; it < 4; it++) {
aoffsets[it] += lda;
}
@@ -2520,24 +2776,24 @@ class tinyBLAS_HP16_PPC {
if (rows & 3) {
aoffsets[0] = aoffset;
for (int it = 1; it < 3; it++ )
aoffsets[it] = aoffsets[it-1] + lda;
aoffsets[it] = aoffsets[it - 1] + lda;
i = (cols >> 3);
if (i > 0) {
do {
switch(rows) {
case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[2]->qs);
__builtin_vsx_disassemble_pair(c[2], &arr[2]);
case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[2]->qs);
__builtin_vsx_disassemble_pair(c[2], & arr[2]);
c1[2] = c[2][0]; c2[2] = c[2][1];
case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[1]->qs);
__builtin_vsx_disassemble_pair(c[1], &arr[1]);
case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[1]->qs);
__builtin_vsx_disassemble_pair(c[1], & arr[1]);
c1[1] = c[1][0]; c2[1] = c[1][1];
case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[0]->qs);
__builtin_vsx_disassemble_pair(c[0], &arr[0]);
case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[0]->qs);
__builtin_vsx_disassemble_pair(c[0], & arr[0]);
c1[0] = c[0][0]; c2[0] = c[0][1];
break;
}
vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
for (int it = 0; it < 3; it++)
aoffsets[it] += lda;
vecOffset += 128;
@@ -2547,8 +2803,7 @@ class tinyBLAS_HP16_PPC {
}
}
template<typename TA>
void tinyBLAS_Q0_PPC<TA>::mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
int m_rem = MIN(m - m0, 16);
int n_rem = MIN(n - n0, 16);
@@ -2585,8 +2840,7 @@ class tinyBLAS_HP16_PPC {
}
template<typename TA>
void tinyBLAS_Q0_PPC<TA>::KERNEL_4x8(int64_t ii, int64_t jj) {
void KERNEL_4x8(int64_t ii, int64_t jj) {
vec_t vec_A[8], vec_B[16] = {0};
acc_t acc_0, acc_1;
std::array<int, 4> comparray {};
@@ -2594,26 +2848,26 @@ class tinyBLAS_HP16_PPC {
vector float vs[8] = {0};
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
for (int l = 0; l < k; l++) {
__builtin_mma_xxsetaccz(&acc_0);
__builtin_mma_xxsetaccz(&acc_1);
__builtin_mma_xxsetaccz(& acc_0);
__builtin_mma_xxsetaccz(& acc_1);
if (std::is_same_v<TA, block_q4_0>) {
packNormalInt4<4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
packNormalInt4<4>((A + (ii * lda) + l), lda, 4, 4, (int8_t *)vec_A, comparray);
} else {
packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 4, 8, (int8_t *)vec_A, false);
}
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true);
for(int x = 0; x < 8; x++) {
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
__builtin_mma_xvi8ger4pp(&acc_1, vec_A[x], vec_B[x+8]);
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
__builtin_mma_xvi8ger4pp(& acc_1, vec_A[x], vec_B[x+8]);
}
for (int I = 0; I<4; I++) {
for (int J = 0; J<4; J++) {
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
*((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
*((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
*((float *)& vs[I + 4] + J) = (unhalf((A +((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d));
}
}
if (!isAblock_q4) {
auto aoffset = A+(ii*lda)+l;
auto aoffset = A + (ii * lda) + l;
for (int i = 0; i < 4; i++) {
comparray[i] = 0;
int ca = 0;
@@ -2624,15 +2878,14 @@ class tinyBLAS_HP16_PPC {
aoffset += lda;
}
}
compute(&acc_0, 0, 0, comparray, vs, fin_res);
compute(&acc_1, 0, 4, comparray, vs, fin_res);
compute(& acc_0, 0, 0, comparray, vs, fin_res);
compute(& acc_1, 0, 4, comparray, vs, fin_res);
}
save_res(ii, jj, 0, fin_res);
save_res(ii, jj+4, 4, fin_res);
save_res(ii, jj + 4, 4, fin_res);
}
template<typename TA>
void tinyBLAS_Q0_PPC<TA>::KERNEL_8x4(int64_t ii, int64_t jj) {
void KERNEL_8x4(int64_t ii, int64_t jj) {
vec_t vec_A[16], vec_B[8] = {0};
acc_t acc_0, acc_1;
std::array<int, 8> comparray {};
@@ -2640,25 +2893,25 @@ class tinyBLAS_HP16_PPC {
vector float vs[8] = {0};
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
for (int l = 0; l < k; l++) {
__builtin_mma_xxsetaccz(&acc_0);
__builtin_mma_xxsetaccz(&acc_1);
__builtin_mma_xxsetaccz(& acc_0);
__builtin_mma_xxsetaccz(& acc_1);
if (std::is_same_v<TA, block_q4_0>) {
packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray);
} else {
packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false);
}
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 4, 8, (uint8_t *)vec_B, true);
for(int x = 0; x < 8; x++) {
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
__builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
__builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]);
}
for (int I = 0; I<8; I++) {
for (int J = 0; J<4; J++) {
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
for (int I = 0; I < 8; I++) {
for (int J = 0; J < 4; J++) {
*((float *)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
}
}
if (!isAblock_q4) {
auto aoffset = A+(ii*lda)+l;
auto aoffset = A + (ii * lda) + l;
for (int i = 0; i < 8; i++) {
comparray[i] = 0;
int ca = 0;
@@ -2669,15 +2922,14 @@ class tinyBLAS_HP16_PPC {
aoffset += lda;
}
}
compute(&acc_0, 0, 0, comparray, vs, fin_res);
compute(&acc_1, 4, 4, comparray, vs, fin_res);
compute(& acc_0, 0, 0, comparray, vs, fin_res);
compute(& acc_1, 4, 4, comparray, vs, fin_res);
}
save_res(ii, jj, 0, fin_res);
save_res(ii+4, jj, 4, fin_res);
save_res(ii + 4, jj, 4, fin_res);
}
template<typename TA>
void tinyBLAS_Q0_PPC<TA>::KERNEL_8x8(int64_t ii, int64_t jj) {
void KERNEL_8x8(int64_t ii, int64_t jj) {
vec_t vec_A[16], vec_B[16] = {0};
acc_t acc_0, acc_1, acc_2, acc_3;
acc_t acc_4, acc_5, acc_6, acc_7;
@@ -2686,30 +2938,30 @@ class tinyBLAS_HP16_PPC {
vector float vs[16] = {0};
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
for (int l = 0; l < k; l++) {
__builtin_mma_xxsetaccz(&acc_0);
__builtin_mma_xxsetaccz(&acc_1);
__builtin_mma_xxsetaccz(&acc_2);
__builtin_mma_xxsetaccz(&acc_3);
__builtin_mma_xxsetaccz(& acc_0);
__builtin_mma_xxsetaccz(& acc_1);
__builtin_mma_xxsetaccz(& acc_2);
__builtin_mma_xxsetaccz(& acc_3);
if (std::is_same_v<TA, block_q4_0>) {
packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray);
} else {
packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false);
}
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true);
for(int x = 0; x < 8; x++) {
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
__builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
__builtin_mma_xvi8ger4pp(&acc_2, vec_A[x], vec_B[x+8]);
__builtin_mma_xvi8ger4pp(&acc_3, vec_A[x+8], vec_B[x+8]);
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
__builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]);
__builtin_mma_xvi8ger4pp(& acc_2, vec_A[x], vec_B[x + 8]);
__builtin_mma_xvi8ger4pp(& acc_3, vec_A[x + 8], vec_B[x + 8]);
}
for (int I = 0; I<8; I++) {
for (int J = 0; J<4; J++) {
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
*((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
for (int I = 0; I < 8 ; I++) {
for (int J = 0; J < 4; J++) {
*((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
*((float *)& vs[I + 8] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d));
}
}
if (!isAblock_q4) {
auto aoffset = A+(ii*lda)+l;
auto aoffset = A + (ii * lda) + l;
for (int i = 0; i < 8; i++) {
comparray[i] = 0;
int ca = 0;
@@ -2720,19 +2972,99 @@ class tinyBLAS_HP16_PPC {
aoffset += lda;
}
}
compute(&acc_0, 0, 0, comparray, vs, fin_res);
compute(&acc_1, 4, 4, comparray, vs, fin_res);
compute(&acc_2, 0, 8, comparray, vs, fin_res);
compute(&acc_3, 4, 12, comparray, vs, fin_res);
compute(& acc_0, 0, 0, comparray, vs, fin_res);
compute(& acc_1, 4, 4, comparray, vs, fin_res);
compute(& acc_2, 0, 8, comparray, vs, fin_res);
compute(& acc_3, 4, 12, comparray, vs, fin_res);
}
save_res(ii, jj, 0, fin_res);
save_res(ii+4, jj, 4, fin_res);
save_res(ii, jj+4, 8, fin_res);
save_res(ii+4, jj+4, 12, fin_res);
save_res(ii + 4, jj, 4, fin_res);
save_res(ii, jj + 4, 8, fin_res);
save_res(ii + 4, jj + 4, 12, fin_res);
}
template<typename TA>
void tinyBLAS_Q0_PPC<TA>::gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t * vec_A, vec_t * vec_B) {
acc_t acc[8];
for (int i = 0; i < mc ; i += 16) {
for (int j = 0; j < nc; j += 8) {
int A0_base = (i / 16) * (2 * 32 * kc);
int B0_base = (j / 8) * (32 * kc);
for (int x = 0; x < 8; x++) {
__builtin_mma_xxsetaccz(&acc[x]);
}
for (int64_t kk = 0; kk < kc; kk++) {
int A0_block_idx = A0_base + kk * 32;
int B0_block_idx = B0_base + kk * 32;
int A1_block_idx = A0_block_idx + 32 * kc;
int B1_block_idx = B0_block_idx + 32 * kc;
vec_t * A0_block = & vec_A[A0_block_idx];
vec_t * B0_block = & vec_B[B0_block_idx];
vec_t * A1_block = & vec_A[A1_block_idx];
for (int it = 0; it < 4; it++) {
for (int x = 0; x < 4; x++) {
__builtin_mma_xvf16ger2pp(& acc[0], A0_block[8 * it + x], B0_block[8 * it + x]);
__builtin_mma_xvf16ger2pp(& acc[1], A0_block[8 * it + x], B0_block[8 * it + x + 4]);
__builtin_mma_xvf16ger2pp(& acc[2], A0_block[8 * it + x + 4], B0_block[8 * it + x]);
__builtin_mma_xvf16ger2pp(& acc[3], A0_block[8 * it + x + 4], B0_block[8 * it + x + 4]);
__builtin_mma_xvf16ger2pp(& acc[4], A1_block[8 * it + x], B0_block[8 * it + x]);
__builtin_mma_xvf16ger2pp(& acc[5], A1_block[8 * it + x], B0_block[8 * it+ x + 4]);
__builtin_mma_xvf16ger2pp(& acc[6], A1_block[8 * it + x + 4], B0_block[8 * it + x]);
__builtin_mma_xvf16ger2pp(& acc[7], A1_block[8 * it + x + 4], B0_block[8 * it + x + 4]);
}
}
}
if (l == 0) {
save_acc(& acc[0], ii + i, jj + j);
save_acc(& acc[1], ii + i, jj + j + 4);
save_acc(& acc[2], ii + i + 4, jj + j);
save_acc(& acc[3], ii + i + 4, jj + j + 4);
save_acc(& acc[4], ii + i + 8, jj + j);
save_acc(& acc[5], ii + i + 8, jj + j + 4);
save_acc(& acc[6], ii + i + 12, jj + j);
save_acc(& acc[7], ii + i + 12, jj + j + 4);
} else {
add_save_acc(& acc[0], ii + i, jj + j);
add_save_acc(& acc[1], ii + i, jj + j + 4);
add_save_acc(& acc[2], ii + i + 4, jj + j);
add_save_acc(& acc[3], ii + i + 4, jj + j + 4);
add_save_acc(& acc[4], ii + i + 8, jj + j);
add_save_acc(& acc[5], ii + i + 8, jj + j + 4);
add_save_acc(& acc[6], ii + i + 12, jj + j);
add_save_acc(& acc[7], ii + i + 12, jj + j + 4);
}
}
}
}
void matmul_tiled(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) {
vec_t A_pack[mc * kc * 4];
vec_t B_pack[nc * kc * 4];
constexpr bool is_Ablock_q4 = std::is_same_v<TA, block_q4_0>;
int64_t ytiles = m / mc;
int64_t xtiles = n / nc;
int64_t tiles = xtiles * ytiles;
int64_t duty = (tiles + nth - 1) / nth;
int64_t start = duty * ith;
int64_t end = start + duty;
if (end > tiles) {
end = tiles;
}
for (int64_t job = start; job < end; ++job) {
int64_t ii = (job / xtiles) * mc;
int64_t jj = (job % xtiles) * nc;
for (int64_t kk = 0; kk < k; kk += kc) {
if constexpr(is_Ablock_q4) {
packNormal_q4_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack);
} else {
packNormal_q8_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack);
}
packNormal_q8_fp16(B + jj * ldb + kk, ldb, nc, kc, (uint8_t *)B_pack);
KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack);
}
}
}
void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
int64_t ytiles = (m - m0) / RM;
int64_t xtiles = (n - n0) / RN;
int64_t tiles = xtiles * ytiles;
@@ -2754,32 +3086,32 @@ class tinyBLAS_HP16_PPC {
vector float fin_res[4] = {0};
vector float vs[4] = {0};
vector float CA[4] = {0};
__builtin_prefetch((A+(ii*lda)+0)->qs, 0, 1); // prefetch first value
__builtin_prefetch((B+(jj*ldb)+0)->qs, 0, 1); // prefetch first value
__builtin_prefetch((A + (ii * lda) + 0)->qs, 0, 1); // prefetch first value
__builtin_prefetch((B + (jj * ldb) + 0)->qs, 0, 1); // prefetch first value
for (int l = 0; l < k; l++) {
__builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead
__builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
__builtin_mma_xxsetaccz(&acc_0);
__builtin_prefetch((A + (ii * lda) + (l + 1))->qs, 0, 1); // prefetch one loop ahead
__builtin_prefetch((B + (jj * ldb) + (l + 1))->qs, 0, 1); // prefetch one loop ahead
__builtin_mma_xxsetaccz(& acc_0);
if (isAblock_q4) {
packNormalInt4<4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
packNormalInt4<4>((A + (ii * lda) + l), lda, RM, 4, (int8_t *)vec_A, comparray);
} else {
packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, RM, 8, (int8_t *)vec_A, false);
}
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
for(int x = 0; x < 8; x+=4) {
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+1], vec_B[x+1]);
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+2], vec_B[x+2]);
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+3], vec_B[x+3]);
packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, RN, 8, (uint8_t *)vec_B, true);
for (int x = 0; x < 8; x += 4) {
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 1], vec_B[x + 1]);
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 2], vec_B[x + 2]);
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 3], vec_B[x + 3]);
}
for (int I = 0; I<RM; I++) {
for (int J = 0; J<RN; J++) {
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
for (int I = 0; I < RM; I++) {
for (int J = 0; J < RN; J++) {
*((float*)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
}
}
__builtin_mma_disassemble_acc(vec_C, &acc_0);
__builtin_mma_disassemble_acc(vec_C, & acc_0);
if (!isAblock_q4) {
auto aoffset = A+(ii*lda)+l;
auto aoffset = A + (ii * lda) + l;
for (int i = 0; i < RM; i++) {
comparray[i] = 0;
int ca = 0;
@@ -2800,9 +3132,21 @@ class tinyBLAS_HP16_PPC {
}
}
template<typename TA>
template<int RM, int RN>
inline void kernel(int64_t ii, int64_t jj) {
if constexpr(RM == 4 && RN == 8) {
KERNEL_4x8(ii,jj);
} else if constexpr(RM == 8 && RN == 4) {
KERNEL_8x4(ii,jj);
} else if constexpr(RM == 8 && RN == 8) {
KERNEL_8x8(ii,jj);
} else {
assert(false && "RN/RM values not supported");
}
}
template <int RM, int RN>
NOINLINE void tinyBLAS_Q0_PPC<TA>::gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
int64_t ytiles = (m - m0) / RM;
int64_t xtiles = (n - n0) / RN;
int64_t tiles = xtiles * ytiles;
@@ -2814,12 +3158,20 @@ class tinyBLAS_HP16_PPC {
for (int64_t job = start; job < end; ++job) {
int64_t ii = m0 + job / xtiles * RM;
int64_t jj = n0 + job % xtiles * RN;
this->kernel<RM, RN>(ii, jj);
kernel<RM, RN>(ii, jj);
}
}
template class tinyBLAS_Q0_PPC<block_q4_0>;
template class tinyBLAS_Q0_PPC<block_q8_0>;
const TA * const A;
const block_q8_0 * const B;
float * C;
const int64_t k;
int64_t kc;
const int64_t lda;
const int64_t ldb;
const int64_t ldc;
const int ith;
const int nth;
};
class tinyBLAS_PPC {
public:
+173 -94
View File
@@ -3,6 +3,7 @@
#include "ggml-cpu.h"
#include "ggml-impl.h"
#include "binary-ops.h"
#include "simd-gemm.h"
#include "ggml.h"
#include "unary-ops.h"
#include "vec.h"
@@ -2096,10 +2097,14 @@ static void ggml_compute_forward_gelu_f32(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@@ -2113,10 +2118,14 @@ static void ggml_compute_forward_gelu_f32(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_gelu_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])));
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@@ -2135,10 +2144,14 @@ static void ggml_compute_forward_gelu_f16(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@@ -2152,10 +2165,14 @@ static void ggml_compute_forward_gelu_f16(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_gelu_f16(nc,
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@@ -2276,10 +2293,14 @@ static void ggml_compute_forward_gelu_erf_f32(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@@ -2293,10 +2314,14 @@ static void ggml_compute_forward_gelu_erf_f32(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_gelu_erf_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])));
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@@ -2315,10 +2340,14 @@ static void ggml_compute_forward_gelu_erf_f16(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@@ -2332,10 +2361,14 @@ static void ggml_compute_forward_gelu_erf_f16(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_gelu_erf_f16(nc,
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@@ -2379,10 +2412,14 @@ static void ggml_compute_forward_gelu_quick_f32(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@@ -2396,10 +2433,14 @@ static void ggml_compute_forward_gelu_quick_f32(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_gelu_quick_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])));
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@@ -2418,10 +2459,14 @@ static void ggml_compute_forward_gelu_quick_f16(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@@ -2435,10 +2480,14 @@ static void ggml_compute_forward_gelu_quick_f16(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_gelu_quick_f16(nc,
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@@ -2482,10 +2531,14 @@ static void ggml_compute_forward_silu_f32(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@@ -2499,10 +2552,14 @@ static void ggml_compute_forward_silu_f32(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_silu_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])));
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@@ -2521,10 +2578,14 @@ static void ggml_compute_forward_silu_f16(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@@ -2538,10 +2599,14 @@ static void ggml_compute_forward_silu_f16(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_silu_f16(nc,
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@@ -8325,10 +8390,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
GGML_ASSERT(k->type == v->type);
const ggml_type kv_type = k->type;
const auto * kv_type_traits_cpu = ggml_get_type_traits_cpu(kv_type);
const ggml_from_float_t kv_from_float = kv_type_traits_cpu->from_float;
const ggml_vec_dot_t kv_vec_dot = kv_type_traits_cpu->vec_dot;
const size_t kv_type_size = ggml_type_size(kv_type);
// broadcast factors
const int64_t rk2 = neq2/nek2;
@@ -8360,8 +8421,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q;
static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;
GGML_ASSERT(nek1 % KV_TILE_SZ == 0 && "KV sequence length must be divisible by KV_TILE_SZ");
int ir = ir0;
while (ir < ir1) {
// q indices for the start of this tile
@@ -8388,18 +8447,20 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
}
// Per-thread scratch layout:
// Q_q: Q_TILE_SZ * DK (converted Q tile in KV type)
// Q_q: Q_TILE_SZ * DK (converted Q tile — F32 for GEMM, KV type for scalar)
// KQ: Q_TILE_SZ * KV_TILE_SZ (attention scores in float)
// mask: Q_TILE_SZ * KV_TILE_SZ (mask in float)
// VKQ32: Q_TILE_SZ * DV (FP32 output accumulator)
// V32: KV_TILE_SZ * DV (F32 buffer for V tile - used for f166 conversion)
float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + CACHE_LINE_SIZE_F32);
// V32: KV_TILE_SZ * DV (F32 buffer for V tile)
// K_f32: KV_TILE_SZ * DK (F32 buffer for K tile — GEMM path)
float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + KV_TILE_SZ*DK + CACHE_LINE_SIZE_F32);
void * Q_q = base;
float * KQ = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float));
float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ;
float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ;
float * V32 = VKQ32 + Q_TILE_SZ * DV; // F32 buffer for V tile
float * V32 = VKQ32 + Q_TILE_SZ * DV;
float * K_f32 = V32 + KV_TILE_SZ * DV;
memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float));
memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
@@ -8412,28 +8473,38 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
const int iv3 = iq3 / rv3;
const int iv2 = iq2 / rv2;
for (int tq = 0; tq < tile_rows; tq++) {
const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
kv_from_float(pq, (char *)Q_q + tq * DK * kv_type_size, DK);
}
// Zero-pad remaining rows
for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
memset((char *)Q_q + tq * DK * kv_type_size, 0, DK * kv_type_size);
{
float * Q_f32 = (float *)Q_q;
for (int tq = 0; tq < tile_rows; tq++) {
const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
memcpy(Q_f32 + tq * DK, pq, DK * sizeof(float));
}
for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
memset(Q_f32 + tq * DK, 0, DK * sizeof(float));
}
}
memset(K_f32, 0, DK * KV_TILE_SZ * sizeof(float));
memset(V32, 0, KV_TILE_SZ * DV * sizeof(float));
for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) {
const int kv_tile = (int)std::min((int64_t)KV_TILE_SZ, nek1 - ic);
// skip the tile entirely if all the masks are -inf
if (mask) {
bool can_skip = true;
for (int tq = 0; tq < tile_rows; tq++) {
const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]);
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
for (int tk = 0; tk < kv_tile; tk++) {
mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]);
if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) {
can_skip = false;
}
}
// Pad remaining mask entries with -inf
for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
mask32[tq * KV_TILE_SZ + tk] = -INFINITY;
}
}
if (can_skip) {
@@ -8441,13 +8512,32 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
}
}
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
const void * q_row = (const char *)Q_q + tq * DK * kv_type_size;
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
const void * k_row = (const char *) k->data + ((ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3);
float s;
kv_vec_dot(DK, &s, 0, k_row, 0, q_row, 0, 1);
KQ[tq * KV_TILE_SZ + tk] = s * scale;
// Pack K tile transposed: K_f32[dk][kv] so KV_TILE is contiguous (SIMD dim)
// Zero-pad the last tile so the GEMM always operates on KV_TILE_SZ columns
for (int tk = 0; tk < kv_tile; tk++) {
const char * k_data = (const char *)k->data + (ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3;
if (kv_type == GGML_TYPE_F16) {
const ggml_fp16_t * k_f16 = (const ggml_fp16_t *)k_data;
for (int64_t dk = 0; dk < DK; dk++) {
K_f32[dk * KV_TILE_SZ + tk] = GGML_CPU_FP16_TO_FP32(k_f16[dk]);
}
} else {
const float * k_f32_src = (const float *)k_data;
for (int64_t dk = 0; dk < DK; dk++) {
K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk];
}
}
}
memset(KQ, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
simd_gemm(KQ, (const float *)Q_q, K_f32, Q_TILE_SZ, DK, KV_TILE_SZ);
ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, scale);
// Set padded KQ entries to -inf so softmax gives them zero weight
if (kv_tile < KV_TILE_SZ) {
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
KQ[tq * KV_TILE_SZ + tk] = -INFINITY;
}
}
}
@@ -8487,33 +8577,22 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew);
}
// Convert V tile to F32 first (if F16), then do MAD
// On x86, ggml_vec_mad_f16 internall converts F16<->F32 on every load/store, so pre-converting is faster.
// TODO: on ARM, native f16 should be faster
if (kv_type == GGML_TYPE_F16) {
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
const ggml_fp16_t * v_row = (const ggml_fp16_t *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
ggml_fp16_to_fp32_row(v_row, V32 + tk * DV, DV);
}
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
if (skip[tq]) continue;
float * vkq_row = VKQ32 + tq * DV;
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
const float p = KQ[tq * KV_TILE_SZ + tk];
ggml_vec_mad_f32(DV, vkq_row, V32 + tk * DV, p);
}
}
} else {
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
if (skip[tq]) continue;
float * vkq_row = VKQ32 + tq * DV;
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
const float p = KQ[tq * KV_TILE_SZ + tk];
const float * v_row = (const float *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
ggml_vec_mad_f32(DV, vkq_row, v_row, p);
}
// V accumulation: VKQ32 += softmax(KQ) * V
// Pack V tile to contiguous F32, zero-padded
for (int tk = 0; tk < kv_tile; tk++) {
const char * v_data = (const char *)v->data + (ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3;
if (kv_type == GGML_TYPE_F16) {
ggml_fp16_to_fp32_row((const ggml_fp16_t *)v_data, V32 + tk * DV, DV);
} else {
memcpy(V32 + tk * DV, v_data, DV * sizeof(float));
}
}
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
if (skip[tq]) {
memset(KQ + tq * KV_TILE_SZ, 0, KV_TILE_SZ * sizeof(float));
}
}
simd_gemm(VKQ32, KQ, V32, Q_TILE_SZ, KV_TILE_SZ, DV);
}
// sinks (apply only to valid rows in the tile)
@@ -8730,15 +8809,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const int64_t dr = (nr + nchunk - 1) / nchunk;
static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV;
static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
const bool use_tiled = !use_ref &&
bool use_tiled = !use_ref &&
(q->type == GGML_TYPE_F32 &&
kv_is_f32_or_f16 &&
k->type == v->type &&
nek1 % KV_TILE_SZ == 0 &&
neq1 >= Q_TILE_SZ);
#ifdef GGML_SIMD
use_tiled &= (DV % GGML_F32_EPR == 0);
#endif
int current_chunk = ith;
while (current_chunk < nchunk) {
+3 -2
View File
@@ -1916,9 +1916,10 @@ static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_in
int src_offset = (i / 8) * blck_size_interleave;
int dst_offset = i * blck_size_interleave;
// buffer large enough for the max interleave block size (8 bytes)
uint64_t elems;
memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
memcpy(&elems, &in[src_id].qs[src_offset], blck_size_interleave);
memcpy(&out.qs[dst_offset], &elems, blck_size_interleave);
}
// The below logic is designed so as to unpack and rearrange scales and mins values in Q4_K
+136
View File
@@ -0,0 +1,136 @@
#pragma once
// Computes C[M x N] += A[M x K] * B[K x N]
#include "simd-mappings.h"
// TODO: add support for sizeless vector types
#if defined(GGML_SIMD) && !defined(__ARM_FEATURE_SVE) && !defined(__riscv_v_intrinsic)
// TODO: untested on avx512
// These are in units of GGML_F32_EPR
#if defined(__AVX512F__) || defined (__ARM_NEON__)
static constexpr int GEMM_RM = 4;
static constexpr int GEMM_RN = 4; // 16+4+1 = 25/32
#elif defined(__AVX2__) || defined(__AVX__)
static constexpr int GEMM_RM = 6;
static constexpr int GEMM_RN = 2; // 12+2+1 = 15/16
#else
static constexpr int GEMM_RM = 2;
static constexpr int GEMM_RN = 2;
#endif
template <int RM, int RN>
static inline void simd_gemm_ukernel(
float * GGML_RESTRICT C,
const float * GGML_RESTRICT A,
const float * GGML_RESTRICT B,
int K, int N)
{
static constexpr int KN = GGML_F32_EPR;
GGML_F32_VEC acc[RM][RN];
for (int64_t i = 0; i < RM; i++) {
for (int r = 0; r < RN; r++) {
acc[i][r] = GGML_F32_VEC_LOAD(C + i * N + r * KN);
}
}
for (int64_t kk = 0; kk < K; kk++) {
GGML_F32_VEC Bv[RN];
for (int r = 0; r < RN; r++) {
Bv[r] = GGML_F32_VEC_LOAD(B + kk * N + r * KN);
}
for (int64_t i = 0; i < RM; i++) {
GGML_F32_VEC p = GGML_F32_VEC_SET1(A[i * K + kk]);
for (int r = 0; r < RN; r++) {
acc[i][r] = GGML_F32_VEC_FMA(acc[i][r], Bv[r], p);
}
}
}
for (int64_t i = 0; i < RM; i++) {
for (int r = 0; r < RN; r++) {
GGML_F32_VEC_STORE(C + i * N + r * KN, acc[i][r]);
}
}
}
// C[M x N] += A[M x K] * B[K x N]
static void simd_gemm(
float * GGML_RESTRICT C,
const float * GGML_RESTRICT A,
const float * GGML_RESTRICT B,
int M, int K, int N)
{
static constexpr int KN = GGML_F32_EPR;
int64_t ii = 0;
for (; ii + GEMM_RM <= M; ii += GEMM_RM) {
int64_t jj = 0;
for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) {
simd_gemm_ukernel<GEMM_RM, GEMM_RN>(C + jj, A, B + jj, K, N);
}
for (; jj + KN <= N; jj += KN) {
simd_gemm_ukernel<GEMM_RM, 1>(C + jj, A, B + jj, K, N);
}
for (; jj < N; jj++) {
for (int64_t i = 0; i < GEMM_RM; i++) {
float a = C[i * N + jj];
for (int64_t kk = 0; kk < K; kk++) {
a += A[i + kk] * B[kk * N + jj];
}
C[i * N + jj] = a;
}
}
A += GEMM_RM * K;
C += GEMM_RM * N;
}
// Tail rows: one at a time
for (; ii < M; ii++) {
int64_t jj = 0;
for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) {
simd_gemm_ukernel<1, GEMM_RN>(C + jj, A, B + jj, K, N);
}
for (; jj + KN <= N; jj += KN) {
simd_gemm_ukernel<1, 1>(C + jj, A, B + jj, K, N);
}
for (; jj < N; jj++) {
float a = C[jj];
for (int64_t kk = 0; kk < K; kk++) {
a += A[kk] * B[kk * N + jj];
}
C[jj] = a;
}
A += K;
C += N;
}
}
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic pop
#endif
#else // scalar path
static void simd_gemm(
float * GGML_RESTRICT C,
const float * GGML_RESTRICT A,
const float * GGML_RESTRICT B,
int M, int K, int N)
{
for (int64_t i = 0; i < M; i++) {
for (int64_t j = 0; j < N; j++) {
float sum = C[i * N + j];
for (int64_t kk = 0; kk < K; kk++) {
sum += A[i * K + kk] * B[kk * N + j];
}
C[i * N + j] = sum;
}
}
}
#endif // GGML_SIMD
+26
View File
@@ -1160,6 +1160,14 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
float32x4_t tmp = x[0] + vec_reve(x[0]); \
res = tmp[0] + tmp[1]; \
}
#define GGML_F32x4_REDUCE_4(res, s0, s1, s2, s3) \
{ \
float32x4_t v = vec_add(vec_add(s0, s1), \
vec_add(s2, s3)); \
v = vec_add(v, vec_sld(v, v, 8)); \
v = vec_add(v, vec_sld(v, v, 4)); \
res += (ggml_float)vec_extract(v, 0); \
}
#define GGML_F32_VEC GGML_F32x4
#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
@@ -1209,6 +1217,24 @@ static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) {
#define GGML_F16_VEC_MUL GGML_F32x4_MUL
#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
// BF16 s390x
#define GGML_BF16_STEP 16
#define GGML_BF16_EPR 8
#define GGML_BF16x8 __vector unsigned short
#define GGML_BF16x8_ZERO vec_splats((unsigned short)0)
#define GGML_BF16x8_LOAD(p) vec_xl(0, (const unsigned short *)(p))
#define GGML_BF16_VEC GGML_BF16x8
#define GGML_BF16_VEC_ZERO GGML_BF16x8_ZERO
#define GGML_BF16_VEC_LOAD GGML_BF16x8_LOAD
#define GGML_BF16_TO_F32_LO(v) ((float32x4_t) vec_mergel((v), GGML_BF16_VEC_ZERO))
#define GGML_BF16_TO_F32_HI(v) ((float32x4_t) vec_mergeh((v), GGML_BF16_VEC_ZERO))
#define GGML_BF16_FMA_LO(acc, x, y) \
(acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_LO(x), GGML_BF16_TO_F32_LO(y))
#define GGML_BF16_FMA_HI(acc, x, y) \
(acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_HI(x), GGML_BF16_TO_F32_HI(y))
#elif defined(__riscv_v_intrinsic)
// compatible with vlen >= 128
+1 -1
View File
@@ -111,7 +111,7 @@ template <float (*op)(float), typename src0_t, typename dst_t>
static void apply_unary_op(const ggml_compute_params * params, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst));
GGML_ASSERT(ggml_is_contiguous_rows(src0) && ggml_is_contiguous_rows(dst) && ggml_are_same_shape(src0, dst));
GGML_TENSOR_UNARY_OP_LOCALS
+1 -2
View File
@@ -236,8 +236,7 @@ void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t *
vfloat32m1_t redsum = __riscv_vfredusum_vs_f32m4_f32m1(vsum0, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
sumf += __riscv_vfmv_f_s_f32m1_f32(redsum);
#endif
#if defined(__POWER9_VECTOR__)
#elif defined(__POWER9_VECTOR__) || defined(__VXE__) || defined(__VXE2__)
const int np = (n & ~(GGML_BF16_STEP - 1));
if (np > 0) {
GGML_F32_VEC sum[4] = {GGML_F32_VEC_ZERO};
+2 -15
View File
@@ -1149,8 +1149,7 @@ struct ggml_cuda_graph {
size_t num_nodes = 0;
std::vector<cudaGraphNode_t> nodes;
bool disable_due_to_gpu_arch = false;
bool disable_due_to_too_many_updates = false;
int number_consecutive_updates = 0;
bool warmup_complete = false;
std::vector<ggml_cuda_graph_node_properties> props;
// these are extra tensors (inputs) that participate in the ggml graph but are not nodes
@@ -1159,21 +1158,9 @@ struct ggml_cuda_graph {
// ref: https://github.com/ggml-org/llama.cpp/pull/19165
std::vector<ggml_cuda_graph_node_properties> extra;
void record_update(bool use_graph, bool update_required) {
if (use_graph && update_required) {
number_consecutive_updates++;
} else {
number_consecutive_updates = 0;
}
if (number_consecutive_updates >= 4) {
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
disable_due_to_too_many_updates = true;
}
}
bool is_enabled() const {
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env || disable_due_to_too_many_updates);
return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env);
}
#endif
};
+38 -24
View File
@@ -7,7 +7,8 @@
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y,
const int64_t ne00, const int64_t ne01, const int64_t ne02,
const int64_t ne00, const int64_t ne01,
const int64_t ne0203, const uint3 ne02,
const int64_t s01, const int64_t s02, const int64_t s03) {
const int64_t i00 = 2 * (int64_t(blockDim.x)*blockIdx.x + threadIdx.x);
@@ -16,23 +17,27 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
}
const int64_t i01 = blockIdx.y;
const int64_t i02 = blockIdx.z % ne02;
const int64_t i03 = blockIdx.z / ne02;
const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {
const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);
const int64_t i02 = dm.y;
const int64_t i03 = dm.x;
const int64_t ib = ibx0 + i00/qk; // block index
const int64_t iqs = (i00%qk)/qr; // quant index
const int64_t iybs = i00 - i00%qk; // y block start index
const int64_t y_offset = qr == 1 ? 1 : qk/2;
const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
// dequantize
float2 v;
dequantize_kernel(vx, ib, iqs, v);
const int64_t ib = ibx0 + i00/qk; // block index
const int64_t iqs = (i00%qk)/qr; // quant index
const int64_t iybs = i00 - i00%qk; // y block start index
const int64_t y_offset = qr == 1 ? 1 : qk/2;
const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs;
y[iy0 + 0] = ggml_cuda_cast<dst_t>(v.x);
y[iy0 + y_offset] = ggml_cuda_cast<dst_t>(v.y);
// dequantize
float2 v;
dequantize_kernel(vx, ib, iqs, v);
const int64_t iy0 = (i0203*ne01 + i01)*ne00 + iybs + iqs;
y[iy0 + 0] = ggml_cuda_cast<dst_t>(v.x);
y[iy0 + y_offset] = ggml_cuda_cast<dst_t>(v.y);
}
}
template <bool need_check>
@@ -485,9 +490,11 @@ template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static void dequantize_block_cuda(const void * vx, dst_t * y,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, ne02*ne03);
const int64_t ne0203 = ne02*ne03;
const uint3 ne02_fdv = init_fastdiv_values(ne02);
const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, (int)std::min(ne0203, (int64_t)65535));
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
(vx, y, ne00, ne01, ne02, s01, s02, s03);
(vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03);
}
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
@@ -612,7 +619,8 @@ static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t
template <typename src_t, typename dst_t>
static __global__ void convert_unary(
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01,
const int64_t ne0203, const uint3 ne02,
const int64_t s01, const int64_t s02, const int64_t s03) {
const int64_t i00 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
@@ -621,23 +629,29 @@ static __global__ void convert_unary(
}
const int64_t i01 = blockIdx.y;
const int64_t i02 = blockIdx.z % ne02;
const int64_t i03 = blockIdx.z / ne02;
const src_t * x = (const src_t *) vx;
const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
const int64_t iy = ((i03*ne02 + i02)*ne01 + i01)*ne00 + i00;
y[iy] = ggml_cuda_cast<dst_t>(x[ix]);
for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {
const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);
const int64_t i02 = dm.y;
const int64_t i03 = dm.x;
const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
const int64_t iy = (i0203*ne01 + i01)*ne00 + i00;
y[iy] = ggml_cuda_cast<dst_t>(x[ix]);
}
}
template <typename src_t, typename dst_t>
static void convert_unary_cuda(const void * vx, dst_t * y,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, ne02*ne03);
const int64_t ne0203 = ne02*ne03;
const uint3 ne02_fdv = init_fastdiv_values(ne02);
const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, (int)std::min(ne0203, (int64_t)65535));
convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
(vx, y, ne00, ne01, ne02, s01, s02, s03);
(vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03);
}
template <typename src_t, typename dst_t>
+3 -1
View File
@@ -1186,8 +1186,10 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
const int gqa_ratio = Q->ne[2] / K->ne[2];
// On NVIDIA (Pascal and older) the GQA optimizations seem to be detrimental in some cases.
// However, for DKQ == 576, DV == 512 only the kernel variant with GQA optimizations is implemented.
const bool nvidia = GGML_CUDA_CC_IS_NVIDIA(ggml_cuda_info().devices[ggml_cuda_get_device()].cc);
const int gqa_limit = nvidia && gqa_ratio <= 4 ? 16 : INT_MAX;
const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX;
const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0;
if constexpr (DV == 512) {
+26 -5
View File
@@ -63,11 +63,19 @@ static __global__ void flash_attn_ext_f16(
constexpr int frag_m = ncols == 8 ? 32 : 16;
constexpr int frag_n = ncols == 8 ? 8 : 16;
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
#if defined(GGML_USE_HIP) && HIP_VERSION >= 60500000
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, _Float16, wmma::row_major> frag_a_K;
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, _Float16, wmma::col_major> frag_a_V;
typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, _Float16, wmma::col_major> frag_b;
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, _Float16> frag_c_VKQ;
#else
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::row_major> frag_a_K;
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::col_major> frag_a_V;
typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, half, wmma::col_major> frag_b;
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
#endif
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
@@ -126,6 +134,19 @@ static __global__ void flash_attn_ext_f16(
__shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
half2 * VKQ2 = (half2 *) VKQ;
#if defined(GGML_USE_HIP) && HIP_VERSION >= 60500000
const _Float16 * K_h_f16 = reinterpret_cast<const _Float16 *>(K_h);
const _Float16 * V_h_f16 = reinterpret_cast<const _Float16 *>(V_h);
_Float16 * KQ_f16 = reinterpret_cast<_Float16 *>(KQ);
_Float16 * VKQ_f16 = reinterpret_cast<_Float16 *>(VKQ);
#else
const half * K_h_f16 = K_h;
const half * V_h_f16 = V_h;
half * KQ_f16 = KQ;
half * VKQ_f16 = VKQ;
#endif
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
@@ -160,7 +181,7 @@ static __global__ void flash_attn_ext_f16(
for (int i0 = 0; i0 < D; i0 += 16) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ_f16 + j0*D_padded + i0, D_padded);
}
}
@@ -180,7 +201,7 @@ static __global__ void flash_attn_ext_f16(
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
frag_a_K K_a;
wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
wmma::load_matrix_sync(K_a, K_h_f16 + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
@@ -310,7 +331,7 @@ static __global__ void flash_attn_ext_f16(
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
wmma::load_matrix_sync(
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
KQ + j0*(kqar*kqs_padded) + k,
KQ_f16 + j0*(kqar*kqs_padded) + k,
kqar*kqs_padded);
}
}
@@ -328,7 +349,7 @@ static __global__ void flash_attn_ext_f16(
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
frag_a_V v_a;
wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
wmma::load_matrix_sync(v_a, V_h_f16 + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
@@ -344,7 +365,7 @@ static __global__ void flash_attn_ext_f16(
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
wmma::store_matrix_sync(
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
KQ_f16 + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
D_padded, wmma::mem_col_major);
}
+46 -43
View File
@@ -2278,11 +2278,12 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
// [TAG_MUL_MAT_ID_CUDA_GRAPHS]
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE);
if (ne2 <= MMVQ_MAX_BATCH_SIZE) {
if (ggml_is_quantized(src0->type)) {
if (ne2 <= 4) {
if (ne2 <= MMVQ_MMID_MAX_BATCH_SIZE) {
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
return;
}
@@ -2305,6 +2306,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
}
}
// note: this path should not be reached when recording CUDA graphs, because it requires stream synchronization
// TODO: add asserts to verify this. should work with CUDA, HIP, etc.
cudaStream_t stream = ctx.stream();
GGML_ASSERT(nb12 % nb11 == 0);
@@ -2865,14 +2868,6 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
bool use_cuda_graph = true;
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out";
const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d";
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
@@ -2887,30 +2882,14 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
#endif
}
if (node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) {
use_cuda_graph = false; // This node type is not supported by CUDA graph capture
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
#endif
}
if (node->op == GGML_OP_ADD &&
node->src[1] && node->src[1]->ne[1] > 1 &&
(node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
(node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 &&
strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 &&
strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) {
// disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
// by means of matching node names. See
// https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
// https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
// Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
// [TAG_MUL_MAT_ID_CUDA_GRAPHS]
if (node->op == GGML_OP_MUL_MAT_ID && (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > MMVQ_MMID_MAX_BATCH_SIZE)) {
// under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs
// TODO: figure out a way to enable for larger batch sizes, without hurting performance
// ref: https://github.com/ggml-org/llama.cpp/pull/18958
use_cuda_graph = false;
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
#endif
}
@@ -3000,10 +2979,6 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
const void * graph_key = ggml_cuda_graph_get_key(cgraph);
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
if (graph->instance == nullptr) {
res = true;
}
// Check if the graph size has changed
if (graph->props.size() != (size_t)cgraph->n_nodes) {
res = true;
@@ -3640,11 +3615,13 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
n_fuse++;
if (n_fuse > 1) {
ggml_tensor fused_add_node;
memcpy(&fused_add_node, node, sizeof(ggml_tensor));
for (int j = 0; j < n_fuse - 1; ++j) {
node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
fused_add_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
}
cgraph->nodes[i + n_fuse - 1]->data = node->data;
ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse);
fused_add_node.data = cgraph->nodes[i + n_fuse - 1]->data;
ggml_cuda_op_fused_add(*cuda_ctx, &fused_add_node, n_fuse);
i += n_fuse - 1;
continue;
@@ -3950,14 +3927,35 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
#ifdef USE_CUDA_GRAPH
graph_key = ggml_cuda_graph_get_key(cgraph);
use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);
ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
if (graph->is_enabled()) {
cuda_graph_update_required = ggml_cuda_graph_update_required(cuda_ctx, cgraph);
use_cuda_graph = ggml_cuda_graph_check_compability(cgraph);
const bool graph_compatible = ggml_cuda_graph_check_compability(cgraph);
if (graph_compatible) {
const bool properties_changed = ggml_cuda_graph_update_required(cuda_ctx, cgraph);
graph->record_update(use_cuda_graph, cuda_graph_update_required);
if (!graph->warmup_complete) {
// Warmup: need at least 2 calls with no property change on the 2nd call
if (!properties_changed) {
graph->warmup_complete = true;
GGML_LOG_DEBUG("%s: CUDA graph warmup complete\n", __func__);
use_cuda_graph = true;
cuda_graph_update_required = true;
}
// else: properties changed or first call - execute directly (use_cuda_graph stays false)
} else {
// Post-warmup: normal CUDA graph operation
if (properties_changed) {
// Properties changed - reset warmup, execute directly until stable again
graph->warmup_complete = false;
GGML_LOG_DEBUG("%s: CUDA graph warmup reset\n", __func__);
} else {
use_cuda_graph = true;
cuda_graph_update_required = graph->instance == nullptr;
}
}
}
}
#endif // USE_CUDA_GRAPH
@@ -4542,6 +4540,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_UNARY_OP_CEIL:
case GGML_UNARY_OP_ROUND:
case GGML_UNARY_OP_TRUNC:
// TODO: should become:
//return ggml_is_contiguous_rows(op->src[0]);
return ggml_is_contiguous(op->src[0]);
default:
return false;
@@ -4820,8 +4820,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_CONV_2D_DW:
case GGML_OP_CONV_TRANSPOSE_2D:
case GGML_OP_POOL_2D:
case GGML_OP_ACC:
return true;
case GGML_OP_ACC:
// TODO: extend support like so:
//return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]);
return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
case GGML_OP_SUM:
return ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_TOP_K:
+21 -16
View File
@@ -2715,14 +2715,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
#pragma unroll
for (int l = 0; l < QR2_XXS; ++l) {
const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]);
const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F];
const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[l]];
const uint32_t signs = unpack_ksigns(aux32 >> (7 * l));
const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0);
const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
const int grid1 = __vsub4(grid_pos.y ^ signs1, signs1);
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
@@ -2733,12 +2733,12 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
const int ls = aux32 >> 28;
const int ls = aux32 >> 27 | 1; // (scale * 2 + 1)
const float d = bxi->d;
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4
#else
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4;
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
}
@@ -2776,11 +2776,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
#pragma unroll
for (int l = 0; l < QR2_XS; ++l) {
const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF));
const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9));
const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l] & 0x1FF];
const uint32_t signs = unpack_ksigns(q2[l] >> 9);
const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
@@ -2904,11 +2907,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
#pragma unroll
for (int l = 0; l < QR3_XXS; ++l) {
const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]);
const uint32_t signs = unpack_ksigns(aux32 >> (7*l));
const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F));
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
+1
View File
@@ -1,6 +1,7 @@
#include "common.cuh"
#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
#define MMVQ_MMID_MAX_BATCH_SIZE 4 // Max. batch size for which to use MMVQ kernels for MUL_MAT_ID
void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr);
+31 -17
View File
@@ -94,6 +94,15 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con
#endif
}
static __device__ __forceinline__ uint32_t unpack_ksigns(const uint8_t v) {
// v is a 7 bit int, with the 8th sign being encodable as popcnt
// with xor we can "correct" the bit instead of having to mask
const uint32_t p = __popc(v) & 1;
const uint32_t s = v ^ p << 7;
// broadcast over uint to allow for 0x08040201 / 0x80402010 as selectors
return s * 0x01010101;
}
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
@@ -905,22 +914,22 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
int sumi = 0;
#pragma unroll
for (int k0 = 0; k0 < 8; k0 += 2) {
const int * grid_pos = (const int *) (iq2xxs_grid + aux8[k0/2]);
const int signs_packed = ksigns_iq2xs[(aux32 >> (7*k0/2)) & 0x7F];
const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[k0/2]];
const uint32_t signs = unpack_ksigns(aux32 >> (7 * k0 / 2));
const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0);
const int u0 = get_int_b4(bq8_1[iqs/2].qs, k0 + 0);
sumi = ggml_cuda_dp4a(grid0, u0, sumi);
const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
const int grid1 = __vsub4(grid_pos.y ^ signs1, signs1);
const int u1 = get_int_b4(bq8_1[iqs/2].qs, k0 + 1);
sumi = ggml_cuda_dp4a(grid1, u1, sumi);
}
const int ls = aux32 >> 28;
sumi = (ls*sumi + sumi/2)/4;
const int ls = aux32 >> 27 | 1; // (scale * 2 + 1)
sumi = sumi * ls / 8; // (sumi * scale + sumi / 2) / 4
const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds);
return d * sumi;
}
@@ -942,13 +951,15 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
int sumi1 = 0;
#pragma unroll
for (int l0 = 0; l0 < 8; l0 += 2) {
const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l0/2] & 0x000001FF));
const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l0/2] >> 9));
const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l0/2] & 0x1FF];
const uint32_t signs = unpack_ksigns(q2[l0/2] >> 9);
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
if (l0 < 4) {
@@ -1028,13 +1039,16 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
#pragma unroll
for (int l0 = 0; l0 < 8; l0 += 2) {
const int2 grid_pos = make_int2(iq3xxs_grid[q3[l0 + 0]], iq3xxs_grid[q3[l0 + 1]]);
const uint32_t signs = unpack_ksigns(aux32 >> (7*l0/2));
const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l0/2)) & 0x7F));
const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
sumi = ggml_cuda_dp4a(grid_l, u0, sumi);
+254 -272
View File
@@ -17,121 +17,6 @@
#include "htp-msg.h"
#include "htp-ops.h"
static inline HVX_Vector hvx_load_f32_to_f16(const HVX_Vector * restrict src, const HVX_Vector zero) {
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(src[0], zero); // 32 elements
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(src[1], zero); // 32 elements
return Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
}
// Dot product of FP32 and FP16 vectors, accumulating to float
static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) {
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
const HVX_Vector zero = Q6_V_vsplat_R(0);
HVX_Vector rsum = Q6_V_vsplat_R(0);
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; i++) {
// Load y (fp32) and convert into fp16
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
// Load x (fp16)
HVX_Vector x_hf = vx[i];
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
}
if (nloe) {
// Load y (fp32) and convert into fp16
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
// Load x (fp16)
HVX_Vector x_hf = vx[i];
// Zero-out unused elements
// Note that we need to clear both x and y because they may contain NANs
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
x_hf = Q6_V_vand_QV(bmask, x_hf);
y_hf = Q6_V_vand_QV(bmask, y_hf);
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
}
rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum));
hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum));
}
// Dot product of FP32 and FP16 vectors, accumulating to float
static inline void hvx_dot_f32_f16_aa_rx2(float * restrict r,
const void * restrict y,
const void * restrict x0,
const void * restrict x1,
unsigned int n,
float s) {
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16
const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
const HVX_Vector zero = Q6_V_vsplat_R(0);
HVX_Vector rsum0 = Q6_V_vsplat_R(0);
HVX_Vector rsum1 = Q6_V_vsplat_R(0);
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
// Load y (fp32) and convert into fp16
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
// Load x (fp16)
HVX_Vector x0_hf = vx0[i];
HVX_Vector x1_hf = vx1[i];
HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
}
if (nloe) {
// Load y (fp32) and convert into fp16
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
// Load x (fp16)
HVX_Vector x0_hf = vx0[i];
HVX_Vector x1_hf = vx1[i];
// Zero-out unused elements
// Note that we need to clear both x and y because they may contain NANs
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
x0_hf = Q6_V_vand_QV(bmask, x0_hf);
x1_hf = Q6_V_vand_QV(bmask, x1_hf);
y_hf = Q6_V_vand_QV(bmask, y_hf);
HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
}
HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1));
hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum));
}
// Dot product of two F16 vectors, accumulating to float
static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) {
const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
@@ -140,8 +25,7 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
const HVX_Vector zero = Q6_V_vsplat_R(0);
HVX_Vector rsum = Q6_V_vsplat_R(0);
HVX_Vector rsum = Q6_V_vsplat_R(0);
uint32_t i = 0;
@@ -156,11 +40,10 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
}
if (nloe) {
HVX_Vector y_hf = vy[i];
// Load x (fp16) and zero-out unused elements
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
@@ -181,12 +64,11 @@ static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r,
const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
const HVX_Vector zero = Q6_V_vsplat_R(0);
HVX_Vector rsum0 = Q6_V_vsplat_R(0);
HVX_Vector rsum1 = Q6_V_vsplat_R(0);
HVX_Vector rsum0 = Q6_V_vsplat_R(0);
HVX_Vector rsum1 = Q6_V_vsplat_R(0);
uint32_t i = 0;
@@ -204,12 +86,11 @@ static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r,
}
if (nloe) {
HVX_Vector y_hf = vy[i];
// Load x (fp16) and zero-out unused elements
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
@@ -222,7 +103,7 @@ static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r,
hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum));
}
// MAD: y (F32) += x (F16) * s (float)
// MAD: y (F32) += x (F16) * s (F32)
static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) {
const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x;
HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
@@ -259,15 +140,125 @@ static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict
}
}
// MAD: y (F32) += x0 (F16) * s0 (F32) + x1 (F16) * s1 (F32)
static inline void hvx_mad_f32_f16_aa_rx2(float * restrict y,
const void * restrict x0,
const void * restrict x1,
float s0,
float s1,
int n) {
const HVX_Vector * restrict ptr_x0 = (const HVX_Vector *) x0;
const HVX_Vector * restrict ptr_x1 = (const HVX_Vector *) x1;
HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
HVX_Vector S0 = hvx_vec_splat_f16(s0);
HVX_Vector S1 = hvx_vec_splat_f16(s1);
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; ++i) {
// Multiply x * s -> pair of F32 vectors
HVX_VectorPair xs0_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x0[i]), S0);
HVX_VectorPair xs1_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x1[i]), S1);
HVX_Vector xs_p_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xs0_p), Q6_V_lo_W(xs1_p));
HVX_Vector xs_p_hi = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(xs0_p), Q6_V_hi_W(xs1_p));
ptr_y[i * 2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs_p_lo, ptr_y[i * 2]));
ptr_y[i * 2 + 1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs_p_hi, ptr_y[i * 2 + 1]));
}
if (nloe) {
HVX_VectorPair xs0_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x0[i]), S0);
HVX_VectorPair xs1_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x1[i]), S1);
HVX_Vector xs_p_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xs0_p), Q6_V_lo_W(xs1_p));
HVX_Vector xs = xs_p_lo;
i = 2 * i; // index for ptr_y
if (nloe >= 32) {
ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
nloe -= 32; ++i;
xs = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(xs0_p), Q6_V_hi_W(xs1_p));
}
if (nloe) {
HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
hvx_vec_store_a(&ptr_y[i], nloe * 4, xy);
}
}
}
#define FLASH_ATTN_BLOCK_SIZE 128
static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, int nth) {
struct htp_fa_context {
const struct htp_ops_context * octx;
struct fastdiv_values src0_div21;
struct fastdiv_values src0_div1;
struct fastdiv_values broadcast_rk2;
struct fastdiv_values broadcast_rk3;
struct fastdiv_values broadcast_rv2;
struct fastdiv_values broadcast_rv3;
struct fastdiv_values src3_div2;
struct fastdiv_values src3_div3;
float scale;
float max_bias;
float logit_softcap;
uint32_t n_head_log2;
float m0;
float m1;
uint32_t n_blocks;
size_t size_q_row_padded;
size_t size_k_row_padded;
size_t size_v_row_padded;
size_t size_k_block;
size_t size_v_block;
size_t size_m_block;
bool is_q_fp32;
};
static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, HVX_Vector vs) {
assert((size_t) dst % 128 == 0);
assert((size_t) src % 128 == 0);
const HVX_Vector * restrict vsrc = (const HVX_Vector * restrict) src;
HVX_Vector * restrict vdst = (HVX_Vector * restrict) dst;
const uint32_t nvec = n / VLEN_FP32;
const uint32_t nloe = n % VLEN_FP32;
uint32_t i = 0;
#pragma unroll(4)
for (; i < nvec; ++i) {
vdst[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs));
}
if (nloe) {
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
hvx_vec_store_a(&vdst[i], nloe * sizeof(float), Q6_Vsf_equals_Vqf32(v));
}
}
static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * data) {
struct htp_fa_context * factx = (struct htp_fa_context *) data;
const struct htp_ops_context * octx = factx->octx;
const struct htp_tensor * q = &octx->src0;
const struct htp_tensor * k = &octx->src1;
const struct htp_tensor * v = &octx->src2;
const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;
const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL;
struct htp_tensor * dst = &octx->dst;
const struct htp_tensor * dst = &octx->dst;
const uint32_t neq0 = q->ne[0];
const uint32_t neq1 = q->ne[1];
@@ -304,18 +295,6 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
const uint32_t nb2 = dst->nb[2];
const uint32_t nb3 = dst->nb[3];
float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;
memcpy(&scale, (float *) octx->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float));
memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));
if (logit_softcap != 0) {
scale /= logit_softcap;
}
// total rows in q
const uint32_t nr = neq1*neq2*neq3;
@@ -331,18 +310,8 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
const uint32_t DV = nev0;
const size_t size_q_row = DK * ((q->type == HTP_TYPE_F32) ? 4 : 2);
const size_t size_q_row_padded = hex_round_up(size_q_row, 128);
const size_t size_k_row = DK * sizeof(__fp16);
const size_t size_v_row = DV * sizeof(__fp16);
const size_t size_m_row = FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16); // Treat block as one row for mask
const size_t size_k_row_padded = hex_round_up(size_k_row, 128);
const size_t size_v_row_padded = hex_round_up(size_v_row, 128);
const size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
const size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
const size_t size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
// Scratchpad buffers for Q, K, V, Mask, and VKQ32 accumulator
uint8_t * spad_q = octx->src0_spad.data + octx->src0_spad.size_per_thread * ith;
@@ -351,31 +320,28 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
uint8_t * spad_m = octx->src3_spad.data + octx->src3_spad.size_per_thread * ith;
uint8_t * spad_a = octx->dst_spad.data + octx->dst_spad.size_per_thread * ith;
const uint32_t n_head = neq2;
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
const HVX_Vector logit_cap = hvx_vec_splat_f32(factx->logit_softcap);
for (uint32_t ir = ir0; ir < ir1; ++ir) {
const uint32_t iq3 = fastdiv(ir, &octx->src0_div21);
const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &octx->src0_div1);
const uint32_t iq3 = fastdiv(ir, &factx->src0_div21);
const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &factx->src0_div1);
const uint32_t iq1 = (ir - iq3*neq2*neq1 - iq2 * neq1);
const uint32_t ik3 = fastdiv(iq3, &octx->broadcast_rk3);
const uint32_t ik2 = fastdiv(iq2, &octx->broadcast_rk2);
const uint32_t ik3 = fastdiv(iq3, &factx->broadcast_rk3);
const uint32_t ik2 = fastdiv(iq2, &factx->broadcast_rk2);
const uint32_t iv3 = fastdiv(iq3, &octx->broadcast_rv3);
const uint32_t iv2 = fastdiv(iq2, &octx->broadcast_rv2);
const uint32_t iv3 = fastdiv(iq3, &factx->broadcast_rv3);
const uint32_t iv2 = fastdiv(iq2, &factx->broadcast_rv2);
// Fetch Q row
const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3);
dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), size_q_row_padded, nbq1, size_q_row, 1);
dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), factx->size_q_row_padded, nbq1, size_q_row, 1);
const uint32_t h = iq2; // head index
const float slope = (max_bias > 0.0f) ? (h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1)) : 1.0f;
const float slope = (factx->max_bias > 0.0f) ? (h < factx->n_head_log2 ? powf(factx->m0, h + 1) : powf(factx->m1, 2*(h - factx->n_head_log2) + 1)) : 1.0f;
float S = 0.0f; // sum
float M = -INFINITY; // maximum KQ value
HVX_Vector S_vec = hvx_vec_splat_f32(0.0f);
HVX_Vector M_vec = hvx_vec_splat_f32(-INFINITY);
// Clear accumulator
hvx_splat_f32_a(spad_a, 0, DV);
@@ -383,40 +349,42 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
const __fp16 * mp_base = NULL;
if (mask) {
const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &octx->src3_div2);
const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &octx->src3_div3);
const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &factx->src3_div2);
const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &factx->src3_div3);
mp_base = (const __fp16 *) ((const uint8_t *) mask->data + iq1*mask->nb[1] + im2*mask->nb[2] + im3*mask->nb[3]);
}
const uint32_t n_blocks = (nek1 + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE;
// Prefetch first two blocks
for (uint32_t ib = 0; ib < MIN(n_blocks, 2); ++ib) {
for (uint32_t ib = 0; ib < MIN(factx->n_blocks, 2); ++ib) {
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
// K
const uint8_t * k_src = (const uint8_t *) k->data + (ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
uint8_t * k_dst = spad_k + (ib % 2) * size_k_block;
dma_queue_push(dma, dma_make_ptr(k_dst, k_src), size_k_row_padded, nbk1, size_k_row, current_block_size);
uint8_t * k_dst = spad_k + (ib % 2) * factx->size_k_block;
dma_queue_push(dma, dma_make_ptr(k_dst, k_src), factx->size_k_row_padded, nbk1, size_k_row, current_block_size);
// V
const uint8_t * v_src = (const uint8_t *) v->data + (ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
uint8_t * v_dst = spad_v + (ib % 2) * size_v_block;
dma_queue_push(dma, dma_make_ptr(v_dst, v_src), size_v_row_padded, nbv1, size_v_row, current_block_size);
uint8_t * v_dst = spad_v + (ib % 2) * factx->size_v_block;
dma_queue_push(dma, dma_make_ptr(v_dst, v_src), factx->size_v_row_padded, nbv1, size_v_row, current_block_size);
// Mask
if (mask) {
const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start);
uint8_t * m_dst = spad_m + (ib % 2) * size_m_block;
uint8_t * m_dst = spad_m + (ib % 2) * factx->size_m_block;
// Mask is 1D contiguous for this row
dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
}
}
const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
if (factx->is_q_fp32) {
hvx_copy_f16_f32_aa(q_ptr_vtcm, q_ptr_vtcm, DK); // inplace convert f32 to f16
}
for (uint32_t ib = 0; ib < n_blocks; ++ib) {
const HVX_Vector slope_vec = hvx_vec_splat_f16(slope);
for (uint32_t ib = 0; ib < factx->n_blocks; ++ib) {
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
@@ -428,8 +396,6 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
// Inner loop processing the block from VTCM
uint32_t ic = 0;
const bool is_q_fp32 = (q->type == HTP_TYPE_F32);
// Process in blocks of 32 (VLEN_FP32)
static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 <= 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage");
HVX_Vector_x4 scores_x4;
@@ -437,22 +403,18 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) {
// 1. Compute scores
float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
for (int j = 0; j < VLEN_FP32; j += 2) {
for (uint32_t j = 0; j < VLEN_FP32; j += 2) {
const uint32_t cur_ic = ic + j;
const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
if (is_q_fp32) {
hvx_dot_f32_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
} else {
hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
}
const uint8_t * k_ptr = k_base + cur_ic * factx->size_k_row_padded;
hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + factx->size_k_row_padded, DK, factx->scale);
}
HVX_Vector scores = *(HVX_Vector *) scores_arr;
// 2. Softcap
if (logit_softcap != 0.0f) {
if (factx->logit_softcap != 0.0f) {
scores = hvx_vec_tanh_f32(scores);
scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_f32(logit_softcap));
scores = Q6_Vqf32_vmpy_VsfVsf(scores, logit_cap);
scores = Q6_Vsf_equals_Vqf32(scores);
}
@@ -460,70 +422,59 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
if (mask) {
const __fp16 * mp = m_base + ic;
HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp;
HVX_Vector one_f16 = Q6_Vh_vsplat_R(0x3c00);
HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), one_f16);
HVX_Vector m_vals_f32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_f32_pair));
HVX_Vector slope_vec = hvx_vec_splat_f32(slope);
HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_f32, slope_vec);
scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val));
HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec);
HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair);
scores = Q6_Vqf32_vadd_Vqf32Vsf(add_val, scores);
scores = Q6_Vsf_equals_Vqf32(scores);
}
scores_x4.v[iv] = scores;
v_max = Q6_Vsf_vmax_VsfVsf(scores, v_max);
v_max = hvx_vec_reduce_max2_f32(scores, v_max); // All lanes have block max
}
{
// 4. Online Softmax Update
v_max = hvx_vec_reduce_max_f32(v_max);
float m_block = hvx_vec_get_f32(v_max);
float M_old = M;
float M_new = (m_block > M) ? m_block : M;
M = M_new;
HVX_Vector M_new_vec = Q6_Vsf_vmax_VsfVsf(v_max, M_vec);
HVX_Vector diff_vec = Q6_Vqf32_vsub_VsfVsf(M_vec, M_new_vec);
HVX_Vector ms_vec = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(diff_vec));
M_vec = M_new_vec;
const float ms = expf(M_old - M_new);
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
HVX_Vector M_new_vec = hvx_vec_splat_f32(M_new);
HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f);
for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) {
HVX_Vector scores = scores_x4.v[iv];
HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_vec);
HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted));
p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P));
// 5. Accumulate V
float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
*(HVX_Vector*)p_arr = P;
*(HVX_Vector *) p_arr = P;
for (int j = 0; j < VLEN_FP32; ++j) {
const uint32_t cur_ic = ic2 + j;
const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]);
for (uint32_t j = 0; j < VLEN_FP32; j += 2) {
const uint32_t cur_ic = ic2 + j;
const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded;
hvx_mad_f32_f16_aa_rx2(VKQ32, v_ptr, v_ptr + factx->size_v_row_padded, p_arr[j], p_arr[j + 1], DV);
}
}
p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec);
S = S * ms + hvx_vec_get_f32(p_sum_vec);
S_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(S_vec, ms_vec)), p_sum_vec));
}
// Sync scalars for leftover/next block if needed
float M = hvx_vec_get_f32(M_vec);
float S = hvx_vec_get_f32(S_vec);
// Leftover
for (; ic < current_block_size; ++ic) {
float s_val;
const uint8_t * k_ptr = k_base + ic * size_k_row_padded;
if (is_q_fp32) {
hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
} else {
hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
}
if (logit_softcap != 0.0f) {
s_val = logit_softcap * tanhf(s_val);
const uint8_t * k_ptr = k_base + ic * factx->size_k_row_padded;
hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, factx->scale);
if (factx->logit_softcap != 0.0f) {
s_val = factx->logit_softcap * tanhf(s_val);
}
if (mask) {
@@ -532,37 +483,42 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
}
const float Mold = M;
float ms = 1.0f;
float vs = 1.0f;
if (s_val > M) {
M = s_val;
ms = expf(Mold - M);
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
HVX_Vector diff_vec = hvx_vec_splat_f32(Mold - M);
HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec);
hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
float ms = hvx_vec_get_f32(ms_vec);
S = S * ms + vs;
} else {
vs = expf(s_val - M);
HVX_Vector diff_vec = hvx_vec_splat_f32(s_val - M);
vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec));
S += vs;
}
const uint8_t * v_ptr = v_base + ic * size_v_row_padded;
const uint8_t * v_ptr = v_base + ic * factx->size_v_row_padded;
hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, vs);
S = S * ms + vs;
}
M_vec = hvx_vec_splat_f32(M);
S_vec = hvx_vec_splat_f32(S);
// Issue DMA for next+1 block (if exists)
if (ib + 2 < n_blocks) {
if (ib + 2 < factx->n_blocks) {
const uint32_t next_ib = ib + 2;
const uint32_t next_ic_start = next_ib * FLASH_ATTN_BLOCK_SIZE;
const uint32_t next_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - next_ic_start);
// K
const uint8_t * k_src = (const uint8_t *) k->data + (next_ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
dma_queue_push(dma, dma_make_ptr(k_base, k_src), size_k_row_padded, nbk1, size_k_row, next_block_size);
dma_queue_push(dma, dma_make_ptr(k_base, k_src), factx->size_k_row_padded, nbk1, size_k_row, next_block_size);
// V
const uint8_t * v_src = (const uint8_t *) v->data + (next_ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
dma_queue_push(dma, dma_make_ptr(v_base, v_src), size_v_row_padded, nbv1, size_v_row, next_block_size);
dma_queue_push(dma, dma_make_ptr(v_base, v_src), factx->size_v_row_padded, nbv1, size_v_row, next_block_size);
// Mask
if (mask) {
@@ -573,20 +529,26 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
}
// sinks
float M = hvx_vec_get_f32(M_vec);
float S = hvx_vec_get_f32(S_vec);
if (sinks) {
const float s = ((float *)((char *) sinks->data))[h];
float ms = 1.0f;
float vs = 1.0f;
if (s > M) {
ms = expf(M - s);
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
} else {
vs = expf(s - M);
}
HVX_Vector diff_vec = hvx_vec_splat_f32(M - s);
HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec);
hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
S = S * ms + vs;
float ms = hvx_vec_get_f32(ms_vec);
S = S * ms + vs;
} else {
HVX_Vector diff_vec = hvx_vec_splat_f32(s - M);
vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec));
S += vs;
}
}
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
@@ -609,53 +571,73 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
}
}
static void htp_flash_attn_ext_job(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = data;
flash_attn_ext_f16_thread(octx, i, n);
}
int op_flash_attn_ext(struct htp_ops_context * octx) {
const struct htp_tensor * q = &octx->src0;
const struct htp_tensor * k = &octx->src1;
const struct htp_tensor * v = &octx->src2;
const struct htp_tensor * mask = (octx->src3.type != HTP_TYPE_COUNT) ? &octx->src3 : NULL;
struct htp_tensor * dst = &octx->dst;
const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;
const struct htp_tensor * dst = &octx->dst;
// Check support
if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) ||
k->type != HTP_TYPE_F16 ||
v->type != HTP_TYPE_F16) {
if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) || k->type != HTP_TYPE_F16 || v->type != HTP_TYPE_F16) {
return HTP_STATUS_NO_SUPPORT;
}
octx->src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
octx->src0_div1 = init_fastdiv_values(q->ne[1]);
struct htp_fa_context factx;
factx.octx = octx;
octx->broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
octx->broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
octx->broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
octx->broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
factx.src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
factx.src0_div1 = init_fastdiv_values(q->ne[1]);
factx.broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
factx.broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
factx.broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
factx.broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
if (mask) {
octx->src3_div2 = init_fastdiv_values(mask->ne[2]);
octx->src3_div3 = init_fastdiv_values(mask->ne[3]);
factx.src3_div2 = init_fastdiv_values(mask->ne[2]);
factx.src3_div3 = init_fastdiv_values(mask->ne[3]);
}
size_t size_q_row_padded = hex_round_up(q->ne[0] * (q->type == HTP_TYPE_F32 ? 4 : 2), 128);
size_t size_k_row_padded = hex_round_up(k->ne[0] * sizeof(__fp16), 128);
size_t size_v_row_padded = hex_round_up(v->ne[0] * sizeof(__fp16), 128);
factx.is_q_fp32 = (q->type == HTP_TYPE_F32);
factx.size_q_row_padded = hex_round_up(q->ne[0] * (factx.is_q_fp32 ? 4 : 2), 128);
factx.size_k_row_padded = hex_round_up(k->ne[0] * sizeof(__fp16), 128);
factx.size_v_row_padded = hex_round_up(v->ne[0] * sizeof(__fp16), 128);
size_t size_q_block = size_q_row_padded * 1; // single row for now
size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
size_t size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
size_t size_q_block = factx.size_q_row_padded * 1; // single row for now
factx.size_k_block = factx.size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
factx.size_v_block = factx.size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
factx.size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
factx.n_blocks = (k->ne[1] + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE;
float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;
memcpy(&scale, (float *) octx->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float));
memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));
if (logit_softcap != 0.0f) {
scale /= logit_softcap;
}
factx.scale = scale;
factx.max_bias = max_bias;
factx.logit_softcap = logit_softcap;
uint32_t n_head = q->ne[2];
factx.n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
factx.m0 = powf(2.0f, -(max_bias ) / factx.n_head_log2);
factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2);
size_t size_vkq_acc = hex_round_up(v->ne[0] * sizeof(float), 128); // VKQ32
octx->src0_spad.size_per_thread = size_q_block * 1;
octx->src1_spad.size_per_thread = size_k_block * 2;
octx->src2_spad.size_per_thread = size_v_block * 2;
octx->src3_spad.size_per_thread = mask ? size_m_block * 2 : 0;
octx->src1_spad.size_per_thread = factx.size_k_block * 2;
octx->src2_spad.size_per_thread = factx.size_v_block * 2;
octx->src3_spad.size_per_thread = mask ? factx.size_m_block * 2 : 0;
octx->dst_spad.size_per_thread = size_vkq_acc;
octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
@@ -677,7 +659,7 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size;
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
worker_pool_run_func(octx->ctx->worker_pool, htp_flash_attn_ext_job, octx, octx->n_threads);
worker_pool_run_func(octx->ctx->worker_pool, flash_attn_ext_f16_thread, &factx, octx->n_threads);
}
return HTP_STATUS_OK;
-13
View File
@@ -64,25 +64,12 @@ struct htp_ops_context {
struct fastdiv_values broadcast_rv2;
struct fastdiv_values broadcast_rv3;
struct fastdiv_values mm_div_ne12_ne1; // fastdiv values for ne12 * ne1
struct fastdiv_values mm_div_ne1; // fastdiv values for ne1
struct fastdiv_values mm_div_r2; // fastdiv values for ne12 / ne02
struct fastdiv_values mm_div_r3; // fastdiv values for ne13 / ne03
struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12
struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11
struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10
struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11
struct fastdiv_values cpy_div_ne01; // fastdiv values for ne01
struct fastdiv_values cpy_div_ne02; // fastdiv values for ne02
struct fastdiv_values cpy_div_ne03; // fastdiv values for ne03
struct fastdiv_values cpy_rshp_div_n0; // fastdiv values for ne00
struct fastdiv_values cpy_rshp_div_n1n0; // fastdiv values for ne00*ne01
struct fastdiv_values cpy_rshp_div_n2n1n0; // fastdiv values for ne00*ne01*ne02
uint32_t flags;
};
+1 -1
View File
@@ -189,7 +189,7 @@ static int vtcm_release_callback(unsigned int rctx, void * state) {
// otherwise we'll release it once we're done with the current Op.
if (ctx->vtcm_inuse) {
ctx->vtcm_needs_release = false;
ctx->vtcm_needs_release = true;
return 0;
}
File diff suppressed because it is too large Load Diff
+4
View File
@@ -98,6 +98,10 @@ static bool ggml_op_is_empty(enum ggml_op op) {
}
}
static inline bool ggml_impl_is_view(const struct ggml_tensor * t) {
return t->view_src != NULL;
}
static inline float ggml_compute_softplus_f32(float input) {
return (input > 20.0f) ? input : logf(1 + expf(input));
}
+13 -2
View File
@@ -264,15 +264,26 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_GROUP_NORM:
case GGML_OP_L2_NORM:
case GGML_OP_SUM_ROWS:
case GGML_OP_SSM_CONV:
case GGML_OP_SSM_SCAN:
case GGML_OP_CLAMP:
case GGML_OP_TRI:
case GGML_OP_DIAG:
case GGML_OP_MUL:
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_DIV:
case GGML_OP_GLU:
case GGML_OP_SCALE:
case GGML_OP_UNARY:
case GGML_OP_GET_ROWS:
case GGML_OP_CPY:
case GGML_OP_SET_ROWS:
case GGML_OP_SET:
case GGML_OP_CPY:
case GGML_OP_CONT:
case GGML_OP_REPEAT:
return true;
default:
return ggml_op_is_empty(op);
@@ -312,7 +323,7 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
h_add(mrs1, node0);
// that many nodes forward to search for a concurrent node
constexpr int N_FORWARD = 8;
constexpr int N_FORWARD = 64;
for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) {
if (used[i1]) {
+31 -13
View File
@@ -328,31 +328,46 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum(ggml_metal_l
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
char base[256];
char name[256];
const char * op_str = "undefined";
int op_num = -1;
switch (op->op) {
case GGML_OP_SUM_ROWS:
op_str = "sum_rows"; break;
case GGML_OP_MEAN:
op_str = "mean"; break;
case GGML_OP_SUM_ROWS: op_num = OP_SUM_ROWS_NUM_SUM_ROWS; break;
case GGML_OP_MEAN: op_num = OP_SUM_ROWS_NUM_MEAN; break;
default: GGML_ABORT("fatal error");
};
snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type));
const char * t0_str = ggml_type_name(op->src[0]->type);
const char * t_str = ggml_type_name(op->type);
snprintf(name, 256, "%s", base);
const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
snprintf(base, 256, "kernel_sum_rows_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
snprintf(name, 256, "%s_op=%d", base, op_num);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_int16(cv, op_num, FC_SUM_ROWS + 0);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_cv_free(cv);
}
res.smem = 32*sizeof(float);
if (is_c4) {
res.smem *= 4;
}
res.c4 = is_c4;
return res;
}
@@ -1480,13 +1495,15 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one(ggml_met
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_L2_NORM);
GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
char base[256];
char name[256];
snprintf(base, 256, "kernel_l2_norm_f32");
const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
const char * t0_str = ggml_type_name(op->src[0]->type);
const char * t_str = ggml_type_name(op->type);
snprintf(base, 256, "kernel_l2_norm_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
@@ -1494,6 +1511,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
}
res.c4 = is_c4;
res.smem = 32*sizeof(float);
return res;
+5 -5
View File
@@ -1019,7 +1019,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_LOG:
return ggml_is_contiguous_rows(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_TANH:
@@ -1039,7 +1039,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_SOFTPLUS:
case GGML_UNARY_OP_EXPM1:
return ggml_is_contiguous_rows(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
default:
return false;
}
@@ -1067,8 +1067,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_ADD_ID:
return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ACC:
return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_REPEAT:
case GGML_OP_CONV_TRANSPOSE_1D:
return true;
@@ -1086,9 +1086,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_MEAN:
case GGML_OP_SOFT_MAX:
case GGML_OP_GROUP_NORM:
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_L2_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_COUNT_EQUAL:
return has_simdgroup_reduction &&
op->src[0]->type == GGML_TYPE_I32 &&
@@ -1160,6 +1159,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
return has_simdgroup_reduction;
case GGML_OP_SET:
case GGML_OP_CPY:
case GGML_OP_DUP:
case GGML_OP_CONT:
+17 -1
View File
@@ -82,6 +82,7 @@
#define FC_COUNT_EQUAL 1100
#define FC_UNARY 1200
#define FC_BIN 1300
#define FC_SUM_ROWS 1400
// op-specific constants
#define OP_FLASH_ATTN_EXT_NQPSG 8
@@ -118,6 +119,8 @@
#define OP_UNARY_NUM_SOFTPLUS 115
#define OP_UNARY_NUM_EXPM1 116
#define OP_SUM_ROWS_NUM_SUM_ROWS 10
#define OP_SUM_ROWS_NUM_MEAN 11
// kernel argument structs
//
@@ -539,8 +542,21 @@ typedef struct {
typedef struct {
int32_t ne00;
int32_t ne00_4;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
float eps;
} ggml_metal_kargs_l2_norm;
+196 -28
View File
@@ -426,6 +426,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx);
} break;
case GGML_OP_SET:
{
n_fuse = ggml_metal_op_set(ctx, idx);
} break;
case GGML_OP_DUP:
case GGML_OP_CPY:
case GGML_OP_CONT:
@@ -616,8 +620,8 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
GGML_ASSERT(op->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
GGML_ASSERT(ggml_is_contiguous(op->src[1]));
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
const size_t pnb1 = ((const int32_t *) op->op_params)[0];
const size_t pnb2 = ((const int32_t *) op->op_params)[1];
@@ -667,10 +671,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
}
ggml_metal_kargs_bin args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.ne00 =*/ ne10,
/*.ne01 =*/ ne11,
/*.ne02 =*/ ne12,
/*.ne03 =*/ ne13,
/*.nb00 =*/ nb00,
/*.nb01 =*/ pnb1,
/*.nb02 =*/ pnb2,
@@ -683,10 +687,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.ne0 =*/ ne10,
/*.ne1 =*/ ne11,
/*.ne2 =*/ ne12,
/*.ne3 =*/ ne13,
/*.nb0 =*/ nb0,
/*.nb1 =*/ pnb1,
/*.nb2 =*/ pnb2,
@@ -703,7 +707,13 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
int nth = 1;
while (2*nth < args.ne0 && nth < nth_max) {
nth *= 2;
}
ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1);
@@ -904,6 +914,11 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
ggml_metal_kargs_sum_rows args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
@@ -925,21 +940,26 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
if (pipeline.c4) {
args.ne00 = ne00/4;
args.ne0 = ne0/4;
}
int nth = 32; // SIMD width
while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
while (nth < args.ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth *= 2;
}
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
nth = std::min(nth, ne00);
nth = std::min(nth, (int) args.ne00);
const size_t smem = pipeline.smem;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
@@ -1599,6 +1619,134 @@ int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
return 1;
}
int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
const size_t pnb1 = ((const int32_t *) op->op_params)[0];
const size_t pnb2 = ((const int32_t *) op->op_params)[1];
const size_t pnb3 = ((const int32_t *) op->op_params)[2];
const size_t offs = ((const int32_t *) op->op_params)[3];
const bool inplace = (bool) ((const int32_t *) op->op_params)[4];
if (!inplace) {
// run a separete kernel to cpy src->dst
// not sure how to avoid this
// TODO: make a simpler cpy_bytes kernel
//const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
ggml_metal_kargs_cpy args = {
/*.nk0 =*/ ne00,
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
};
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
ggml_metal_op_concurrency_reset(ctx);
}
auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[1]->type, op->type);
GGML_ASSERT(ne10 % ggml_blck_size(op->src[1]->type) == 0);
int64_t nk0 = ne10;
if (ggml_is_quantized(op->src[1]->type)) {
nk0 = ne10/16;
} else if (ggml_is_quantized(op->type)) {
nk0 = ne10/ggml_blck_size(op->type);
}
int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
// when rows are small, we can batch them together in a single threadgroup
int nrptg = 1;
// TODO: relax this constraint in the future
if (ggml_blck_size(op->src[1]->type) == 1 && ggml_blck_size(op->type) == 1) {
if (nth > nk0) {
nrptg = (nth + nk0 - 1)/nk0;
nth = nk0;
if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nrptg--;
}
}
}
nth = std::min<int>(nth, nk0);
ggml_metal_kargs_cpy args = {
/*.nk0 =*/ nk0,
/*.ne00 =*/ ne10,
/*.ne01 =*/ ne11,
/*.ne02 =*/ ne12,
/*.ne03 =*/ ne13,
/*.nb00 =*/ nb10,
/*.nb01 =*/ nb11,
/*.nb02 =*/ nb12,
/*.nb03 =*/ nb13,
/*.ne0 =*/ ne10,
/*.ne1 =*/ ne11,
/*.ne2 =*/ ne12,
/*.ne3 =*/ ne13,
/*.nb0 =*/ ggml_element_size(op),
/*.nb1 =*/ pnb1,
/*.nb2 =*/ pnb2,
/*.nb3 =*/ pnb3,
};
const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
bid_dst.offs += offs;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne11 + nrptg - 1)/nrptg, ne12, ne13, nth, nrptg, 1);
return 1;
}
int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
@@ -2979,39 +3127,59 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
float eps;
memcpy(&eps, op->op_params, sizeof(float));
int nth = 32; // SIMD width
ggml_metal_kargs_l2_norm args = {
/*.ne00 =*/ ne00,
/*.ne00_4 =*/ ne00/4,
/*.nb01 =*/ nb01,
/*.eps =*/ eps,
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.eps =*/ eps,
};
auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
if (pipeline.c4) {
args.ne00 = ne00/4;
args.ne0 = ne0/4;
}
int nth = 32; // SIMD width
while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth *= 2;
}
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
nth = std::min(nth, ne00/4);
const size_t smem = pipeline.smem;
const int64_t nrows = ggml_nrows(op->src[0]);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
return 1;
}
+1
View File
@@ -59,6 +59,7 @@ int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_set (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_pool_1d (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_pool_2d (ggml_metal_op_t ctx, int idx);
+121 -72
View File
@@ -77,6 +77,14 @@ static inline float dot(float x, float y) {
return x*y;
}
static inline float sum(float x) {
return x;
}
static inline float sum(float4 x) {
return x[0] + x[1] + x[2] + x[3];
}
// NOTE: this is not dequantizing - we are simply fitting the template
template <typename type4x4>
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
@@ -910,7 +918,7 @@ constant float a4_erf = -1.453152027f;
constant float a5_erf = 1.061405429f;
template<typename T>
T erf_approx(T x) {
inline T erf_approx(T x) {
T sign_x = sign(x);
x = fabs(x);
T t = 1.0f / (1.0f + p_erf * x);
@@ -918,10 +926,27 @@ T erf_approx(T x) {
return sign_x * y;
}
template<typename T> T elu_approx(T x);
template<> inline float elu_approx<float>(float x) {
return (x > 0.f) ? x : (exp(x) - 1);
}
template<> inline float4 elu_approx<float4>(float4 x) {
float4 res;
res[0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
res[1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
res[2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
res[3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
return res;
}
constant short FC_unary_op [[function_constant(FC_UNARY + 0)]];
constant bool FC_unary_cnt[[function_constant(FC_UNARY + 1)]];
template <typename T0, typename T>
template <typename T0, typename T, typename TC>
kernel void kernel_unary_impl(
constant ggml_metal_kargs_unary & args,
device const char * src0,
@@ -963,111 +988,111 @@ kernel void kernel_unary_impl(
}
}
device const T0 & x = src0_ptr[i0];
const TC x = (TC) src0_ptr[i0];
if (FC_OP == OP_UNARY_NUM_SCALE) {
dst_ptr[i0] = args.scale * x + args.bias;
dst_ptr[i0] = (T) (args.scale * x + args.bias);
}
if (FC_OP == OP_UNARY_NUM_FILL) {
dst_ptr[i0] = args.val;
dst_ptr[i0] = (T) args.val;
}
if (FC_OP == OP_UNARY_NUM_CLAMP) {
dst_ptr[i0] = clamp(x, args.min, args.max);
dst_ptr[i0] = (T) clamp(x, args.min, args.max);
}
if (FC_OP == OP_UNARY_NUM_SQR) {
dst_ptr[i0] = x * x;
dst_ptr[i0] = (T) (x * x);
}
if (FC_OP == OP_UNARY_NUM_SQRT) {
dst_ptr[i0] = sqrt(x);
dst_ptr[i0] = (T) sqrt(x);
}
if (FC_OP == OP_UNARY_NUM_SIN) {
dst_ptr[i0] = sin(x);
dst_ptr[i0] = (T) sin(x);
}
if (FC_OP == OP_UNARY_NUM_COS) {
dst_ptr[i0] = cos(x);
dst_ptr[i0] = (T) cos(x);
}
if (FC_OP == OP_UNARY_NUM_LOG) {
dst_ptr[i0] = log(x);
dst_ptr[i0] = (T) log(x);
}
if (FC_OP == OP_UNARY_NUM_LEAKY_RELU) {
dst_ptr[i0] = T(x > 0.0f)*x + T(x <= 0.0f)*(x * args.slope);
dst_ptr[i0] = (T) (TC(x > 0)*x + TC(x <= 0)*(x * args.slope));
}
if (FC_OP == OP_UNARY_NUM_TANH) {
dst_ptr[i0] = precise::tanh(x);
dst_ptr[i0] = (T) precise::tanh(x);
}
if (FC_OP == OP_UNARY_NUM_RELU) {
dst_ptr[i0] = fmax(0.0f, x);
dst_ptr[i0] = (T) fmax(0, x);
}
if (FC_OP == OP_UNARY_NUM_SIGMOID) {
dst_ptr[i0] = 1.0f / (1.0f + exp(-x));
dst_ptr[i0] = (T) (1 / (1 + exp(-x)));
}
if (FC_OP == OP_UNARY_NUM_GELU) {
dst_ptr[i0] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
dst_ptr[i0] = (T) (0.5*x*(1 + precise::tanh(SQRT_2_OVER_PI*x*(1 + GELU_COEF_A*x*x))));
}
if (FC_OP == OP_UNARY_NUM_GELU_ERF) {
dst_ptr[i0] = 0.5f*x*(1.0f + erf_approx(SQRT_2_INV*x));
dst_ptr[i0] = (T) (0.5*x*(1 + erf_approx(SQRT_2_INV*x)));
}
if (FC_OP == OP_UNARY_NUM_GELU_QUICK) {
dst_ptr[i0] = x * (1.0f/(1.0f + exp(GELU_QUICK_COEF*x)));
dst_ptr[i0] = (T) (x * (1/(1 + exp(GELU_QUICK_COEF*x))));
}
if (FC_OP == OP_UNARY_NUM_SILU) {
dst_ptr[i0] = x / (1.0f + exp(-x));
dst_ptr[i0] = (T) (x / (1 + exp(-x)));
}
if (FC_OP == OP_UNARY_NUM_ELU) {
dst_ptr[i0] = T(x > 0.0f)*x + T(x <= 0.0f)*(exp(x) - 1.0f);
dst_ptr[i0] = (T) elu_approx(x);
}
if (FC_OP == OP_UNARY_NUM_NEG) {
dst_ptr[i0] = -x;
dst_ptr[i0] = (T) -x;
}
if (FC_OP == OP_UNARY_NUM_ABS) {
dst_ptr[i0] = fabs(x);
dst_ptr[i0] = (T) fabs(x);
}
if (FC_OP == OP_UNARY_NUM_SGN) {
dst_ptr[i0] = T(x > 0.0f) - T(x < 0.0f);
dst_ptr[i0] = T(x > 0) - T(x < 0);
}
if (FC_OP == OP_UNARY_NUM_STEP) {
dst_ptr[i0] = T(x > 0.0f);
dst_ptr[i0] = T(x > 0);
}
if (FC_OP == OP_UNARY_NUM_HARDSWISH) {
dst_ptr[i0] = x * fmax(0.0f, fmin(1.0f, x/6.0f + 0.5f));
dst_ptr[i0] = (T) (x * fmax(0, fmin(1, x/6 + 0.5)));
}
if (FC_OP == OP_UNARY_NUM_HARDSIGMOID) {
dst_ptr[i0] = fmax(0.0f, fmin(1.0f, x/6.0f + 0.5f));
dst_ptr[i0] = (T) fmax(0, fmin(1, x/6 + 0.5));
}
if (FC_OP == OP_UNARY_NUM_EXP) {
dst_ptr[i0] = exp(x);
dst_ptr[i0] = (T) exp(x);
}
if (FC_OP == OP_UNARY_NUM_SOFTPLUS) {
dst_ptr[i0] = select(log(1.0f + exp(x)), x, x > 20.0f);
dst_ptr[i0] = (T) select(log(1 + exp(x)), x, x > 20);
}
if (FC_OP == OP_UNARY_NUM_EXPM1) {
// TODO: precise implementation
dst_ptr[i0] = exp(x) - 1.0f;
dst_ptr[i0] = (T) (exp(x) - 1);
}
}
@@ -1075,11 +1100,12 @@ kernel void kernel_unary_impl(
#undef FC_CNT
}
typedef decltype(kernel_unary_impl<float, float>) kernel_unary_t;
template [[host_name("kernel_unary_f32_f32")]] kernel kernel_unary_t kernel_unary_impl<float, float>;
template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl<float4, float4>;
typedef decltype(kernel_unary_impl<float, float, float>) kernel_unary_t;
template [[host_name("kernel_unary_f32_f32")]] kernel kernel_unary_t kernel_unary_impl<float, float, float>;
template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl<float4, float4, float4>;
template [[host_name("kernel_unary_f16_f16")]] kernel kernel_unary_t kernel_unary_impl<half, half, float>;
template [[host_name("kernel_unary_f16_f16_4")]] kernel kernel_unary_t kernel_unary_impl<half4, half4, float4>;
// OP: 0 - add, 1 - sub, 2 - mul, 3 - div
constant short FC_bin_op [[function_constant(FC_BIN + 0)]];
@@ -1483,33 +1509,35 @@ kernel void kernel_op_sum_f32(
}
}
template <bool norm>
kernel void kernel_sum_rows(
constant short FC_sum_rows_op [[function_constant(FC_SUM_ROWS + 0)]];
template <typename T0, typename T>
kernel void kernel_sum_rows_impl(
constant ggml_metal_kargs_sum_rows & args,
device const float * src0,
device float * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
device const char * src0,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
int64_t i3 = tgpig.z;
int64_t i2 = tgpig.y;
int64_t i1 = tgpig.x;
#define FC_OP FC_sum_rows_op
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
return;
}
const int i3 = tgpig.z;
const int i2 = tgpig.y;
const int i1 = tgpig.x;
threadgroup T0 * shmem_t = (threadgroup T0 *) shmem;
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
shmem_t[tiisg] = 0.0f;
}
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
device const T0 * src_row = (device const T0 *) (src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
device T * dst_row = (device T *) (dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
float sumf = 0;
T0 sumf = T0(0.0f);
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
sumf += src_row[i0];
@@ -1520,23 +1548,33 @@ kernel void kernel_sum_rows(
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
shmem_t[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg];
sumf = shmem_t[tiisg];
sumf = simd_sum(sumf);
if (tpitg.x == 0) {
dst_row[0] = norm ? sumf / args.ne00 : sumf;
if (FC_OP == OP_SUM_ROWS_NUM_MEAN) {
if (is_same<float4, T0>::value) {
dst_row[0] = sum(sumf) / (4*args.ne00);
} else {
dst_row[0] = sum(sumf) / args.ne00;
}
} else {
dst_row[0] = sum(sumf);
}
}
#undef FC_OP
}
typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
typedef decltype(kernel_sum_rows_impl<float, float>) kernel_sum_rows_t;
template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
template [[host_name("kernel_sum_rows_f32_f32")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float, float>;
template [[host_name("kernel_sum_rows_f32_f32_4")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float4, float>;
template<typename T>
kernel void kernel_cumsum_blk(
@@ -2417,9 +2455,6 @@ kernel void kernel_solve_tri_f32(
const short K = FC_solve_tri_k;
const short NP = PAD2(N, NW);
const int32_t ne02 = args.ne02;
const int32_t ne03 = args.ne03;
const int32_t i03 = tgpig.z;
const int32_t i02 = tgpig.y;
const int32_t i01 = tgpig.x*NSG + sgitg;
@@ -2706,26 +2741,32 @@ template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_f
template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>;
template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 3>;
kernel void kernel_l2_norm_f32(
template <typename T0, typename T>
kernel void kernel_l2_norm_impl(
constant ggml_metal_kargs_l2_norm & args,
device const char * src0,
device char * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
ushort tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort ntg[[threads_per_threadgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig.z;
const int i02 = tgpig.y;
const int i01 = tgpig.x;
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
device const T0 * x = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
float sumf = 0.0f;
// parallel sum
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
sumf += dot(x[i00], x[i00]);
}
sumf = simd_sum(sumf);
@@ -2743,12 +2784,16 @@ kernel void kernel_l2_norm_f32(
const float scale = 1.0f/sqrt(max(sumf, args.eps));
device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
y[i00] = x[i00] * scale;
}
}
typedef decltype(kernel_l2_norm_impl<float, float>) kernel_l2_norm_t;
template [[host_name("kernel_l2_norm_f32_f32")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float, float>;
template [[host_name("kernel_l2_norm_f32_f32_4")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float4, float4>;
kernel void kernel_group_norm_f32(
constant ggml_metal_kargs_group_norm & args,
device const float * src0,
@@ -5921,7 +5966,7 @@ kernel void kernel_flash_attn_ext_vec(
static_assert(DK4 % NL == 0, "DK4 must be divisible by NL");
static_assert(DV4 % NL == 0, "DV4 must be divisible by NL");
const short T = PK + NSG*SH; // shared memory size per query in (half)
//const short T = PK + NSG*SH; // shared memory size per query in (half)
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t
@@ -8509,7 +8554,9 @@ kernel void kernel_mul_mm(
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
#ifdef GGML_METAL_HAS_TENSOR
threadgroup float * sc = (threadgroup float *)(shmem);
#endif
constexpr int NR0 = 64;
constexpr int NR1 = 32;
@@ -8632,8 +8679,8 @@ kernel void kernel_mul_mm(
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
const short dx = sx;
const short dy = sy;
//const short dx = sx;
//const short dy = sy;
const short ly = (tiitg/NL1)%8;
@@ -8882,7 +8929,9 @@ kernel void kernel_mul_mm_id(
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
#ifdef GGML_METAL_HAS_TENSOR
threadgroup float * sc = (threadgroup float *)(shmem);
#endif
constexpr int NR0 = 64;
constexpr int NR1 = 32;
@@ -9017,8 +9066,8 @@ kernel void kernel_mul_mm_id(
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
const short dx = sx;
const short dy = sy;
//const short dx = sx;
//const short dy = sy;
const short ly = (tiitg/NL1)%8;
+6
View File
@@ -85,6 +85,9 @@ set(GGML_OPENCL_KERNELS
mul_mv_q4_0_f32_8x_flat
mul_mv_q4_0_f32_1d_8x_flat
mul_mv_q4_0_f32_1d_16x_flat
mul_mv_q4_1_f32
mul_mv_q4_1_f32_flat
mul_mv_q4_k_f32
mul_mv_q6_k_f32
mul_mv_q6_k_f32_flat
mul_mv_q8_0_f32
@@ -100,7 +103,10 @@ set(GGML_OPENCL_KERNELS
gemv_moe_mxfp4_f32
mul_mm_f32_f32_l4_lm
mul_mm_f16_f32_l4_lm
mul_mm_q4_0_f32_l4_lm
mul_mm_q4_1_f32_l4_lm
mul_mm_q8_0_f32_l4_lm
mul_mm_q6_k_f32_l4_lm
mul_mm_q8_0_f32_8x4
gemv_noshuffle_general_q8_0_f32
mul
File diff suppressed because it is too large Load Diff
+51
View File
@@ -46,6 +46,15 @@ struct block_q4_0
uint8_t qs[QK4_0 / 2];
};
//------------------------------------------------------------------------------
// block_q4_1
//------------------------------------------------------------------------------
struct block_q4_1 {
half d; // delta
half m; // min
uchar qs[QK4_1 / 2]; // nibbles / quants
};
//------------------------------------------------------------------------------
// block_q6_K
//------------------------------------------------------------------------------
@@ -148,6 +157,48 @@ kernel void kernel_restore_block_q4_0_noshuffle(
}
}
//------------------------------------------------------------------------------
// kernel_convert_block_q4_1
// Convert the block_q4_1 format to 2 separate arrays (AOS -> SOA).
// This kernel does not deshuffle the bits.
//------------------------------------------------------------------------------
kernel void kernel_convert_block_q4_1(
global struct block_q4_1 * src0,
global uchar * dst_q,
global half * dst_d,
global half * dst_m
) {
global struct block_q4_1 * b = (global struct block_q4_1 *) src0 + get_global_id(0);
global uchar * q = (global uchar *) dst_q + QK4_1/2*get_global_id(0);
global half * d = (global half *) dst_d + get_global_id(0);
global half * m = (global half *) dst_m + get_global_id(0);
*d = b->d;
*m = b->m;
for (int i = 0; i < QK4_1/2; ++i) {
q[i] = b->qs[i];
}
}
kernel void kernel_restore_block_q4_1(
global uchar * src_q,
global half * src_d,
global half * src_m,
global struct block_q4_1 * dst
) {
global struct block_q4_1 * b = (global struct block_q4_1 *) dst + get_global_id(0);
global uchar * q = (global uchar *) src_q + QK4_1/2*get_global_id(0);
global half * d = (global half *) src_d + get_global_id(0);
global half * m = (global half *) src_m + get_global_id(0);
b->d = *d;
b->m = *m;
for (int i = 0; i < QK4_1/2; ++i) {
b->qs[i] = q[i];
}
}
//------------------------------------------------------------------------------
// block_mxfp4
//------------------------------------------------------------------------------
+87 -56
View File
@@ -3,80 +3,111 @@
//------------------------------------------------------------------------------
// expm1
//------------------------------------------------------------------------------
kernel void kernel_expm1_f32_nd(
global void * p_src0_base,
ulong off_src0_abs,
global void * p_dst_base,
ulong off_dst_abs,
int ne00,
int ne01,
int ne02,
int ne03,
kernel void kernel_expm1_f32(
global const float * src0,
ulong offset0,
global float * dst,
ulong offsetd
) {
src0 = (global float*)((global char*)src0 + offset0);
dst = (global float*)((global char*)dst + offsetd);
dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0f;
}
kernel void kernel_expm1_f32_4(
global const float4 * src0,
ulong offset0,
global float4 * dst,
ulong offsetd
) {
src0 = (global float4*)((global char*)src0 + offset0);
dst = (global float4*)((global char*)dst + offsetd);
dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0f;
}
kernel void kernel_expm1_f16(
global const half * src0,
ulong offset0,
global half * dst,
ulong offsetd
) {
src0 = (global half*)((global char*)src0 + offset0);
dst = (global half*)((global char*)dst + offsetd);
dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0h;
}
kernel void kernel_expm1_f16_4(
global const half4 * src0,
ulong offset0,
global half4 * dst,
ulong offsetd
) {
src0 = (global half4*)((global char*)src0 + offset0);
dst = (global half4*)((global char*)dst + offsetd);
dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0h;
}
kernel void kernel_expm1_f32_nc(
global const char * src0,
ulong offset0,
global char * dst,
ulong offsetd,
int ne00,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
int ne11,
int ne12,
int ne13,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb13
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3
) {
int i0 = get_global_id(0);
int i1 = get_global_id(1);
int i2 = get_global_id(2);
src0 = src0 + offset0;
dst = dst + offsetd;
if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
for (int i3 = 0; i3 < ne13; ++i3) {
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
const int i3 = get_group_id(2);
const int i2 = get_group_id(1);
const int i1 = get_group_id(0);
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
*dst_val_ptr = exp(*src_val_ptr) - 1;
}
*y = exp(*x) - 1.0f;
}
}
kernel void kernel_expm1_f16_nd(
global void * p_src0_base,
ulong off_src0_abs,
global void * p_dst_base,
ulong off_dst_abs,
int ne00,
int ne01,
int ne02,
int ne03,
kernel void kernel_expm1_f16_nc(
global const char * src0,
ulong offset0,
global char * dst,
ulong offsetd,
int ne00,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
int ne11,
int ne12,
int ne13,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb13
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3
) {
int i0 = get_global_id(0);
int i1 = get_global_id(1);
int i2 = get_global_id(2);
src0 = src0 + offset0;
dst = dst + offsetd;
if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
for (int i3 = 0; i3 < ne13; ++i3) {
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
const int i3 = get_group_id(2);
const int i2 = get_group_id(1);
const int i1 = get_group_id(0);
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
global const half * x = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
global half * y = (global half *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
*dst_val_ptr = exp(*src_val_ptr) - 1;
}
*y = exp(*x) - 1.0f;
}
}
+116 -15
View File
@@ -1,8 +1,13 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
// Most devices have max workgroup size of 1024, so this is enough for subgroup
// sizes of 16, 32, 64 and 128. Increase this value for smaller subgroups sizes
#define MAX_SUBGROUPS 64
kernel void kernel_mean_f32(
global float * src0,
global char * src0,
ulong offset0,
global float * dst,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
@@ -15,25 +20,121 @@ kernel void kernel_mean_f32(
ulong nb2,
ulong nb3
) {
src0 = (global float *)((global char *)src0 + offset0);
dst = (global float *)((global char *)dst + offsetd);
src0 = src0 + offset0;
dst = dst + offsetd;
int i3 = get_global_id(2);
int i2 = get_global_id(1);
int i1 = get_global_id(0);
const int i3 = get_group_id(2);
const int i2 = get_group_id(1);
const int i1 = get_group_id(0);
const int lid = get_local_id(0);
const int lsize = get_local_size(0);
const uint sg_size = get_sub_group_size();
const uint sg_id = get_sub_group_id();
const uint sg_lid = get_sub_group_local_id();
__local float lmem[MAX_SUBGROUPS];
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
return;
}
global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
global float * dst_row = (global float *) ((global char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
float row_sum = 0;
for (int i0 = 0; i0 < ne00; i0++) {
row_sum += src_row[i0];
if(sg_id == 0){
lmem[sg_lid] = 0.0f;
}
dst_row[0] = row_sum / ne00;
global float * src_row = (global float *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);
global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3);
float sumf = 0.0f;
for (int i0 = lid; i0 < ne00; i0 += lsize) {
sumf += src_row[i0];
}
sumf = sub_group_reduce_add(sumf);
barrier(CLK_LOCAL_MEM_FENCE);
if(sg_lid == 0){
lmem[sg_id] = sumf;
}
barrier(CLK_LOCAL_MEM_FENCE);
sumf = lmem[sg_lid];
sumf = sub_group_reduce_add(sumf);
if (lid == 0) {
dst_row[0] = sumf / ne00;
}
}
kernel void kernel_mean_f32_4(
global char * src0,
ulong offset0,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb01,
ulong nb02,
ulong nb03,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = src0 + offset0;
dst = dst + offsetd;
const int i3 = get_group_id(2);
const int i2 = get_group_id(1);
const int i1 = get_group_id(0);
const int lid = get_local_id(0);
const int lsize = get_local_size(0);
const uint sg_size = get_sub_group_size();
const uint sg_id = get_sub_group_id();
const uint sg_lid = get_sub_group_local_id();
__local float lmem[MAX_SUBGROUPS];
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
return;
}
if(sg_id == 0){
lmem[sg_lid] = 0.0f;
}
global float4 * src_row = (global float4 *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);
global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3);
float4 sum_vec = (float4)0.0f;
for (int i0 = lid; i0 < ne00 / 4; i0 += lsize) {
sum_vec += src_row[i0];
}
float sumf = dot(sum_vec, (float4)(1.0f));
sumf = sub_group_reduce_add(sumf);
barrier(CLK_LOCAL_MEM_FENCE);
if(sg_lid == 0){
lmem[sg_id] = sumf;
}
barrier(CLK_LOCAL_MEM_FENCE);
sumf = lmem[sg_lid];
sumf = sub_group_reduce_add(sumf);
if (lid == 0) {
dst_row[0] = sumf / ne00;
}
}
@@ -0,0 +1,163 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#define LOAD_VEC_A 8
#define LOAD_VEC_B 4
#define BM 64
#define BN 64
#define BK 32
#define TM 4
#define TN 8
kernel void kernel_mul_mm_q4_0_f32_l4_lm(
global uchar4 * src0_q,
global half * src0_d,
global float4 * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne11,
int ne12,
int stride_a,
int stride_b,
int stride_d,
int batch_stride_a,
int batch_stride_b,
int batch_stride_d,
int r2,
int r3
) {
src1 = (global float4*)((global char*)src1 + offset1);
dst = (global float *)((global char*)dst + offsetd);
local float buf_a[BM * BK];
local float buf_b[BN * BK];
const int batch_idx = get_global_id(2);
const int i13 = batch_idx / ne12;
const int i12 = batch_idx % ne12;
const int i03 = i13 / r3;
const int i02 = i12 / r2;
const int batch_idx_a = i03 * ne02 + i02;
const int ir = get_group_id(0);
const int ic = get_group_id(1);
const int tid = get_local_id(0);
const int th_r = tid % (BM / TM);
const int th_c = tid / (BM / TM);
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
float sums[TM * TN];
float cache_a[TM];
float cache_b[TN];
for (int i = 0; i < TM * TN; i++) {
sums[i] = 0.0f;
}
for (int block = 0; block < ne00; block += BK) {
for (int l = 0; l < BM; l += loadstride_a) {
if (ir*BM + loadc_a + l < ne01) {
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
int ib = idx / 4;
int iqs = idx % 4;
float d = (float)src0_d[ib];
global uchar4 * qs = src0_q + ib*4 + iqs;
uchar4 q = *qs;
float4 v1 = (convert_float4((uchar4)((q.s0 )&0x0F, (q.s1 )&0x0F, (q.s2 )&0x0F, (q.s3 )&0x0F)) - 8.0f)*d;
float4 v2 = (convert_float4((uchar4)((q.s0>>4)&0x0F, (q.s1>>4)&0x0F, (q.s2>>4)&0x0F, (q.s3>>4)&0x0F)) - 8.0f)*d;
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0;
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1;
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2;
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3;
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0;
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1;
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2;
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3;
} else {
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f;
}
}
for (int l = 0; l < BN; l += loadstride_b) {
if (ic*BN + loadc_b + l < ne11) {
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
} else {
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
}
}
barrier(CLK_LOCAL_MEM_FENCE);
pos_a += BK / LOAD_VEC_A;
pos_b += BK / LOAD_VEC_B;
for (int i = 0; i < BK; i++) {
for (int j = 0; j < TM; j++) {
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
}
for (int j = 0; j < TN; j++) {
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
}
for (int cc = 0; cc < TN; cc++) {
for (int cr = 0; cr < TM; cr++) {
const int sums_idx = cc*TM + cr;
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
}
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
const int dr = ir * BM + th_r * TM;
const int dc = ic * BN + th_c * TN;
const int offsets = batch_idx * batch_stride_d;
for (int cc = 0; cc < TN; cc++) {
for (int cr = 0; cr < TM; cr++) {
if (dr + cr < ne01 && dc + cc < ne11) {
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
}
}
}
}
@@ -0,0 +1,165 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#define LOAD_VEC_A 8
#define LOAD_VEC_B 4
#define BM 64
#define BN 64
#define BK 32
#define TM 4
#define TN 8
kernel void kernel_mul_mm_q4_1_f32_l4_lm(
global uchar4 * src0_q,
global half * src0_d,
global half * src0_m,
global float4 * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne11,
int ne12,
int stride_a,
int stride_b,
int stride_d,
int batch_stride_a,
int batch_stride_b,
int batch_stride_d,
int r2,
int r3
) {
src1 = (global float4*)((global char*)src1 + offset1);
dst = (global float *)((global char*)dst + offsetd);
local float buf_a[BM * BK];
local float buf_b[BN * BK];
const int batch_idx = get_global_id(2);
const int i13 = batch_idx / ne12;
const int i12 = batch_idx % ne12;
const int i03 = i13 / r3;
const int i02 = i12 / r2;
const int batch_idx_a = i03 * ne02 + i02;
const int ir = get_group_id(0);
const int ic = get_group_id(1);
const int tid = get_local_id(0);
const int th_r = tid % (BM / TM);
const int th_c = tid / (BM / TM);
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
float sums[TM * TN];
float cache_a[TM];
float cache_b[TN];
for (int i = 0; i < TM * TN; i++) {
sums[i] = 0.0f;
}
for (int block = 0; block < ne00; block += BK) {
for (int l = 0; l < BM; l += loadstride_a) {
if (ir*BM + loadc_a + l < ne01) {
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
int ib = idx / 4;
int iqs = idx % 4;
float d = (float)src0_d[ib];
float m = (float)src0_m[ib];
global uchar4 * qs = src0_q + ib*4 + iqs;
uchar4 q = *qs;
float4 v1 = (convert_float4((uchar4)((q.s0 )&0x0F, (q.s1 )&0x0F, (q.s2 )&0x0F, (q.s3 )&0x0F)))*d + m;
float4 v2 = (convert_float4((uchar4)((q.s0>>4)&0x0F, (q.s1>>4)&0x0F, (q.s2>>4)&0x0F, (q.s3>>4)&0x0F)))*d + m;
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0;
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1;
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2;
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3;
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0;
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1;
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2;
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3;
} else {
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f;
}
}
for (int l = 0; l < BN; l += loadstride_b) {
if (ic*BN + loadc_b + l < ne11) {
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
} else {
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
}
}
barrier(CLK_LOCAL_MEM_FENCE);
pos_a += BK / LOAD_VEC_A;
pos_b += BK / LOAD_VEC_B;
for (int i = 0; i < BK; i++) {
for (int j = 0; j < TM; j++) {
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
}
for (int j = 0; j < TN; j++) {
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
}
for (int cc = 0; cc < TN; cc++) {
for (int cr = 0; cr < TM; cr++) {
const int sums_idx = cc*TM + cr;
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
}
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
const int dr = ir * BM + th_r * TM;
const int dc = ic * BN + th_c * TN;
const int offsets = batch_idx * batch_stride_d;
for (int cc = 0; cc < TN; cc++) {
for (int cr = 0; cr < TM; cr++) {
if (dr + cr < ne01 && dc + cc < ne11) {
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
}
}
}
}
@@ -0,0 +1,158 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#define LOAD_VEC_A 2
#define LOAD_VEC_B 4
#define BM 64
#define BN 64
#define BK 32
#define TM 4
#define TN 8
kernel void kernel_mul_mm_q6_k_f32_l4_lm(
global uchar * src0_ql,
global uchar * src0_qh,
global char * src0_s,
global half * src0_d,
global float4 * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne11,
int ne12,
int stride_a,
int stride_b,
int stride_d,
int batch_stride_a,
int batch_stride_b,
int batch_stride_d,
int r2,
int r3
) {
src1 = (global float4*)((global char*)src1 + offset1);
dst = (global float *)((global char*)dst + offsetd);
local float buf_a[BM * BK];
local float buf_b[BN * BK];
const int batch_idx = get_global_id(2);
const int i13 = batch_idx / ne12;
const int i12 = batch_idx % ne12;
const int i03 = i13 / r3;
const int i02 = i12 / r2;
const int batch_idx_a = i03 * ne02 + i02;
const int ir = get_group_id(0);
const int ic = get_group_id(1);
const int tid = get_local_id(0);
const int th_r = tid % (BM / TM);
const int th_c = tid / (BM / TM);
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
float sums[TM * TN];
float cache_a[TM];
float cache_b[TN];
for (int i = 0; i < TM * TN; i++) {
sums[i] = 0.0f;
}
for (int block = 0; block < ne00; block += BK) {
for (int l = 0; l < BM; l += loadstride_a) {
if (ir*BM + loadc_a + l < ne01) {
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
int ib = idx / 128; // 2 values per idx
int iqs = idx % 128; // 0..127
int n = iqs / 64; // 0,1
int b = (iqs % 64) / 32; // 0,1
int is_b = (iqs % 16) / 8; // 0,1
int qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
int is = 8 * n + qhshift + is_b; // 0..15
int qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126
int qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
float dscale = (float)src0_d[ib] * (float)src0_s[ib*16 + is];
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 0] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 0] >> qhshift) & 3) << 4)) - 32);
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 1] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 1] >> qhshift) & 3) << 4)) - 32);
} else {
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f;
}
}
for (int l = 0; l < BN; l += loadstride_b) {
if (ic*BN + loadc_b + l < ne11) {
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
} else {
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
}
}
barrier(CLK_LOCAL_MEM_FENCE);
pos_a += BK / LOAD_VEC_A;
pos_b += BK / LOAD_VEC_B;
for (int i = 0; i < BK; i++) {
for (int j = 0; j < TM; j++) {
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
}
for (int j = 0; j < TN; j++) {
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
}
for (int cc = 0; cc < TN; cc++) {
for (int cr = 0; cr < TM; cr++) {
const int sums_idx = cc*TM + cr;
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
}
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
const int dr = ir * BM + th_r * TM;
const int dc = ic * BN + th_c * TN;
const int offsets = batch_idx * batch_stride_d;
for (int cc = 0; cc < TN; cc++) {
for (int cr = 0; cr < TM; cr++) {
if (dr + cr < ne01 && dc + cc < ne11) {
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
}
}
}
}
@@ -0,0 +1,219 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_subgroups
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#endif
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#define QK4_1 32
struct block_q4_1 {
half d; // delta
half m; // min
uchar qs[QK4_1 / 2]; // nibbles / quants
};
inline float block_q4_1_dot_y(
global const struct block_q4_1 * qb_curr,
float sumy,
float16 yl,
int il
) {
float d = qb_curr->d;
float m = qb_curr->m;
float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f);
global const ushort * qs = ((global const ushort *) qb_curr + 2 + il/2);
acc.s0 += yl.s0 * (qs[0] & 0x000F);
acc.s0 += yl.s1 * (qs[0] & 0x0F00);
acc.s0 += yl.s8 * (qs[0] & 0x00F0);
acc.s3 += yl.s9 * (qs[0] & 0xF000);
acc.s0 += yl.s2 * (qs[1] & 0x000F);
acc.s1 += yl.s3 * (qs[1] & 0x0F00);
acc.s2 += yl.sa * (qs[1] & 0x00F0);
acc.s3 += yl.sb * (qs[1] & 0xF000);
acc.s0 += yl.s4 * (qs[2] & 0x000F);
acc.s1 += yl.s5 * (qs[2] & 0x0F00);
acc.s2 += yl.sc * (qs[2] & 0x00F0);
acc.s3 += yl.sd * (qs[2] & 0xF000);
acc.s0 += yl.s6 * (qs[3] & 0x000F);
acc.s1 += yl.s7 * (qs[3] & 0x0F00);
acc.s2 += yl.se * (qs[3] & 0x00F0);
acc.s3 += yl.sf * (qs[3] & 0xF000);
return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3) + sumy * m;
}
#undef N_DST
#undef N_SIMDGROUP
#undef N_SIMDWIDTH
#ifdef INTEL_GPU
#define N_DST 4 // each subgroup works on 4 rows
#define N_SIMDGROUP 1 // number of subgroups in a thread group
#define N_SIMDWIDTH 16 // assuming subgroup size is 16
#elif defined (ADRENO_GPU)
#define N_DST 4
#define N_SIMDGROUP 1
#define N_SIMDWIDTH 64
#endif
inline void mul_vec_q_n_f32(
global void * src0,
global float * src1,
global float * dst,
int ne00,
int ne01,
int ne02,
int ne10,
int ne12,
int ne0,
int ne1,
int r2,
int r3
) {
const ulong nb = ne00/QK4_1;
int r0 = get_group_id(0);
int r1 = get_group_id(1);
int im = get_group_id(2);
int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
int i12 = im%ne12;
int i13 = im/ne12;
ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
global struct block_q4_1 * x = (global struct block_q4_1 *) src0 + offset0;
global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1;
float16 yl;
float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f);
int ix = get_sub_group_local_id()/2;
int il = 8*(get_sub_group_local_id()%2);
global float * yb = y + ix * QK4_1 + il;
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
float sumy = 0;
sumy += yb[0];
sumy += yb[1];
sumy += yb[2];
sumy += yb[3];
sumy += yb[4];
sumy += yb[5];
sumy += yb[6];
sumy += yb[7];
sumy += yb[16];
sumy += yb[17];
sumy += yb[18];
sumy += yb[19];
sumy += yb[20];
sumy += yb[21];
sumy += yb[22];
sumy += yb[23];
yl.s0 = yb[0];
yl.s1 = yb[1]/256.f;
yl.s2 = yb[2];
yl.s3 = yb[3]/256.f;
yl.s4 = yb[4];
yl.s5 = yb[5]/256.f;
yl.s6 = yb[6];
yl.s7 = yb[7]/256.f;
yl.s8 = yb[16]/16.f;
yl.s9 = yb[17]/4096.f;
yl.sa = yb[18]/16.f;
yl.sb = yb[19]/4096.f;
yl.sc = yb[20]/16.f;
yl.sd = yb[21]/4096.f;
yl.se = yb[22]/16.f;
yl.sf = yb[23]/4096.f;
sumf.s0 += block_q4_1_dot_y(x+ib+0*nb, sumy, yl, il);
sumf.s1 += block_q4_1_dot_y(x+ib+1*nb, sumy, yl, il);
sumf.s2 += block_q4_1_dot_y(x+ib+2*nb, sumy, yl, il);
sumf.s3 += block_q4_1_dot_y(x+ib+3*nb, sumy, yl, il);
yb += QK4_1 * (N_SIMDWIDTH/2);
}
float4 tot = (float4)(
sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3)
);
if (get_sub_group_local_id() == 0) {
if (first_row + 0 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
}
if (first_row + 1 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
}
if (first_row + 2 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
}
if (first_row + 3 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
}
}
}
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_16
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_mul_mv_q4_1_f32(
global void * src0,
ulong offset0,
global float * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne10,
int ne12,
int ne0,
int ne1,
int r2,
int r3
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global float*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);
mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
}
@@ -0,0 +1,229 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_subgroups
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#endif
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#define QK4_1 32
struct block_q4_1 {
half d; // delta
half m; // min
uchar qs[QK4_1 / 2]; // nibbles / quants
};
inline float block_q4_1_dot_y_flat(
global const uchar * x,
global const half * dh,
global const half * mh,
float sumy,
float16 yl,
int il
) {
float d = *dh;
float m = *mh;
global const ushort * qs = ((global const ushort *) x + il/2);
float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f);
acc.s0 += yl.s0 * (qs[0] & 0x000F);
acc.s0 += yl.s1 * (qs[0] & 0x0F00);
acc.s0 += yl.s8 * (qs[0] & 0x00F0);
acc.s3 += yl.s9 * (qs[0] & 0xF000);
acc.s0 += yl.s2 * (qs[1] & 0x000F);
acc.s1 += yl.s3 * (qs[1] & 0x0F00);
acc.s2 += yl.sa * (qs[1] & 0x00F0);
acc.s3 += yl.sb * (qs[1] & 0xF000);
acc.s0 += yl.s4 * (qs[2] & 0x000F);
acc.s1 += yl.s5 * (qs[2] & 0x0F00);
acc.s2 += yl.sc * (qs[2] & 0x00F0);
acc.s3 += yl.sd * (qs[2] & 0xF000);
acc.s0 += yl.s6 * (qs[3] & 0x000F);
acc.s1 += yl.s7 * (qs[3] & 0x0F00);
acc.s2 += yl.se * (qs[3] & 0x00F0);
acc.s3 += yl.sf * (qs[3] & 0xF000);
return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3) + sumy * m;
}
#undef N_DST
#undef N_SIMDGROUP
#undef N_SIMDWIDTH
#ifdef INTEL_GPU
#define N_DST 4 // each subgroup works on 4 rows
#define N_SIMDGROUP 1 // number of subgroups in a thread group
#define N_SIMDWIDTH 16 // assuming subgroup size is 16
#elif defined (ADRENO_GPU)
#define N_DST 4
#define N_SIMDGROUP 1
#define N_SIMDWIDTH 64
#endif
inline void mul_vec_q_n_f32_flat(
global void * src0_q,
global void * src0_d,
global void * src0_m,
global float * src1,
global float * dst,
int ne00,
int ne01,
int ne02,
int ne10,
int ne12,
int ne0,
int ne1,
int r2,
int r3
) {
const ulong nb = ne00/QK4_1;
int r0 = get_group_id(0);
int r1 = get_group_id(1);
int im = get_group_id(2);
int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
int i12 = im%ne12;
int i13 = im/ne12;
ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
// The number of scales/mins is the same as the number of blocks.
ulong offset0_dm = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02));
// Each block contains QK4_1/2 uchars, hence offset for qs is as follows.
ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_1/2;
global uchar * x = (global uchar *) src0_q + offset0_q;
global half * d = (global half *) src0_d + offset0_dm;
global half * m = (global half *) src0_m + offset0_dm;
global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1;
float16 yl;
float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f);
int ix = get_sub_group_local_id()/2;
int il = 8*(get_sub_group_local_id()%2);
global float * yb = y + ix * QK4_1 + il;
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
float sumy = 0;
sumy += yb[0];
sumy += yb[1];
sumy += yb[2];
sumy += yb[3];
sumy += yb[4];
sumy += yb[5];
sumy += yb[6];
sumy += yb[7];
sumy += yb[16];
sumy += yb[17];
sumy += yb[18];
sumy += yb[19];
sumy += yb[20];
sumy += yb[21];
sumy += yb[22];
sumy += yb[23];
yl.s0 = yb[0];
yl.s1 = yb[1]/256.f;
yl.s2 = yb[2];
yl.s3 = yb[3]/256.f;
yl.s4 = yb[4];
yl.s5 = yb[5]/256.f;
yl.s6 = yb[6];
yl.s7 = yb[7]/256.f;
yl.s8 = yb[16]/16.f;
yl.s9 = yb[17]/4096.f;
yl.sa = yb[18]/16.f;
yl.sb = yb[19]/4096.f;
yl.sc = yb[20]/16.f;
yl.sd = yb[21]/4096.f;
yl.se = yb[22]/16.f;
yl.sf = yb[23]/4096.f;
sumf.s0 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 0*nb*QK4_1/2, d + ib + 0*nb, m + ib + 0*nb, sumy, yl, il);
sumf.s1 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 1*nb*QK4_1/2, d + ib + 1*nb, m + ib + 1*nb, sumy, yl, il);
sumf.s2 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 2*nb*QK4_1/2, d + ib + 2*nb, m + ib + 2*nb, sumy, yl, il);
sumf.s3 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 3*nb*QK4_1/2, d + ib + 3*nb, m + ib + 3*nb, sumy, yl, il);
yb += QK4_1 * (N_SIMDWIDTH/2);
}
float4 tot = (float4)(
sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3)
);
if (get_sub_group_local_id() == 0) {
if (first_row + 0 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
}
if (first_row + 1 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
}
if (first_row + 2 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
}
if (first_row + 3 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
}
}
}
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_16
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_mul_mv_q4_1_f32_flat(
global void * src0_q,
global void * src0_d,
global void * src0_m,
global float * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne10,
int ne12,
int ne0,
int ne1,
int r2,
int r3
) {
src1 = (global float*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);
mul_vec_q_n_f32_flat(src0_q, src0_d, src0_m, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
}
@@ -0,0 +1,180 @@
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
//------------------------------------------------------------------------------
// block_q4_K
//------------------------------------------------------------------------------
#define QK_K 256
#define K_SCALE_SIZE 12
// 8 blocks of 32 elements each
// weight is represented as x = a * q + b
typedef struct {
half d; // super-block scale for quantized scales
half dmin; // super-block scale for quantized mins
uchar scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
uchar qs[QK_K/2]; // 4-bit quants
} block_q4_K;
#undef N_DST
#undef N_SIMDGROUP
#undef N_SIMDWIDTH
#ifdef INTEL_GPU
#define N_DST 4 // number of rows each SIMD group works on
#define N_SIMDGROUP 1 // number of SIMD groups in a thread group
#define N_SIMDWIDTH 16 // SIMD group size
#elif defined (ADRENO_GPU)
#define N_DST 4
#define N_SIMDGROUP 1
#define N_SIMDWIDTH 64
#endif
#undef BLOCK_STRIDE
// number of (super) blocks each subgroup processes
// each thread in a subgroup processes a block (32 weights)
#define BLOCK_STRIDE (N_SIMDWIDTH/8)
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_16
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_mul_mv_q4_K_f32(
global char * src0,
int offset0,
global char * src1,
int offset1,
global char * dst,
int offsetd,
int ne00,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
int ne12,
ulong nb11,
ulong nb12,
ulong nb13,
int ne0,
int ne1,
int r2,
int r3
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
ushort kmask1 = 0x3f3f;
ushort kmask2 = 0x0f0f;
ushort kmask3 = 0xc0c0;
int ix = get_sub_group_local_id()/8; // super block index
int it = get_sub_group_local_id()%8; // block index (inside super block)
int iq = it/4; // 0 or 1 - first or second half of the super block
int ir = it%4; // 0...3 - block index in the half super block
int nb = ne00/QK_K;
int r0 = get_group_id(0);
int r1 = get_group_id(1);
int im = get_group_id(2);
int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
int i12 = im%ne12;
int i13 = im/ne12;
int offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
int offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
global block_q4_K * x = (global block_q4_K *) (src0 + offset_src0);
global float * y = (global float *) (src1 + offset_src1);
float yl[16];
float yh[16];
float sumf[N_DST] = {0.f};
float all_sum;
global float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
ushort sc16[4];
uchar * sc8 = (uchar *)sc16;
for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) {
float4 sumy = {0.f, 0.f, 0.f, 0.f};
for (int i = 0; i < 8; ++i) {
yl[i+0] = y4[i+0];
sumy.s0 += yl[i+0];
yl[i+8] = y4[i+32];
sumy.s1 += yl[i+8];
yh[i+0] = y4[i+128];
sumy.s2 += yh[i+0];
yh[i+8] = y4[i+160];
sumy.s3 += yh[i+8];
}
global ushort * sc = (global ushort *)x[ib].scales + iq;
global ushort * q1 = (global ushort *)x[ib].qs + 16 * iq + 4 * ir;
global half * dh = &x[ib].d;
for (int row = 0; row < N_DST; row++) {
sc16[0] = sc[0] & kmask1;
sc16[1] = sc[2] & kmask1;
sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
global ushort * q2 = q1 + 32;
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
for (int i = 0; i < 8; i += 2) {
acc1.s0 += yl[i+0] * (q1[i/2] & 0x000F);
acc1.s1 += yl[i+1] * (q1[i/2] & 0x0F00);
acc1.s2 += yl[i+8] * (q1[i/2] & 0x00F0);
acc1.s3 += yl[i+9] * (q1[i/2] & 0xF000);
acc2.s0 += yh[i+0] * (q2[i/2] & 0x000F);
acc2.s1 += yh[i+1] * (q2[i/2] & 0x0F00);
acc2.s2 += yh[i+8] * (q2[i/2] & 0x00F0);
acc2.s3 += yh[i+9] * (q2[i/2] & 0xF000);
}
float dall = dh[0];
float dmin = dh[1];
sumf[row] += dall * ((acc1.s0 + 1.f/256.f * acc1.s1) * sc8[0] +
(acc1.s2 + 1.f/256.f * acc1.s3) * sc8[1] * 1.f/16.f +
(acc2.s0 + 1.f/256.f * acc2.s1) * sc8[4] +
(acc2.s2 + 1.f/256.f * acc2.s3) * sc8[5] * 1.f/16.f) -
dmin * (sumy.s0 * sc8[2] + sumy.s1 * sc8[3] + sumy.s2 * sc8[6] + sumy.s3 * sc8[7]);
q1 += nb01/2;
sc += nb01/2;
dh += nb01/2;
}
y4 += BLOCK_STRIDE * QK_K;
}
global float * dst_f32 = (global float *) dst + im*ne0*ne1 + r1*ne0;
for (int row = 0; row < N_DST; ++row) {
all_sum = sub_group_reduce_add(sumf[row]);
if (first_row + row < ne01) {
if (get_sub_group_local_id() == 0) {
dst_f32[first_row + row] = all_sum;
}
}
}
}
+88 -60
View File
@@ -3,86 +3,114 @@
//------------------------------------------------------------------------------
// softplus
//------------------------------------------------------------------------------
inline float softplus_f32(float x){
float ax = fabs(x);
float m = fmax(x, 0.0f);
return log1p(exp(-ax)) + m;
kernel void kernel_softplus_f32(
global const float * src0,
ulong offset0,
global float * dst,
ulong offsetd
) {
src0 = (global float*)((global char*)src0 + offset0);
dst = (global float*)((global char*)dst + offsetd);
dst[get_global_id(0)] = (src0[get_global_id(0)] > 20.0f) ? src0[get_global_id(0)] : log(1.0f + exp(src0[get_global_id(0)]));
}
kernel void kernel_softplus_f32_nd(
global void * p_src0_base,
ulong off_src0_abs,
global void * p_dst_base,
ulong off_dst_abs,
int ne00,
int ne01,
int ne02,
int ne03,
kernel void kernel_softplus_f32_4(
global const float4 * src0,
ulong offset0,
global float4 * dst,
ulong offsetd
) {
src0 = (global float4*)((global char*)src0 + offset0);
dst = (global float4*)((global char*)dst + offsetd);
dst[get_global_id(0)] = (src0[get_global_id(0)] > 20.0f) ? src0[get_global_id(0)] : log(1.0f + exp(src0[get_global_id(0)]));
}
kernel void kernel_softplus_f16(
global const half * src0,
ulong offset0,
global half * dst,
ulong offsetd
) {
src0 = (global half*)((global char*)src0 + offset0);
dst = (global half*)((global char*)dst + offsetd);
const float x = convert_float(src0[get_global_id(0)]);
dst[get_global_id(0)] = convert_half_rte((x > 20.0f) ? x : log(1.0f + exp(x)));
}
kernel void kernel_softplus_f16_4(
global const half4 * src0,
ulong offset0,
global half4 * dst,
ulong offsetd
) {
src0 = (global half4*)((global char*)src0 + offset0);
dst = (global half4*)((global char*)dst + offsetd);
const float4 x = convert_float4(src0[get_global_id(0)]);
dst[get_global_id(0)] = convert_half4_rte((x > 20.0f) ? x : log(1.0f + exp(x)));
}
kernel void kernel_softplus_f32_nc(
global const char * src0,
ulong offset0,
global char * dst,
ulong offsetd,
int ne00,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
int ne11,
int ne12,
int ne13,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb13
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3
) {
int i0 = get_global_id(0);
int i1 = get_global_id(1);
int i2 = get_global_id(2);
src0 = src0 + offset0;
dst = dst + offsetd;
if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
for (int i3 = 0; i3 < ne13; ++i3) {
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
const int i3 = get_group_id(2);
const int i2 = get_group_id(1);
const int i1 = get_group_id(0);
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
*dst_val_ptr = softplus_f32(*src_val_ptr);
}
*y = (*x > 20.0f) ? *x : log(1.0f + exp(*x));
}
}
kernel void kernel_softplus_f16_nd(
global void * p_src0_base,
ulong off_src0_abs,
global void * p_dst_base,
ulong off_dst_abs,
int ne00,
int ne01,
int ne02,
int ne03,
kernel void kernel_softplus_f16_nc(
global const char * src0,
ulong offset0,
global char * dst,
ulong offsetd,
int ne00,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
int ne11,
int ne12,
int ne13,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb13
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3
) {
int i0 = get_global_id(0);
int i1 = get_global_id(1);
int i2 = get_global_id(2);
src0 = src0 + offset0;
dst = dst + offsetd;
if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
for (int i3 = 0; i3 < ne13; ++i3) {
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
const int i3 = get_group_id(2);
const int i2 = get_group_id(1);
const int i1 = get_group_id(0);
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
global const half * hx = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
global half * hy = (global half *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
*dst_val_ptr = (half)(softplus_f32((float)(*src_val_ptr)));
}
const float x = convert_float(*hx);
*hy = convert_half_rte((x > 20.0f) ? x : log(1.0f + exp(x)));
}
}
+116 -15
View File
@@ -1,8 +1,13 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
// Most devices have max workgroup size of 1024, so this is enough for subgroup
// sizes of 16, 32, 64 and 128. Increase this value for smaller subgroups sizes
#define MAX_SUBGROUPS 64
kernel void kernel_sum_rows_f32(
global float * src0,
global char * src0,
ulong offset0,
global float * dst,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
@@ -15,25 +20,121 @@ kernel void kernel_sum_rows_f32(
ulong nb2,
ulong nb3
) {
src0 = (global float *)((global char *)src0 + offset0);
dst = (global float *)((global char *)dst + offsetd);
src0 = src0 + offset0;
dst = dst + offsetd;
int i3 = get_global_id(2);
int i2 = get_global_id(1);
int i1 = get_global_id(0);
const int i3 = get_group_id(2);
const int i2 = get_group_id(1);
const int i1 = get_group_id(0);
const int lid = get_local_id(0);
const int lsize = get_local_size(0);
const uint sg_size = get_sub_group_size();
const uint sg_id = get_sub_group_id();
const uint sg_lid = get_sub_group_local_id();
__local float lmem[MAX_SUBGROUPS];
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
return;
}
global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
global float * dst_row = (global float *) ((global char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
float row_sum = 0;
for (int i0 = 0; i0 < ne00; i0++) {
row_sum += src_row[i0];
if(sg_id == 0){
lmem[sg_lid] = 0.0f;
}
dst_row[0] = row_sum;
global float * src_row = (global float *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);
global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3);
float sumf = 0.0f;
for (int i0 = lid; i0 < ne00; i0 += lsize) {
sumf += src_row[i0];
}
sumf = sub_group_reduce_add(sumf);
barrier(CLK_LOCAL_MEM_FENCE);
if(sg_lid == 0){
lmem[sg_id] = sumf;
}
barrier(CLK_LOCAL_MEM_FENCE);
sumf = lmem[sg_lid];
sumf = sub_group_reduce_add(sumf);
if (lid == 0) {
dst_row[0] = sumf;
}
}
kernel void kernel_sum_rows_f32_4(
global char * src0,
ulong offset0,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb01,
ulong nb02,
ulong nb03,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = src0 + offset0;
dst = dst + offsetd;
const int i3 = get_group_id(2);
const int i2 = get_group_id(1);
const int i1 = get_group_id(0);
const int lid = get_local_id(0);
const int lsize = get_local_size(0);
const uint sg_size = get_sub_group_size();
const uint sg_id = get_sub_group_id();
const uint sg_lid = get_sub_group_local_id();
__local float lmem[MAX_SUBGROUPS];
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
return;
}
if(sg_id == 0){
lmem[sg_lid] = 0.0f;
}
global float4 * src_row = (global float4 *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);
global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3);
float4 sum_vec = (float4)0.0f;
for (int i0 = lid; i0 < ne00 / 4; i0 += lsize) {
sum_vec += src_row[i0];
}
float sumf = dot(sum_vec, (float4)(1.0f));
sumf = sub_group_reduce_add(sumf);
barrier(CLK_LOCAL_MEM_FENCE);
if(sg_lid == 0){
lmem[sg_id] = sumf;
}
barrier(CLK_LOCAL_MEM_FENCE);
sumf = lmem[sg_lid];
sumf = sub_group_reduce_add(sumf);
if (lid == 0) {
dst_row[0] = sumf;
}
}
+92 -41
View File
@@ -92,6 +92,7 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
#define VK_VENDOR_ID_APPLE 0x106b
#define VK_VENDOR_ID_INTEL 0x8086
#define VK_VENDOR_ID_NVIDIA 0x10de
#define VK_VENDOR_ID_QUALCOMM 0x5143
#define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256
@@ -687,6 +688,7 @@ struct vk_device_struct {
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
vk_pipeline pipeline_acc_f32;
vk_pipeline pipeline_set_f32;
// [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16]
vk_pipeline pipeline_add[2][2][2];
@@ -942,6 +944,7 @@ struct vk_mat_mat_push_constants {
uint32_t M; uint32_t N; uint32_t K;
uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
uint32_t base_work_group_z; uint32_t num_batches;
uint32_t k_split;
uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
uint32_t padded_N;
@@ -961,6 +964,7 @@ struct vk_mat_vec_push_constants {
uint32_t batch_stride_b;
uint32_t batch_stride_d;
uint32_t fusion_flags;
uint32_t base_work_group_y;
uint32_t ne02;
uint32_t ne12;
uint32_t broadcast2;
@@ -4080,7 +4084,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
}
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@@ -4181,7 +4185,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_set_f32, "set_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 0}, 1);
ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
@@ -5641,6 +5646,10 @@ static void ggml_vk_instance_init() {
driver_priorities[vk::DriverId::eMesaNvk] = 2;
#endif
break;
case VK_VENDOR_ID_QUALCOMM:
driver_priorities[vk::DriverId::eQualcommProprietary] = 1;
driver_priorities[vk::DriverId::eMesaTurnip] = 2;
break;
}
driver_priorities[vk::DriverId::eMesaDozen] = 100;
@@ -6766,8 +6775,16 @@ static void ggml_vk_matmul(
uint32_t padded_n) {
VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")");
if (split_k == 1) {
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, batch });
ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2]));
uint32_t base_work_group_z = 0;
while (base_work_group_z < batch) {
uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k, ne02, ne12, broadcast2, broadcast3, padded_n };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, groups_z });
base_work_group_z += groups_z;
}
return;
}
@@ -6781,9 +6798,17 @@ static void ggml_vk_matmul(
uint32_t k_split = CEIL_DIV(k, split_k);
k_split = ROUNDUP_POW2(k_split, 256);
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
// Make sure enough workgroups get assigned for split k to work
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2]));
uint32_t base_work_group_z = 0;
while (base_work_group_z < batch) {
uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
// Make sure enough workgroups get assigned for split k to work
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, groups_z });
base_work_group_z += groups_z;
}
ggml_vk_sync_buffers(ctx, subctx);
const std::array<uint32_t, 2> pc2 = { (uint32_t)(m * n * batch), split_k };
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 });
@@ -7179,7 +7204,6 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
}
// Request descriptor sets
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
if (qx_needs_dequant) {
ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);
}
@@ -7477,7 +7501,6 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (quantize_y) {
ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
}
ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
}
vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
@@ -7572,22 +7595,29 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;
}
// compute
const vk_mat_vec_push_constants pc = {
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
stride_batch_x, stride_batch_y, stride_batch_d,
fusion_flags,
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
};
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
{
d_X,
d_Y,
d_D,
d_F0,
d_F1,
},
pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
ggml_pipeline_request_descriptor_sets(ctx, dmmv, CEIL_DIV(ne12 * ne13, ctx->device->properties.limits.maxComputeWorkGroupCount[1]));
uint32_t base_work_group_y = 0;
while (base_work_group_y < ne12 * ne13) {
uint32_t groups_y = std::min((uint32_t)(ne12 * ne13) - base_work_group_y, ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
const vk_mat_vec_push_constants pc = {
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
stride_batch_x, stride_batch_y, stride_batch_d,
fusion_flags, base_work_group_y,
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
};
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
{
d_X,
d_Y,
d_D,
d_F0,
d_F1,
},
pc, { groups_x, groups_y, groups_z });
base_work_group_y += groups_y;
}
if (x_non_contig) {
ctx->prealloc_x_need_sync = true;
@@ -7825,10 +7855,15 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
src1->nb[2] <= src1->nb[1] &&
src1->nb[1] <= src1->nb[3] &&
src0->ne[3] == 1 &&
src1->ne[3] == 1) {
src1->ne[3] == 1 &&
src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) {
ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, cgraph, node_idx);
} else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 &&
!ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
!ggml_is_permuted(src0) && !ggml_is_permuted(src1) &&
src0->ne[3] <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&
src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) {
ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, cgraph, node_idx);
// mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
// when ne12 and ne13 are one.
@@ -8422,6 +8457,8 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
const uint32_t acctype = f32acc ? 4 : 2;
const uint32_t f16vec4 = 8;
const uint32_t tmpsh = (Bc / MatBc) * sizeof(float);
const uint32_t qstride = hsk_pad / 4 + 2;
const uint32_t Qf = Br * qstride * f16vec4;
@@ -8438,7 +8475,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
const uint32_t slope = Br * acctype;
const uint32_t total_size = Qf + Psh + sfsh + ksh + slope;
const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + slope;
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", kv_type=" << kv_type << ", total_size=" << total_size << ", supported=" << supported);
@@ -8815,6 +8852,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_acc_f32;
}
return nullptr;
case GGML_OP_SET:
if (src0->type == src1->type && src0->type == dst->type &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32)) {
return ctx->device->pipeline_set_f32;
}
return nullptr;
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_MUL:
@@ -9801,16 +9844,16 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
// int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
int offset = dst->op_params[3] / 4; // offset in bytes
int nb1 = dst->op_params[0] / src0_type_size; // 4 bytes of float32
int nb2 = dst->op_params[1] / src0_type_size; // 4 bytes of float32
int nb3 = dst->op_params[2] / src0_type_size; // 4 bytes of float32
int offset = dst->op_params[3] / src0_type_size; // offset in bytes
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ACC, {
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] / dst_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3,
0,
0.0f, 0.0f, offset,
});
@@ -10624,8 +10667,10 @@ static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& sub
}
static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
float * op_params = (float *)dst->op_params;
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
const float * op_params = (const float *)dst->op_params;
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
p.param1 = op_params[0];
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, std::move(p));
}
static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
@@ -11543,7 +11588,6 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
}
}
ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
if (split_k > 1) {
ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
@@ -12052,7 +12096,6 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
// y[i] = i % k;
}
ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
if (split_k > 1) {
ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
@@ -12500,6 +12543,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break;
case GGML_OP_ACC:
case GGML_OP_SET:
ggml_vk_acc(ctx, compute_ctx, src0, src1, node);
break;
@@ -14896,8 +14940,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return true;
case GGML_OP_NORM:
case GGML_OP_GROUP_NORM:
case GGML_OP_L2_NORM:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_L2_NORM:
return ggml_is_contiguous_rows(op->src[0]) &&
op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_MUL:
@@ -14960,7 +15006,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
}
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ACC:
return op->src[0]->type == GGML_TYPE_F32;
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
case GGML_OP_SET:
return op->src[0]->type == op->src[1]->type && op->src[0]->type == op->type &&
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_I32);
case GGML_OP_CONCAT:
return ggml_type_size(op->src[0]->type) == ggml_type_size(GGML_TYPE_F32);
case GGML_OP_ADD1:
@@ -15611,6 +15660,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_ACC) {
tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
} else if (tensor->op == GGML_OP_SET) {
tensor_clone = ggml_set(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
} else if (tensor->op == GGML_OP_NORM) {
tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
} else if (tensor->op == GGML_OP_GROUP_NORM) {
+16 -8
View File
@@ -3,6 +3,9 @@
#include "types.glsl"
#include "generic_binary_head.glsl"
// false for SET, true for ACC
layout(constant_id = 1) const bool ACC = true;
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
@@ -13,17 +16,22 @@ void main() {
const uint offset = p.param3;
const uint src1_i = idx - offset;
const uint oz = src1_i / p.nb02;
const uint oy = (src1_i - (oz * p.nb02)) / p.nb01;
const uint ox = src1_i % p.nb01;
const uint i3 = src1_i / p.nb03;
const uint rem2 = src1_i - i3 * p.nb03;
const uint i2 = rem2 / p.nb02;
const uint rem1 = rem2 - i2 * p.nb02;
const uint i1 = rem1 / p.nb01;
const uint i0 = rem1 % p.nb01;
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03);
if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) {
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + ox + oy * p.ne10 + oz * p.ne10 * p.ne11]));
if (i0 < p.ne10 && i1 < p.ne11 && i2 < p.ne12 && i3 < p.ne13) {
if (ACC) {
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)]));
} else {
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)]));
}
} else {
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]));
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]));
}
}
@@ -130,6 +130,7 @@ void main() {
if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
float max_mask = NEG_FLT_MAX_OVER_2;
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
@@ -137,12 +138,25 @@ void main() {
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
masksh[c][r] = m;
max_mask = max(max_mask, m);
} else {
masksh[c][r] = float(0);
}
}
}
// skip the block if the mask is entirely -inf
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
barrier();
if (gl_SubgroupInvocationID == 0) {
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
}
barrier();
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
max_mask = max(max_mask, tmpsh[s]);
}
if (max_mask <= NEG_FLT_MAX_OVER_2) {
continue;
}
}
float Sf[Br][cols_per_thread];
@@ -260,6 +274,9 @@ void main() {
barrier();
}
// prevent race on tmpsh
barrier();
// reduce across threads
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
@@ -42,6 +42,8 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
return elem;
}
shared float tmpsh[row_split];
const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
shared f16vec4 Qf[Br * qstride];
@@ -213,6 +215,19 @@ void main() {
}
}
}
// skip the block if the mask is entirely -inf
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
barrier();
if (gl_SubgroupInvocationID == 0) {
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
}
barrier();
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
max_mask = max(max_mask, tmpsh[s]);
}
if (max_mask <= NEG_FLT_MAX_OVER_2) {
continue;
}
}
}
@@ -176,7 +176,14 @@ void main() {
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
// skip the block if the mask is entirely -inf
coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
continue;
}
} else {
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
// Don't clamp against nem1 when GQA is enabled
@@ -184,7 +191,14 @@ void main() {
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
// skip the block if the mask is entirely -inf
coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
continue;
}
}
}
}
@@ -1,6 +1,6 @@
#version 450
#include "generic_head.glsl"
#include "generic_unary_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
@@ -8,19 +8,22 @@
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
shared FLOAT_TYPE sum[BLOCK_SIZE];
void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
const uint i3 = row / (p.ne11 * p.ne12);
const uint i3_offset = i3 * p.ne12 * p.ne11;
const uint i2 = (row - i3_offset) / p.ne11;
const uint i2_offset = i2 * p.ne11;
const uint i1 = row - i3_offset - i2_offset;
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]);
sum[tid] += xi * xi;
}
@@ -35,7 +38,7 @@ void main() {
const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1)));
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
data_d[i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0] = D_TYPE(scale * FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]));
}
}
@@ -32,6 +32,7 @@ layout (push_constant) uniform parameter
uint expert_i1;
uint nbi1;
#else
uint base_work_group_y;
uint ne02;
uint ne12;
uint broadcast2;
@@ -45,9 +46,9 @@ uint expert_id;
void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
#ifdef MUL_MAT_ID
const uint expert_i0 = gl_GlobalInvocationID.y;
const uint expert_i0 = gl_WorkGroupID.y;
#else
const uint batch_idx = gl_GlobalInvocationID.y;
const uint batch_idx = gl_WorkGroupID.y + p.base_work_group_y;
#endif
#ifndef MUL_MAT_ID
@@ -90,6 +90,8 @@ layout (push_constant) uniform parameter
uint nbi1;
uint ne11;
#else
uint base_work_group_z;
uint num_batches;
uint k_split;
uint ne02;
uint ne12;
@@ -139,7 +141,7 @@ void main() {
const uint ic = gl_WorkGroupID.y;
#ifdef MUL_MAT_ID
const uint expert_idx = gl_GlobalInvocationID.z;
const uint expert_idx = gl_WorkGroupID.z;
if (ic * BN >= data_expert_count[expert_idx]) {
return;
}
@@ -149,7 +151,7 @@ void main() {
#endif
#ifndef MUL_MAT_ID
const uint batch_idx = gl_GlobalInvocationID.z;
const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z;
const uint i13 = batch_idx / p.ne12;
const uint i12 = batch_idx % p.ne12;
@@ -366,7 +368,7 @@ void main() {
const uint dc = ic * BN + warp_c * WN;
#ifndef MUL_MAT_ID
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches;
#endif
#ifdef COOPMAT
@@ -53,6 +53,8 @@ layout (push_constant) uniform parameter
uint nbi1;
uint ne11;
#else
uint base_work_group_z;
uint num_batches;
uint k_split;
uint ne02;
uint ne12;
@@ -197,7 +199,7 @@ void main() {
const uint ic = gl_WorkGroupID.y;
#ifdef MUL_MAT_ID
const uint expert_idx = gl_GlobalInvocationID.z;
const uint expert_idx = gl_WorkGroupID.z;
if (ic * BN >= data_expert_count[expert_idx]) {
return;
}
@@ -215,7 +217,7 @@ void main() {
#endif
#ifndef MUL_MAT_ID
const uint batch_idx = gl_GlobalInvocationID.z;
const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z;
const uint i13 = batch_idx / p.ne12;
const uint i12 = batch_idx % p.ne12;
@@ -255,7 +257,7 @@ void main() {
#else
uint pos_a = batch_idx_a * (p.batch_stride_a / QUANT_K);
uint pos_b = batch_idx * p.batch_stride_b;
uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches;
#endif
uint stride_a = p.stride_a / QUANT_K;
@@ -57,6 +57,8 @@ layout (push_constant) uniform parameter
uint nbi1;
uint ne11;
#else
uint base_work_group_z;
uint num_batches;
uint k_split;
uint ne02;
uint ne12;
@@ -108,7 +110,7 @@ void main() {
const uint ic = gl_WorkGroupID.y;
#ifdef MUL_MAT_ID
const uint expert_idx = gl_GlobalInvocationID.z;
const uint expert_idx = gl_WorkGroupID.z;
if (ic * BN >= data_expert_count[expert_idx]) {
return;
}
@@ -118,7 +120,7 @@ void main() {
#endif
#ifndef MUL_MAT_ID
const uint batch_idx = gl_GlobalInvocationID.z;
const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z;
const uint i13 = batch_idx / p.ne12;
const uint i12 = batch_idx % p.ne12;
@@ -276,7 +278,7 @@ void main() {
const uint dc = ic * BN + warp_c * WN;
#ifndef MUL_MAT_ID
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches;
#endif
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -1,5 +1,4 @@
#decl(BYTE_HELPERS)
#ifdef BYTE_HELPERS
fn get_byte(value: u32, index: u32) -> u32 {
return (value >> (index * 8)) & 0xFF;
}
@@ -7,76 +6,74 @@ fn get_byte(value: u32, index: u32) -> u32 {
fn get_byte_i32(value: u32, index: u32) -> i32 {
return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;
}
#endif
#enddecl(BYTE_HELPERS)
#decl(Q4_0_T)
#ifdef Q4_0_T
struct q4_0 {
d: f16,
qs: array<f16, 8>
};
#enddecl(Q4_0_T)
#endif
#decl(Q4_1_T)
#ifdef Q4_1_T
struct q4_1 {
d: f16,
m: f16,
qs: array<u32, 4>
};
#enddecl(Q4_1_T)
#endif
#decl(Q5_0_T)
#ifdef Q5_0_T
struct q5_0 {
d: f16,
qh: array<f16, 2>,
qs: array<f16, 8>
};
#enddecl(Q5_0_T)
#endif
#decl(Q5_1_T)
#ifdef Q5_1_T
struct q5_1 {
d: f16,
m: f16,
qh: u32,
qs: array<u32, 4>
};
#enddecl(Q5_1_T)
#endif
#decl(Q8_0_T)
#ifdef Q8_0_T
struct q8_0 {
d: f16,
qs: array<f16, 16>
};
#enddecl(Q8_0_T)
#endif
#decl(Q8_1_T)
#ifdef Q8_1_T
struct q8_1 {
d: f16,
m: f16,
qs: array<u32, 8>
};
#enddecl(Q8_1_T)
#endif
#decl(Q2_K_T)
struct q2_k {
#ifdef Q2_K_T
struct q2_K {
scales: array<u32, 4>,
qs: array<u32, 16>,
d: f16,
dmin: f16
};
#enddecl(Q2_K_T)
#endif
#decl(Q3_K_T)
struct q3_k {
#ifdef Q3_K_T
struct q3_K {
hmask: array<f16, 16>,
qs: array<f16, 32>,
scales: array<f16, 6>,
d: f16
};
#enddecl(Q3_K_T)
#decl(Q45_K_SCALE_MIN)
#endif
#if defined(Q4_K_SCALE_MIN) || defined(Q5_K_SCALE_MIN)
fn get_scale_min(is: u32, scales: array<u32, 3>) -> vec2<f32> {
if (is < 4) {
let sc_byte = get_byte(scales[is / 4], is % 4);
@@ -91,69 +88,67 @@ fn get_scale_min(is: u32, scales: array<u32, 3>) -> vec2<f32> {
return vec2(f32(sc), f32(m));
}
}
#enddecl(Q45_K_SCALE_MIN)
#decl(Q4_K_T)
struct q4_k {
#endif
#ifdef Q4_K_T
struct q4_K {
d: f16,
dmin: f16,
scales: array<u32, 3>,
qs: array<u32, 32>
};
#enddecl(Q4_K_T)
#endif
#decl(Q5_K_T)
struct q5_k {
#ifdef Q5_K_T
struct q5_K {
d: f16,
dmin: f16,
scales: array<u32, 3>,
qh: array<u32, 8>,
qs: array<u32, 32>
};
#enddecl(Q5_K_T)
#endif
#decl(Q6_K_T)
struct q6_k {
#ifdef Q6_K_T
struct q6_K {
ql: array<f16, 64>,
qh: array<f16, 32>,
scales: array<f16, 8>,
d: f16
};
#enddecl(Q6_K_T)
#endif
#decl(IQ2_XXS_T)
#ifdef IQ2_XXS_T
struct iq2_xxs {
d: f16,
qs: array<f16, 32>
};
#enddecl(IQ2_XXS_T)
#endif
#decl(IQ2_XS_T)
#ifdef IQ2_XS_T
struct iq2_xs {
d: f16,
qs: array<f16, 32>,
scales: array<f16, 4>
};
#enddecl(IQ2_XS_T)
#endif
#decl(IQ2_S_T)
#ifdef IQ2_S_T
struct iq2_s {
d: f16,
qs: array<f16, 32>,
qh: array<f16, 4>,
scales: array<f16, 4>
};
#enddecl(IQ2_S_T)
#endif
#decl(IQ3_XSS_T)
#ifdef IQ3_XXS_T
struct iq3_xxs {
d: f16,
qs: array<f16, 48>
};
#enddecl(IQ3_XSS_T)
#endif
#decl(IQ3_S_T)
#ifdef IQ3_S_T
struct iq3_s {
d: f16,
qs: array<f16, 32>,
@@ -161,41 +156,41 @@ struct iq3_s {
signs: array<f16, 16>,
scales: array<f16, 2>
};
#enddecl(IQ3_S_T)
#endif
#decl(IQ1_S_T)
#ifdef IQ1_S_T
struct iq1_s {
d: f16,
qs: array<f16, 16>,
qh: array<f16, 8>
};
#enddecl(IQ1_S_T)
#endif
#decl(IQ1_M_T)
#ifdef IQ1_M_T
struct iq1_m {
qs: array<u32, 8>,
qh: array<u32, 4>,
scales: array<u32, 2>
};
#enddecl(IQ1_M_T)
#endif
#decl(IQ4_NL_T)
#ifdef IQ4_NL_T
struct iq4_nl {
d: f16,
qs: array<f16, 8>,
};
#enddecl(IQ4_NL_T)
#endif
#decl(IQ4_XS_T)
#ifdef IQ4_XS_T
struct iq4_xs {
d: f16,
scales_h: f16,
scales_l: u32,
qs: array<u32, 32>
};
#enddecl(IQ4_XS_T)
#endif
#decl(IQ23_TABLES)
#if defined(IQ2_XXS_TABLES) || defined(IQ2_XS_TABLES) || defined(IQ2_S_TABLES) || defined(IQ3_XXS_TABLES) || defined(IQ3_S_TABLES)
const kmask_iq2xs : array<u32, 2> = array<u32, 2>(
0x08040201u, // 1, 2, 4, 8
0x80402010u // 16, 32, 64, 128
@@ -211,9 +206,9 @@ const ksigns_iq2xs: array<u32, 32> = array<u32, 32>(
0x63e2e160,0xe76665e4,0xeb6a69e8,0x6feeed6c,
0xf37271f0,0x77f6f574,0x7bfaf978,0xff7e7dfc
);
#enddecl(IQ23_TABLES)
#endif
#decl(IQ2_XXS_GRID)
#ifdef IQ2_XXS_GRID
const iq2xxs_grid = array<u32, 512>(
0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,
0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x082b0808, 0x08080808,
@@ -280,9 +275,9 @@ const iq2xxs_grid = array<u32, 512>(
0x0808082b, 0x2b2b0808, 0x19190808, 0x2b2b0808, 0x2b081919, 0x2b2b0808, 0x08082b19, 0x2b2b0819,
0x08080808, 0x2b2b082b, 0x08192b08, 0x2b2b1908, 0x19190808, 0x2b2b2b08, 0x08081908, 0x2b2b2b19
);
#enddecl(IQ2_XXS_GRID)
#endif
#decl(IQ2_XS_GRID)
#ifdef IQ2_XS_GRID
const iq2xs_grid = array<u32, 1024>(
0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,
0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808,
@@ -413,9 +408,9 @@ const iq2xs_grid = array<u32, 1024>(
0x2b2b2b08, 0x2b2b2b08, 0x08081908, 0x2b2b2b19, 0x2b081908, 0x2b2b2b19, 0x2b08192b, 0x2b2b2b19,
0x082b2b08, 0x2b2b2b2b, 0x082b2b2b, 0x2b2b2b2b, 0x2b190819, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b
);
#enddecl(IQ2_XS_GRID)
#endif
#decl(IQ2_S_GRID)
#ifdef IQ2_S_GRID
const iq2s_grid = array<u32, 2048>(
0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,
0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808,
@@ -674,10 +669,9 @@ const iq2s_grid = array<u32, 2048>(
0x2b08192b, 0x2b2b2b19, 0x08082b08, 0x2b2b2b2b, 0x08082b2b, 0x2b2b2b2b, 0x082b0808, 0x2b2b2b2b,
0x082b082b, 0x2b2b2b2b, 0x082b2b08, 0x2b2b2b2b, 0x2b082b08, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b
);
#enddecl(IQ2_S_GRID)
#decl(IQ3_XSS_GRID)
#endif
#ifdef IQ3_XXS_GRID
const iq3xxs_grid = array<u32, 256>(
0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
@@ -712,10 +706,9 @@ const iq3xxs_grid = array<u32, 256>(
0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04
);
#enddecl(IQ3_XSS_GRID)
#decl(IQ3_S_GRID)
#endif
#ifdef IQ3_S_GRID
const iq3s_grid = array<u32, 512>(
0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305,
0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905,
@@ -782,9 +775,9 @@ const iq3s_grid = array<u32, 512>(
0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b,
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101
);
#enddecl(IQ3_S_GRID)
#endif
#decl(IQ1_GRID)
#if defined(IQ1_S_GRID) || defined(IQ1_M_GRID)
const IQ1_DELTA: f32 = 0.125;
@@ -919,12 +912,12 @@ const iq1_grid = array<u32, 1024>(
0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557
);
#enddecl(IQ1_GRID)
#endif
#decl(IQ4_GRID)
#if defined(IQ4_NL_GRID) || defined(IQ4_XS_GRID)
const kvalues_iq4nl = array<i32, 16>(
-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113
);
#enddecl(IQ4_GRID)
#endif
@@ -56,12 +56,46 @@ def expand_includes(shader, input_dir):
return include_pattern.sub(replacer, shader)
def write_shader(shader_name, shader_code, output_dir, outfile):
def chunk_shader(shader_code, max_chunk_len=60000):
"""Split shader_code into safe raw-string sized chunks."""
return [shader_code[i : i + max_chunk_len] for i in range(0, len(shader_code), max_chunk_len)]
def raw_delim(shader_code):
"""Pick a raw-string delimiter that does not appear in the shader."""
delim = "wgsl"
while f"){delim}\"" in shader_code:
delim += "_x"
return delim
def write_shader(shader_name, shader_code, output_dir, outfile, input_dir):
shader_code = expand_includes(shader_code, input_dir)
if output_dir:
wgsl_filename = os.path.join(output_dir, f"{shader_name}.wgsl")
with open(wgsl_filename, "w", encoding="utf-8") as f_out:
f_out.write(shader_code)
outfile.write(f'const char* wgsl_{shader_name} = R"({shader_code})";\n\n')
delim = raw_delim(shader_code)
chunks = chunk_shader(shader_code)
if len(chunks) == 1:
outfile.write(f'const char* wgsl_{shader_name} = R"{delim}({shader_code}){delim}";\n\n')
else:
for idx, chunk in enumerate(chunks):
outfile.write(f'static const char wgsl_{shader_name}_part{idx}[] = R"{delim}({chunk}){delim}";\n\n')
outfile.write(f'static const std::string& wgsl_{shader_name}_str() {{\n')
outfile.write(' static const std::string s = []{\n')
outfile.write(' std::string tmp;\n')
outfile.write(f' tmp.reserve({len(shader_code)});\n')
for idx in range(len(chunks)):
outfile.write(f' tmp.append(wgsl_{shader_name}_part{idx});\n')
outfile.write(' return tmp;\n')
outfile.write(' }();\n')
outfile.write(' return s;\n')
outfile.write('}\n')
outfile.write(f'const char* wgsl_{shader_name} = wgsl_{shader_name}_str().c_str();\n\n')
def generate_variants(fname, input_dir, output_dir, outfile):
@@ -74,7 +108,7 @@ def generate_variants(fname, input_dir, output_dir, outfile):
try:
variants = ast.literal_eval(extract_block(text, "VARIANTS"))
except ValueError:
write_shader(shader_base_name, text, output_dir, outfile)
write_shader(shader_base_name, text, output_dir, outfile, input_dir)
else:
try:
decls_map = parse_decls(extract_block(text, "DECLS"))
@@ -123,7 +157,7 @@ def generate_variants(fname, input_dir, output_dir, outfile):
output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"]
else:
output_name = shader_base_name
write_shader(output_name, final_shader, output_dir, outfile)
write_shader(output_name, final_shader, output_dir, outfile, input_dir)
def main():
@@ -137,7 +171,8 @@ def main():
os.makedirs(args.output_dir, exist_ok=True)
with open(args.output_file, "w", encoding="utf-8") as out:
out.write("// Auto-generated shader embedding\n\n")
out.write("// Auto-generated shader embedding\n")
out.write("#include <string>\n\n")
for fname in sorted(os.listdir(args.input_dir)):
if fname.endswith(".wgsl"):
generate_variants(fname, args.input_dir, args.output_dir, out)
@@ -1,222 +1,31 @@
#define(VARIANTS)
enable f16;
#include "common_decls.tmpl"
[
{
"SHADER_SUFFIX": "f32_vec",
"REPLS": {
"TYPE" : "vec4<f32>",
"DST_TYPE": "vec4<f32>",
"BLOCK_SIZE": 4
},
"DECLS": ["F32_VEC"]
},
{
"REPLS": {
"TYPE" : "f32",
"DST_TYPE": "f32",
"BLOCK_SIZE": 1
},
"DECLS": ["F32"]
},
{
"REPLS": {
"TYPE" : "f16",
"DST_TYPE": "f32",
"BLOCK_SIZE": 1
},
"DECLS": ["F16"]
},
{
"REPLS": {
"TYPE" : "i32",
"DST_TYPE": "i32",
"BLOCK_SIZE": 1
},
"DECLS": ["I32"]
},
{
"REPLS": {
"TYPE" : "q4_0",
"DST_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"]
},
{
"REPLS": {
"TYPE" : "q4_1",
"DST_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"]
},
{
"REPLS": {
"TYPE" : "q5_0",
"DST_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"]
},
{
"REPLS": {
"TYPE" : "q5_1",
"DST_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"]
},
{
"REPLS": {
"TYPE" : "q8_0",
"DST_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"]
},
{
"REPLS": {
"TYPE" : "q2_k",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"]
},
{
"REPLS": {
"TYPE" : "q3_k",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"]
},
{
"REPLS": {
"TYPE" : "q4_k",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"]
},
{
"REPLS": {
"TYPE" : "q5_k",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"]
},
{
"REPLS": {
"TYPE" : "q6_k",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"]
},
{
"REPLS": {
"TYPE" : "iq2_xxs",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"]
},
{
"REPLS": {
"TYPE" : "iq2_xs",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"]
},
{
"REPLS": {
"TYPE": "iq2_s",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"]
},
{
"REPLS": {
"TYPE": "iq3_xxs",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"]
},
{
"REPLS": {
"TYPE": "iq3_s",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"]
},
{
"REPLS": {
"TYPE": "iq1_s",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"]
},
{
"REPLS": {
"TYPE": "iq1_m",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"]
},
{
"REPLS": {
"TYPE": "iq4_nl",
"DST_TYPE": "f32",
"BLOCK_SIZE": 32,
},
"DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"]
},
{
"REPLS": {
"TYPE": "iq4_xs",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256,
},
"DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"]
}
]
#end(VARIANTS)
#define(DECLS)
#decl(F32_VEC)
#ifdef F32_VEC
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
dst[(dst_base / 4) + offset] = src[(src_base / 4) + offset];
}
#enddecl(F32_VEC)
#endif
#decl(F32)
#ifdef F32
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
dst[dst_base + offset] = src[src_base + offset];
}
#enddecl(F32)
#endif
#decl(F16)
#ifdef F16
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
dst[dst_base + offset] = f32(src[src_base + offset]);
}
#enddecl(F16)
#endif
#decl(I32)
#ifdef I32
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
dst[dst_base + offset] = src[src_base + offset];
}
#enddecl(I32)
#endif
#decl(Q4_0)
#ifdef Q4_0
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_q4_0 = src[src_base + offset];
let d = f32(block_q4_0.d);
@@ -232,9 +41,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(Q4_0)
#endif
#decl(Q4_1)
#ifdef Q4_1
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_q4_1 = src[src_base + offset];
let d = f32(block_q4_1.d);
@@ -251,9 +60,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(Q4_1)
#endif
#decl(Q5_0)
#ifdef Q5_0
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_q5_0 = src[src_base + offset];
let d = f32(block_q5_0.d);
@@ -272,10 +81,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#endif
#enddecl(Q5_0)
#decl(Q5_1)
#ifdef Q5_1
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_q5_1 = src[src_base + offset];
let d = f32(block_q5_1.d);
@@ -294,9 +102,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(Q5_1)
#endif
#decl(Q8_0)
#ifdef Q8_0
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_q8_0 = src[src_base + offset];
let d = f32(block_q8_0.d);
@@ -310,9 +118,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(Q8_0)
#endif
#decl(Q2_K)
#ifdef Q2_K
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@@ -340,9 +148,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(Q2_K)
#endif
#decl(Q3_K)
#ifdef Q3_K
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@@ -398,9 +206,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(Q3_K)
#endif
#decl(Q4_K)
#ifdef Q4_K
// 8 blocks of 32 elements each
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
@@ -425,9 +233,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(Q4_K)
#endif
#decl(Q5_K)
#ifdef Q5_K
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@@ -455,9 +263,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(Q5_K)
#endif
#decl(Q6_K)
#ifdef Q6_K
// 16 blocks of 16 elements each
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
@@ -511,10 +319,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
sc_b_idx += 8;
}
}
#endif
#enddecl(Q6_K)
#decl(IQ2_XXS)
#ifdef IQ2_XXS
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@@ -536,9 +343,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(IQ2_XXS)
#endif
#decl(IQ2_XS)
#ifdef IQ2_XS
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@@ -568,9 +375,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(IQ2_XS)
#endif
#decl(IQ2_S)
#ifdef IQ2_S
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@@ -608,10 +415,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#endif
#enddecl(IQ2_S)
#decl(IQ3_XSS)
#ifdef IQ3_XXS
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@@ -638,9 +444,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(IQ3_XSS)
#endif
#decl(IQ3_S)
#ifdef IQ3_S
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@@ -683,9 +489,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(IQ3_S)
#endif
#decl(IQ1_S)
#ifdef IQ1_S
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@@ -707,10 +513,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#endif
#enddecl(IQ1_S)
#decl(IQ1_M)
#ifdef IQ1_M
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
@@ -751,10 +556,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#endif
#enddecl(IQ1_M)
#decl(IQ4_NL)
#ifdef IQ4_NL
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@@ -770,9 +574,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
dst_i++;
}
}
#enddecl(IQ4_NL)
#endif
#decl(IQ4_XS)
#ifdef IQ4_XS
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@@ -791,24 +595,16 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
dst_i += 16;
}
}
#enddecl(IQ4_XS)
#end(DECLS)
#define(SHADER)
enable f16;
DECLS
#endif
@group(0) @binding(0)
var<storage, read_write> src: array<{{TYPE}}>;
var<storage, read_write> src: array<SRC_TYPE>;
@group(0) @binding(1)
var<storage, read_write> idx: array<i32>;
@group(0) @binding(2)
var<storage, read_write> dst: array<{{DST_TYPE}}>;
var<storage, read_write> dst: array<DST_TYPE>;
struct Params {
offset_src: u32, // in elements
@@ -842,8 +638,7 @@ struct Params {
@group(0) @binding(3)
var<uniform> params: Params;
override wg_size: u32;
@compute @workgroup_size(wg_size)
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.n_rows * params.ne2 * params.ne3) {
return;
@@ -866,9 +661,8 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let i_src_row = params.offset_src + idx_val * params.stride_src1 + i_dst2 * params.stride_src2 + i_dst3 * params.stride_src3;
let i_dst_row = params.offset_dst + i_dst1 * params.stride_dst1 + i_dst2 * params.stride_dst2 + i_dst3 * params.stride_dst3;
for (var i: u32 = 0; i < params.ne0/{{BLOCK_SIZE}}; i++) {
for (var i: u32 = 0; i < params.ne0/BLOCK_SIZE; i++) {
copy_elements(i_src_row, i_dst_row, i);
}
}
#end(SHADER)
@@ -1,195 +1,24 @@
#define(VARIANTS)
enable f16;
[
{
"REPLS": {
"SRC0_TYPE" : "f32",
"SRC1_TYPE" : "f32",
"BLOCK_SIZE" : 1
},
"DECLS" : ["FLOAT"]
},
{
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f16",
"BLOCK_SIZE" : 1
},
"DECLS" : ["FLOAT"]
},
{
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f32",
"BLOCK_SIZE" : 1
},
"DECLS" : ["FLOAT"]
},
{
"REPLS": {
"SRC0_TYPE": "q4_0",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"]
},
{
"REPLS": {
"SRC0_TYPE": "q4_1",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"]
},
{
"REPLS": {
"SRC0_TYPE": "q5_0",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"]
},
{
"REPLS": {
"SRC0_TYPE": "q5_1",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"]
},
{
"REPLS": {
"SRC0_TYPE": "q8_0",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"]
},
{
"REPLS": {
"SRC0_TYPE": "q2_k",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"]
},
{
"REPLS": {
"SRC0_TYPE": "q3_k",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"]
},
{
"REPLS": {
"SRC0_TYPE": "q4_k",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"]
},
{
"REPLS": {
"SRC0_TYPE": "q5_k",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"]
},
{
"REPLS": {
"SRC0_TYPE": "q6_k",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"]
},
{
"REPLS": {
"SRC0_TYPE": "iq2_xxs",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"]
},
{
"REPLS": {
"SRC0_TYPE": "iq2_xs",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"]
},
{
"REPLS": {
"SRC0_TYPE": "iq2_s",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"]
},
{
"REPLS": {
"SRC0_TYPE": "iq3_xxs",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"]
},
{
"REPLS": {
"SRC0_TYPE": "iq3_s",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"]
},
{
"REPLS": {
"SRC0_TYPE": "iq1_s",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"]
},
{
"REPLS": {
"SRC0_TYPE": "iq1_m",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"]
},
{
"REPLS": {
"SRC0_TYPE": "iq4_nl",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 32,
},
"DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"]
},
{
"REPLS": {
"SRC0_TYPE": "iq4_xs",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256,
},
"DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"]
}
]
#include "common_decls.tmpl"
#end(VARIANTS)
#ifdef FLOAT
const BLOCK_SIZE = 1u;
#define(DECLS)
#elif defined(Q4_0) || defined(Q4_1) || defined(Q5_0) || defined(Q5_1) || defined(Q8_0) || defined(Q8_1) || defined(IQ4_NL)
const BLOCK_SIZE = 32u;
#decl(FLOAT)
#elif defined(Q2_K) || defined(Q3_K) || defined(Q4_K) || defined(Q5_K) || defined(Q6_K) || defined(IQ2_XXS) || defined(IQ2_XS) || defined(IQ2_S) || defined(IQ3_XXS) || defined(IQ3_S) || defined(IQ1_S) || defined(IQ1_M) || defined(IQ4_XS)
const BLOCK_SIZE = 256u;
#endif
#ifdef FLOAT
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
return f32(src0[src0_idx_base + offset]) * f32(src1[src1_idx_base + offset]);
}
#enddecl(FLOAT)
#endif
#decl(Q4_0)
#ifdef Q4_0
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_q4_0 = src0[src0_idx_base + offset];
let d = f32(block_q4_0.d);
@@ -207,9 +36,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#enddecl(Q4_0)
#endif
#decl(Q4_1)
#ifdef Q4_1
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_q4_1 = src0[src0_idx_base + offset];
let d = f32(block_q4_1.d);
@@ -228,9 +57,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#enddecl(Q4_1)
#endif
#decl(Q5_0)
#ifdef Q5_0
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_q5_0 = src0[src0_idx_base + offset];
let d = f32(block_q5_0.d);
@@ -251,9 +80,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#enddecl(Q5_0)
#endif
#decl(Q5_1)
#ifdef Q5_1
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_q5_1 = src0[src0_idx_base + offset];
let d = f32(block_q5_1.d);
@@ -274,9 +103,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#enddecl(Q5_1)
#endif
#decl(Q8_0)
#ifdef Q8_0
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_q8_0 = src0[src0_idx_base + offset];
let d = f32(block_q8_0.d);
@@ -292,9 +121,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#enddecl(Q8_0)
#endif
#decl(Q8_1)
#ifdef Q8_1
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_q8_1 = src0[src0_idx_base + offset];
let d = f32(block_q8_1.d);
@@ -311,9 +140,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#enddecl(Q8_1)
#endif
#decl(Q2_K)
#ifdef Q2_K
// 16 blocks of 16 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
@@ -344,10 +173,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(Q2_K)
#decl(Q3_K)
#ifdef Q3_K
// 16 blocks of 16 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
@@ -406,10 +234,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(Q3_K)
#decl(Q4_K)
#ifdef Q4_K
// 8 blocks of 32 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
@@ -436,10 +263,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(Q4_K)
#decl(Q5_K)
#ifdef Q5_K
// 8 blocks of 32 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
@@ -470,10 +296,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(Q5_K)
#decl(Q6_K)
#ifdef Q6_K
// 16 blocks of 16 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
@@ -529,10 +354,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(Q6_K)
#decl(IQ2_XXS)
#ifdef IQ2_XXS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
@@ -556,10 +380,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(IQ2_XXS)
#decl(IQ2_XS)
#ifdef IQ2_XS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
@@ -591,10 +414,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(IQ2_XS)
#decl(IQ2_S)
#ifdef IQ2_S
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
@@ -634,11 +456,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(IQ2_S)
#decl(IQ3_XSS)
#ifdef IQ3_XXS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
@@ -667,10 +487,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(IQ3_XSS)
#decl(IQ3_S)
#ifdef IQ3_S
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
@@ -715,9 +534,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#enddecl(IQ3_S)
#endif
#decl(IQ1_S)
#ifdef IQ1_S
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
@@ -741,10 +560,10 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(IQ1_S)
#decl(IQ1_M)
#ifdef IQ1_M
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
@@ -787,10 +606,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(IQ1_M)
#decl(IQ4_NL)
#ifdef IQ4_NL
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
@@ -808,10 +626,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(IQ4_NL)
#decl(IQ4_XS)
#ifdef IQ4_XS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
@@ -832,16 +649,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#enddecl(IQ4_XS)
#end(DECLS)
#define(SHADER)
enable f16;
DECLS
#endif
struct MulMatParams {
offset_src0: u32, // in elements/blocks
@@ -864,8 +672,8 @@ struct MulMatParams {
broadcast3: u32
};
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns
@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)
@group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns
@group(0) @binding(3) var<uniform> params: MulMatParams;
@@ -898,10 +706,8 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11;
var sum = 0.0;
for (var i: u32 = 0u; i < params.k/{{BLOCK_SIZE}}; i = i + 1u) {
for (var i: u32 = 0u; i < params.k/BLOCK_SIZE; i = i + 1u) {
sum += multiply_add(src0_idx_base, src1_idx_base, i);
}
dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum;
}
#end(SHADER)
@@ -1,58 +1,65 @@
#decl(SHMEM_VEC)
#ifdef VEC
#define VEC_SIZE 4
#define SHMEM_TYPE vec4<f16>
#define DST_TYPE vec4<f32>
#define SRC0_TYPE vec4<SRC0_INNER_TYPE>
#define SRC1_TYPE vec4<SRC1_INNER_TYPE>
fn store_shmem(val: vec4<f16>, idx: u32) {
shmem[idx] = val.x;
shmem[idx + 1] = val.y;
shmem[idx + 2] = val.z;
shmem[idx + 3] = val.w;
}
#enddecl(SHMEM_VEC)
#endif
#ifdef SCALAR
#define VEC_SIZE 1
#define SHMEM_TYPE f16
#define DST_TYPE f32
#define SRC0_TYPE SRC0_INNER_TYPE
#define SRC1_TYPE SRC1_INNER_TYPE
#decl(SHMEM_SCALAR)
fn store_shmem(val: f16, idx: u32) {
shmem[idx] = val;
}
#enddecl(SHMEM_SCALAR)
#decl(INIT_SRC0_SHMEM_FLOAT)
#endif
#ifdef INIT_SRC0_SHMEM_FLOAT
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let src0_val = select( // taking a slight performance hit to avoid oob
{{SRC0_TYPE}}(0.0),
src0[src0_idx/{{VEC_SIZE}}],
SRC0_TYPE(0.0),
src0[src0_idx/VEC_SIZE],
global_m < params.m && global_k < params.k);
store_shmem({{SHMEM_TYPE}}(src0_val), elem_idx);
store_shmem(SHMEM_TYPE(src0_val), elem_idx);
}
}
#endif
#enddecl(INIT_SRC0_SHMEM_FLOAT)
#decl(INIT_SRC1_SHMEM)
#ifdef INIT_SRC1_SHMEM_FLOAT
fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) {
for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {
let tile_n = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_n = offset_n + tile_n;
let global_k = k_outer + tile_k;
let src1_idx = batch_offset + global_n * params.stride_11 + global_k;
let src1_val = select(
{{SRC1_TYPE}}(0.0),
src1[src1_idx/{{VEC_SIZE}}],
SRC1_TYPE(0.0),
src1[src1_idx/VEC_SIZE],
global_n < params.n && global_k < params.k);
store_shmem({{SHMEM_TYPE}}(src1_val), TILE_SRC0_SHMEM + elem_idx);
store_shmem(SHMEM_TYPE(src1_val), TILE_SRC0_SHMEM + elem_idx);
}
}
#endif
#enddecl(INIT_SRC1_SHMEM)
#decl(INIT_SRC0_SHMEM_Q4_0)
#ifdef INIT_SRC0_SHMEM_Q4_0
const BLOCK_SIZE = 32u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
@@ -93,5 +100,4 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
}
}
}
#enddecl(INIT_SRC0_SHMEM_Q4_0)
#endif
@@ -1,115 +1,19 @@
#define(VARIANTS)
[
{
"SHADER_SUFFIX": "f32_f32_vec",
"REPLS": {
"SRC0_TYPE" : "vec4<f32>",
"SRC1_TYPE" : "vec4<f32>",
"DST_TYPE" : "vec4<f32>",
"SHMEM_TYPE" : "vec4<f16>",
"VEC_SIZE" : 4,
},
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "f32_f32",
"REPLS": {
"SRC0_TYPE" : "f32",
"SRC1_TYPE" : "f32",
"DST_TYPE" : "f32",
"SHMEM_TYPE" : "f16",
"VEC_SIZE" : 1,
},
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "f16_f32_vec",
"REPLS": {
"SRC0_TYPE" : "vec4<f16>",
"SRC1_TYPE" : "vec4<f32>",
"DST_TYPE" : "vec4<f32>",
"SHMEM_TYPE" : "vec4<f16>",
"VEC_SIZE" : 4,
},
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "f16_f32",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f32",
"DST_TYPE" : "f32",
"SHMEM_TYPE" : "f16",
"VEC_SIZE" : 1,
},
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "f16_f16_vec",
"REPLS": {
"SRC0_TYPE" : "vec4<f16>",
"SRC1_TYPE" : "vec4<f16>",
"DST_TYPE" : "vec4<f32>",
"SHMEM_TYPE" : "vec4<f16>",
"VEC_SIZE" : 4,
},
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "f16_f16",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f16",
"DST_TYPE" : "f32",
"SHMEM_TYPE" : "f16",
"VEC_SIZE" : 1,
},
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "q4_0_f32_vec",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "vec4<f32>",
"DST_TYPE" : "vec4<f32>",
"SHMEM_TYPE" : "vec4<f16>",
"VEC_SIZE" : 4,
},
"DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
},
{
"SHADER_SUFFIX": "q4_0_f32",
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f32",
"DST_TYPE" : "f32",
"SHMEM_TYPE" : "f16",
"VEC_SIZE" : 1,
},
"DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
}
]
enable f16;
#end(VARIANTS)
#include "common_decls.tmpl"
#include "mul_mat_decls.tmpl"
#define(DECLS)
#decl(VEC)
#ifdef VEC
fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> vec4<f32> {
return vec4<f32>(f32(acc[tm][tn]), f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn]));
}
#enddecl(VEC)
#endif
#decl(SCALAR)
#ifdef SCALAR
fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> f32 {
return f32(acc[tm][tn]);
}
#enddecl(SCALAR)
#end(DECLS)
#define(SHADER)
enable f16;
#endif
struct MulMatParams {
offset_src0: u32,
@@ -130,14 +34,12 @@ struct MulMatParams {
broadcast3: u32
};
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed)
@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns
@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)
@group(0) @binding(2) var<storage, read_write> dst: array<DST_TYPE>; // M rows, N columns (transposed)
@group(0) @binding(3) var<uniform> params: MulMatParams;
DECLS
fn get_local_n(thread_id: u32) -> u32 {
return thread_id / WORKGROUP_SIZE_M;
}
@@ -145,18 +47,9 @@ fn get_local_m(thread_id: u32) -> u32 {
return thread_id % WORKGROUP_SIZE_M;
}
// TILE_M must be multiple of 4 for vec4 loads
const TILE_M = {{WEBGPU_TILE_M}}u;
const TILE_N = {{WEBGPU_TILE_N}}u;
override WORKGROUP_SIZE_M: u32;
override WORKGROUP_SIZE_N: u32;
override TILE_K: u32;
override TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N;
override TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M;
override TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N;
const TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N;
const TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M;
const TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N;
var<workgroup> shmem: array<f16, TILE_SRC0_SHMEM + TILE_SRC1_SHMEM>;
@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
@@ -233,15 +126,13 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
for (var tn = 0u; tn < TILE_N; tn++) {
let global_col = output_col_base + tn;
if (global_col < params.n) {
for (var tm = 0u; tm < TILE_M; tm += {{VEC_SIZE}}) {
for (var tm = 0u; tm < TILE_M; tm += VEC_SIZE) {
let global_row = output_row_base + tm;
if (global_row < params.m) {
let dst_idx = dst_batch_offset + global_col * params.m + global_row;
dst[dst_idx/{{VEC_SIZE}}] = store_val(acc, tn, tm);
dst[dst_idx/VEC_SIZE] = store_val(acc, tn, tm);
}
}
}
}
}
#end(SHADER)

Some files were not shown because too many files have changed in this diff Show More