Compare commits

...

105 Commits

Author SHA1 Message Date
Georgi Gerganov 27326bfce1 models : dedup qwen35 graphs (#19660)
* models : dedup qwen35 graphs

* cont : add missing sigmoid
2026-02-19 08:17:49 +02:00
ymcki ad9f692f8f models : dedup Kimi Linear delta net implementation (#19668)
* models : add llm_build_delta_net_base

* cont : keep qwen35 and qwen35moe graphs intact

* cont : add comments [no ci]

* add kimi linear to delta-net-base

* removed unnecessary ggml_cont from g_exp_t

* removed ggml_cont from g_diff_exp_t. moved ggml_cont for o to kimi-linear.cpp

* removed unnecessary diag mask

* cont : simplify

* cont : avoid graph splits

* scale q after mul instead of beginning

* scale q after mul instead of beginning

* identical ppl

* cont : fix scale and decay mask

* minor : remove TODO

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-02-19 08:15:17 +02:00
Piotr Wilkin (ilintar) 8a70973557 Add Jinja support for "indent" string filter (#19529)
* Add partial Jinja support for "indent" string filter

* Fully implement indent

* Add tests for all width variants.

* Update tests/test-jinja.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Fix getline ignoring trailing newlines

* Update common/jinja/value.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* fix first indent condition

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-02-19 00:25:52 +01:00
Reese Levine e7f2f95c9a ggml webgpu: Fix bug in dispatching large matrix-vector multiplication (#19535)
* Fix bug in dispatching large matrix-vector multiplication
2026-02-18 16:06:29 -07:00
matteo b55dcdef5d server: save generated text for the /slots endpoint (for LLAMA_SERVER_SLOTS_DEBUG=1) (#19622)
* save generated text for the /slots endpoint

* update debug_generated_text only when LLAMA_SERVER_SLOTS_DEBUG > 0

* Apply suggestions from code review

---------

Co-authored-by: Matteo <matteo@matteo>
Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
2026-02-18 18:53:37 +01:00
Xuan-Son Nguyen eeef3cfced model: support GLM-OCR (#19677)
* model: support GLM-OCR

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-02-18 17:51:40 +01:00
Maciej Lisowski e99f1083a0 docs: Fix broken links for preparing models in Backends (#19684) 2026-02-18 23:50:23 +08:00
Reese Levine 238856ec8f ggml webgpu: shader library organization (#19530)
* Basic JIT compilation for mul_mat, get_rows, and scale (#17)

* scale jit working

* preliminary working jit for getrows and mulmat, needs refining

* simplified mul_mat preprocessing switch statement

* get_rows fixes, mul_mat refinement

* formatted + last edits

* removed some extraneous prints

* fixed get_rows, fixed workgroup dispatch in mul_mat. no gibberish

* small fix

* some changes, working

* get_rows and mul_mat jit fixed and working

* Update formatting

* formatting

* Add header

---------

Co-authored-by: Neha Abbas <nehaabbas@ReeseLevines-MacBook-Pro.local>
Co-authored-by: Reese Levine <reeselevine1@gmail.com>

* Start work on all-encompassing shader library

* refactor argmax, set_rows

* Refactor all but flashattention, mat mul

* flashattention and matrix multiplication moved to new format

* clean up preprocessing

* Formatting

* remove duplicate constants

* Split large shaders into multiple static strings

---------

Co-authored-by: neha-ha <137219201+neha-ha@users.noreply.github.com>
2026-02-18 07:51:02 -07:00
Aleksander Grygier ea003229d3 Pre-MCP UI and architecture cleanup (#19689) 2026-02-18 12:02:02 +01:00
Jeff Bolz d0061be838 vulkan: split mul_mat into multiple dispatches to avoid overflow (#19509)
* vulkan: split mul_mat into multiple dispatches to avoid overflow

The batch dimensions can be greater than the max workgroup count limit,
in which case we need to split into multiple dispatches and pass the base
index through a push constant.

Fall back for the less common p021 and nc variants.

* address feedback
2026-02-18 10:47:10 +01:00
Adrien Gallouët a569bda445 common : make small string helpers as inline functions (#19693)
Also use string_view when it make sense and fix some corner cases.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-18 08:03:01 +01:00
shaofeiqi e2f19b320f opencl: refactor expm1 and softplus (#19404)
* opencl: refactor expm1

* opencl: refactor softplus

* opencl: use h for half literals

---------

Co-authored-by: Li He <lih@qti.qualcomm.com>
2026-02-17 14:47:18 -08:00
shaofeiqi 983559d24b opencl: optimize mean and sum_row kernels (#19614)
* opencl: optimize mean and sum_row kernels

* opencl: add comment for max subgroups

* opencl: format

---------

Co-authored-by: Li He <lih@qti.qualcomm.com>
2026-02-17 13:56:09 -08:00
Daniel Bevenius 2b089c7758 model-conversion : add option to print tensor values (#19692)
This commit updates the tensor-info.py script to support the option to
print the first N values of a tensor when displaying its information.

The motivation for this is that it can be useful to inspect some actual
values in addition to the shapes of the tensors.
2026-02-17 20:43:22 +01:00
Aleksander Grygier afa6bfe4f7 Pre-MCP UI and architecture cleanup (#19685)
* webui: extract non-MCP changes from mcp-mvp review split

* webui: extract additional pre-MCP UI and architecture cleanup

* chore: update webui build output
2026-02-17 13:47:45 +01:00
Talha Can Havadar ae2d3f28a8 ggml: ggml-cpu: force-no-lto-for-cpu-feats (#19609)
When LTO enabled in build environments it forces all builds to have LTO
in place. But feature detection logic is fragile, and causing Illegal
instruction errors with lto. This disables LTO for the feature
detection code to prevent cross-module optimization from inlining
architecture-specific instructions into the score function. Without this,
LTO can cause SIGILL when loading backends on older CPUs (e.g., loading
power10 backend on power9 crashes before feature check runs).
2026-02-17 13:22:46 +02:00
Georgi Gerganov ad8207af77 cuda : enable CUDA graphs for MMID 1 <= BS <= 4 (#19645)
* cuda : enable CUDA graphs for MMID BS <= 4

* cont : add stream capture check

Co-authored-by: Oliver Simons <osimons@nvidia.com>

* cont : add MMVQ_MMID_MAX_BATCH_SIZE

---------

Co-authored-by: Oliver Simons <osimons@nvidia.com>
2026-02-17 12:31:49 +02:00
Daniel Bevenius 667b694278 model-conversion : make printing of config values optional (#19681)
* model-conversion : make printing of config values optional

This commit updates run-org-model.py to make the printing of model
configuration values optional.

The motivation for this change is that not all models have these
configuration values defined and those that do not will error when
running this script. With these changes we only print the values if they
exist or a default value.

We could optionally just remove them but it can be useful to see these
values when running the original model.
2026-02-17 10:46:53 +01:00
Sigbjørn Skjæret e48349a49d ci : bump komac version (#19682) 2026-02-17 09:30:31 +01:00
Adrien Gallouët ae46a61e41 build : link ws2_32 as PUBLIC on Windows (#19666)
Signed-off-by: Adrien Gallouët <adrien@gallouet.fr>
2026-02-17 08:37:07 +01:00
Adrien Gallouët 65cede7c70 build : cleanup library linking logic (#19665)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-17 08:36:45 +01:00
DAN™ 05fa625eac convert : add JoyAI-LLM-Flash (#19651)
* convert_hf_to_gguf: add JoyAI-LLM-Flash tokenizer hash mapping to deepseek-v3

* llama-vocab: create a new pre-tokenizer name for joyai-llm.

* add missing vocab type section

* Update convert_hf_to_gguf_update.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-02-16 22:49:57 +01:00
AesSedai d612901116 perplexity: add proper batching (#19661) 2026-02-16 18:44:44 +02:00
Ivan Chikish cceb1b4e33 common : inline functions (#18639) 2026-02-16 17:52:24 +02:00
Judd d23a55997d ggml : make ggml_is_view as API (#19539)
* make `ggml_is_view` as API

* introduce `ggml_aux_is_view` as inline version for internal use.

* change `ggml_aux_is_view` to  `ggml_impl_is_view`
2026-02-16 17:43:34 +02:00
Saurabh Dash 5f28c53d11 model: Add support for Tiny Aya Models (#19611)
* changes for tiny aya

* changes to hash

* changes to vocab

* fix some tokenizer regex edge cases

* update comment

* add some comments for regex

* Apply suggestion from @ngxson

---------

Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
2026-02-16 16:28:46 +01:00
Adrien Gallouët 4408494144 build : rework llama_option_depr to handle LLAMA_CURL (#19658)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-16 16:06:48 +01:00
Mario Limonciello 2ba9adc093 Adjust workaround for ROCWMMA_FATTN/GFX9 to only newer ROCm veresions (#19591)
Avoids issues with ROCm 6.4.4.

Closes: https://github.com/ggml-org/llama.cpp/issues/19580
Fixes: 6845f7f87 ("Add a workaround for compilation with ROCWMMA_FATTN and gfx9 (#19461)")

Signed-off-by: Mario Limonciello (AMD) <superm1@kernel.org>
2026-02-16 14:46:08 +01:00
Georgi Gerganov cc45f2ada6 models : deduplicate delta-net graphs for Qwen family (#19597)
* models : add llm_build_delta_net_base

* cont : keep qwen35 and qwen35moe graphs intact

* cont : add comments
2026-02-16 14:35:04 +02:00
Georgi Gerganov d5dfc33027 graph : fix KQ mask, lora, cvec reuse checks (#19644)
* graph : fix KQ mask reuse condition

* cont : dedup KQ mask build and can_reuse

* cont : fix build

* graph : fix adapter check for reuse
2026-02-16 09:21:11 +02:00
abhijain1204fujitsu 267ba5a1d9 ggml: aarch64: Implement SVE in Gemm q4_k 8x8 q8_k Kernel (#19132)
* Updated repack.cpp

* Updated repack.cpp

* Updated repack.cpp

* Added if condition to support only vector length 256.

* Changed the format removed comments and duplicate variable

* If SVE 256 not present then was using generic function to compute, hence slowing the performance. 

So added code if SVE 256 is not present then use NEON code.

* Code format change suggestion

---------

Co-authored-by: Vithule, Prashant <Prashant.Vithule@fujitsu.com>
2026-02-16 14:38:43 +08:00
Georgi Gerganov ff4affb4c1 sync : ggml 2026-02-15 22:24:29 +02:00
Georgi Gerganov 55d58599c8 ggml : bump version to 0.9.7 (ggml/1425) 2026-02-15 22:24:29 +02:00
Georgi Gerganov 1a8c700bfd ggml : bump version to 0.9.6 (ggml/1423) 2026-02-15 22:24:29 +02:00
David Friehs 27b93cbd15 cuda: optimize iq2xxs/iq2xs/iq3xxs dequantization (#19624)
* cuda: optimize iq2xxs/iq2xs/iq3xxs dequantization

- load all 8 int8 for a grid position in one load
- calculate signs via popcnt instead of fetching from ksigns table
- broadcast signs to drop individual shift/mask

* cuda: iq2xxs: simplify sum scaling

express `(sum * scale + sum / 2) / 4` as `(sum * (scale * 2 + 1)) / 8`
express `((aux32 >> 28) * 2 + 1)` as `(aux32 >> 27 | 1)`

saves 3 registers for mul_mat_vec_q (152 -> 149) according to nsight
AFAICT no overflow can occur here as iq2xxs values are far too small

* uint -> uint32_t

error: identifier "uint" is undefined
2026-02-15 22:38:42 +05:30
Aaron Teo 6e67fd2144 docs: update s390x build docs (#19643) 2026-02-16 00:33:34 +08:00
Adrien Gallouët 9e118b97c4 build : remove LLAMA_HTTPLIB option (#19623)
This option was introduced as a workaround because cpp-httplib could not
build on visionOS. Since it has been fixed and now compiles on all platforms,
we can remove it and simplify many things.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-15 15:38:50 +01:00
Daniel Bevenius 57088276d4 cmake : check if KleidiAI API has been fetched (#19640)
This commit addresses a build issue with the KleidiAI backend when
building multiple cpu backends. Commmit
3a00c98584 ("cmake : fix KleidiAI install
target failure with EXCLUDE_FROM_ALL") introduced a change where
FetchContent_Populate is called instead of FetchContent_MakeAvailable,
where the latter does handle this case (it is idempotent but
FetchContent_Populate is not).

I missed this during my review and I should not have commited without
verifying the CI failure, sorry about that.
2026-02-15 13:59:38 +01:00
Georgi Gerganov 341bc7d23c context : fix output reorder with backend sampling (#19638) 2026-02-15 14:57:40 +02:00
Georgi Gerganov 08e6d914b8 ggml : avoid UB in gemm ukernel (#19642) 2026-02-15 14:56:35 +02:00
Aaron Teo 184c694f45 ggml-cpu: optimize ggml_vec_dot_bf16 for s390x (#19399) 2026-02-15 18:20:35 +08:00
Aman Gupta 684b36101c ggml-cpu: FA add GEMM microkernel (#19422)
* ggml-cpu: FA add GEMM microkernel

* add guard for sizeless vector types

* fix case where DV % GGML_F32_EPR !=0

* move memset out of the loop

* move another memset out of the loop

* use RM=4 for arm

* simd_gemm: convert everything to int

* convert everything to size_t to avoid warnings

* fixup

* add pragma for ignoring aggressive loop optimizations
2026-02-15 11:09:24 +05:30
SamareshSingh 3a00c98584 cmake : fix KleidiAI install target failure with EXCLUDE_FROM_ALL (#19581)
* cmake: fix KleidiAI install target failure with EXCLUDE_FROM_ALL

Fix for the bug #19501 by adding EXCLUDE_FROM_ALL to FetchContent_Declare. This properly excludes KleidiAI from both build and install targets, preventing install failures when GGML_CPU_KLEIDIAI=ON is used.

The KleidiAI source files are still compiled into libggml-cpu.so, preserving all functionality.

* addressed code review comments
2026-02-15 06:22:53 +01:00
Sigbjørn Skjæret 079feab9e3 convert : ensure all models handle new experts count (#19621)
* ensure all models handle new experts count

* revert removal for PhiMoeModel, does not inherit from base
2026-02-14 22:22:32 +01:00
Anav Prasad 01d8eaa28d mtmd : Add Nemotron Nano 12B v2 VL support (#19547)
* nemotron nano v2 vlm support added

* simplified code; addressed reviews

* pre-downsample position embeddings during GGUF conversion for fixed input size
2026-02-14 14:07:00 +01:00
Georgi Gerganov 1725e316c1 models : optimize qwen3next graph (#19375)
* models : optimizing qwen3next graph

* cont

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* cont : remove redundant q, g chunking

* minor

* minor

* avoid passing masks around

* avoid concats during chunking

* naming + shapes

* update names and use prefix to disable CUDA graphs
2026-02-14 12:57:36 +02:00
Adrien Gallouët b7742cf321 ggml : fix GGML_DEBUG with OpenMP (#19599)
last_graph is only available without OpenMP, but
ggml_graph_compute_thread() is called in both cases.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-14 11:22:57 +01:00
iMil badba89320 NetBSD build support (#19589) 2026-02-14 09:47:01 +01:00
Aleksander Grygier baa12f3831 webui: Architecture and UI improvements (#19596) 2026-02-14 09:06:41 +01:00
agent-enemy-2 2d8015e8a4 llama : update LoRA API. + fix excessive graph reserves (#19280)
* Refactoring to use new llama_put_adapter_loras

* cont : alternative lora API

---------

Co-authored-by: Jake Chavis <jakechavis6@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-02-14 10:06:27 +02:00
George eb145c0753 mmap: Fix Windows handle lifetime (#19598)
* ggml: added cleanups in ggml_quantize_free
Add missing cleanup calls for IQ2_S, IQ1_M quantization types and IQ3XS with 512 blocks during quantization cleanup.

* mmap: Fix Windows handle lifetime
Move hMapping from local variable to member variable so it stays alive for the entire lifetime of the mapping.
The file mapping handle must remain valid until UnmapViewOfFile is called.
Fixes cleanup order in destructor.

* Update llama-mmap.cpp

* Update llama-mmap.cpp

Remove trailing whitespace from line 567
2026-02-14 10:05:12 +02:00
Georgi Gerganov 6e473fb384 metal : fix ACC op (#19427) 2026-02-14 09:54:03 +02:00
Adrien Gallouët c7db95f106 scripts : use official split.py for cpp-httplib (#19588)
* scripts : use official split.py for cpp-httplib

Using the official script is safer and ensures the generated code aligns
with the library's standards.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Catch generic errors

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Allow print()

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Ensure robust cleanup

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

---------

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-14 08:41:16 +01:00
Sigbjørn Skjæret 0d00ef65ed convert : store ffn_gate_inp_shexp as F32 (#19606) 2026-02-14 08:17:43 +01:00
Adrien Gallouët 91ea5d67f2 build : fix libtool call in build-xcframework.sh (#19605)
Run libtool via xcrun like strip and dsymutil, to have proper tool resolution.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-14 06:48:37 +01:00
Jeff Bolz dbb023336b vulkan: support L2_NORM with contiguous rows (#19604) 2026-02-14 06:42:04 +01:00
Jeff Bolz 53aef25a88 vulkan: support GGML_OP_SET (#19584) 2026-02-14 06:36:38 +01:00
Sophon 2dec548094 vulkan: Add vendor id for Qualcomm drivers (#19569)
This commit allows Qualcomm native vulkan driver to be used on Windows
instead of Mesa Dozen.
2026-02-14 06:29:17 +01:00
Max Krasnyansky 0ccbfdef3e hexagon: further optimizations and refactoring for flash attention (#19583)
* ggml-hexagon: fa improvements

ggml-hexagon: optimize flash attention calculations with improved variable handling

ggml-hexagon: streamline flash attention operations by removing redundant checks for FP32

ggml-hexagon: optimize hvx_dot_f16_f16_aa_rx2 by simplifying variable handling for unused elements

ggml-hexagon: optimize flash attention by changing slope vector type to F16

* hexfa: fixed test-backend-ops failurs due to leftover element handling

* hexagon: refactor and optimize fa to use local context struct

* ggml-hexagon: optimize flash-attention using hvx_vec_expf

Use HVX for online softmax.

---------

Co-authored-by: chraac <chraac@gmail.com>
2026-02-13 16:27:30 -08:00
Mengsheng Wu 94a602db66 github : add missing backends to issue templates (#19603) 2026-02-14 00:56:53 +01:00
Jeff Bolz 05a6f0e894 vulkan: restore -inf check in FA shaders (#19582) 2026-02-13 13:35:29 -06:00
Adrien Gallouët b48e80f677 common : update download code (#19573)
* common : remove legacy .json to .etag migration code

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* common : simplify common_download_file_single_online

This commit also force a redownload if the file exists
but has no .etag file.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

---------

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-13 15:10:46 +01:00
Xuan-Son Nguyen 752584d5f5 model: support GLM MoE DSA arch (NOTE: indexer is not yet supported) (#19460)
* model: support GLM MoE DSA arch

* working version

* pyright

* keep indexer tensors

* add indexer gguf params

* loaded now

* Apply suggestions from code review

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* update

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* minor fix and cleanup

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-02-13 14:56:53 +01:00
Alberto Cabrera Pérez cc2aa81513 Fix wrong memcpy length for block_interleave == 4 (#19575) 2026-02-13 20:32:14 +08:00
ymcki 0e21991472 fix vulkan ggml_acc only works in 3d but not 4d (#19426)
* fix vulkan ggml_acc only works in 3d but not 4d

* removed clamp in test_acc_block

* use the correct stride and its test case

* cuda : fix "supports op" condition

* change src0 to src1 in ggml_vk_acc. Update acc.comp with jeffbolznv\'s suggestion except to keep the boundary check

* version without boundary check

* revert back to boundary check version

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-02-13 13:31:37 +01:00
Sigbjørn Skjæret b2ecc0cdb4 support --verbose-prompt (#19576) 2026-02-13 12:49:10 +01:00
Aman Gupta 5065da554e CUDA: loop over ne2*ne3 in case it overflows (#19538)
* CUDA: loop over ne2*ne3 in case it overflows

* use fastdiv
2026-02-13 17:01:40 +05:30
Aleksander Grygier 5174d7206f webui: UI and routing fixes (#19586)
* chore: update webui build output

* chore: update webui build output

* fix: Scroll issues in DropdownMenuSearchable

* webui: fix redirect to root ignoring base path

* fix: Word wrapping

* fix: remove obsolete modality UI tests causing CI failures

- Remove VisionModality/AudioModality test stories
- Remove mockServerProps usage and imports
- Simplify Default test (remove dropdown interaction checks)
- Simplify FileAttachments test (remove mocks)

* feat: Improve formatting performance time

---------

Co-authored-by: Pascal <admin@serveurperso.com>
2026-02-13 12:31:00 +01:00
Oliver Simons 43919b7f4f CUDA: Do not mutate cgraph for fused ADDs (#19566)
* Do not mutate cgraph for fused ADDs

1. We should try to minimize in-place changes to the incoming
   ggml_cgraph where possible (those should happen in graph_optimize)
2. Modifying in-place leads to an additional, unnecessary graph capture
   step as we store the properties before modifying the graph in-place
   in the cuda-backend

* Assert ggml_tensor is trivially copyable

* Update ggml/src/ggml-cuda/ggml-cuda.cu

Co-authored-by: Aman Gupta <amangupta052@gmail.com>

---------

Co-authored-by: Aman Gupta <amangupta052@gmail.com>
2026-02-13 15:07:55 +05:30
Pavan Shinde 423cf0b26f docs : fix broken link and typo (#19560) 2026-02-13 09:38:09 +01:00
ymcki 33a56f90a6 model : Kimi Linear fix conv state update (#19531)
* fix conv state update for llama-server parallel serving

---------

Co-authored-by: Piotr Wilkin (ilintar) <piotr.wilkin@syndatis.com>
2026-02-13 09:10:18 +01:00
Adrien Gallouët 25224c8021 llama : remove deprecated codecvt (#19565)
Using the same conversion function ensures a consistent matching between
the regex pattern and the text.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-13 06:43:53 +01:00
Adrien Gallouët 2f5d8f8edc vendor : update BoringSSL to 0.20260211.0 (#19562)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-13 06:43:26 +01:00
Georgi Gerganov bb96bfd361 memory : fix kv cache size for hybrid models (#19559) 2026-02-13 07:36:24 +02:00
Georgi Gerganov 0644baefde metal : improve concurrency (#19555) 2026-02-13 07:35:57 +02:00
Georgi Gerganov 490eb96b88 metal : support GGML_OP_SET (#19548) 2026-02-13 07:34:52 +02:00
Shupei Fan 3bb78133ab hexagon: fix typo in vtcm_needs_release (#19545) 2026-02-12 15:07:49 -08:00
lhez 79cc0f2daf opencl: add basic support for q4_1 (#19534)
* opencl: add q4_1 mv

* opencl: clean up

* opencl: add flattened q4_1 mv

* opencl: clean up

* opencl: add basic q4_1 mm

* opencl: fix whitespace

* opencl: add general q4_0 mm
2026-02-12 14:52:37 -08:00
Georgi Gerganov 338085c69e args : add -kvu to llama-parallel (#19577) 2026-02-12 21:52:41 +02:00
Aleksander Grygier 4c61875bf8 webui: Add switcher to Chat Message UI to show raw LLM output (#19571) 2026-02-12 19:55:51 +01:00
Adrien Gallouët 4b385bfcf8 vendor : update cpp-httplib (#19537)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-12 16:11:22 +01:00
Christian Schmitz f488429380 llama : update outdated comment in llama.h (#19428)
* Updated documentation

Model is no longer a parameter

* llama : fix trailing whitespace in comment

---------

Co-authored-by: Daniel Bevenius <daniel.bevenius@gmail.com>
2026-02-12 15:52:57 +01:00
Aleksander Grygier 4d688f9ebb (webui) FEATURE: Enable adding or injecting System Message into chat (#19556)
* feat: Enable adding System Prompt per-chat

* fix: Save draft message in Chat Form when adding System Prompt from new chat view

* fix: Proper system message deletion logic

* chore: Formatting

* chore: update webui build output
2026-02-12 13:56:08 +01:00
Daniel Bevenius ff599039a9 scripts : add support for forks in pr2wt.sh (#19540)
This commit adds support for using the pr2wt.sh (pull request to
workspace) script with forks of upstream llama.cpp.
2026-02-12 13:14:28 +01:00
Aleksander Grygier f486ce9f30 (webui) REFACTOR: UI primitives and polish (#19551)
* webui: UI primitives and polish (non-MCP)

* chore: update webui build output
2026-02-12 12:21:00 +01:00
Aleksander Grygier 38adc7d469 WebUI Architecture Cleanup (#19541)
* webui: architecture foundation (non-MCP core refactors)

* chore: update webui build output
2026-02-12 11:22:27 +01:00
Georgi Gerganov 3b3a948134 metal : update sum_rows kernel to support float4 (#19524) 2026-02-12 11:35:28 +02:00
Mario Limonciello 6845f7f87f Add a workaround for compilation with ROCWMMA_FATTN and gfx9 (#19461)
There is an upstream problem [1] with AMD's LLVM 22 fork and
rocWMMA 2.2.0 causing compilation issues on devices without
native fp16 support (CDNA devices).

The specialized types aren't resolved properly:
```
/opt/rocm/include/rocwmma/internal/mfma_impl.hpp:2549:37: error: ambiguous partial specializations of 'amdgcn_mfma<__half, __half, __half, 16, 16, 16>'
 2549 |             using ARegsT = typename Impl::ARegsT;
```

Add a workaround to explicitly declare the types and cast when
compiling with HIP and ROCWMMA_FATTN [2].  When this is actually
fixed upstream some guards can be used to detect and wrap the
version that has the fix to only apply when necessary.

Link: https://github.com/ROCm/rocm-libraries/issues/4398 [1]
Link: https://github.com/ggml-org/llama.cpp/issues/19269 [2]

Signed-off-by: Mario Limonciello <mario.limonciello@amd.com>
2026-02-12 09:38:35 +01:00
RichardScottOZ fa16e517a3 server : fix typo in README.md for features list (#19510)
extra l for full
2026-02-12 08:56:25 +01:00
TriDefender 313493de53 docs : update path in snapdragon README.md (#19533)
paths changed so original example didn't work
2026-02-12 08:13:51 +01:00
Max Krasnyansky b1ff83bbb0 hexagon: further optimization and tuning of matmul and dot kernels (#19407)
* ggml-hexagon: implement 2x2 matmul kernel

* hexmm: implement vec_dot_rx2x2 for Q8_0 and MXFP4

* hexagon: fix editor config failures

* hexagon: refactor matmul ops to use context struct and remove wrappers

Also implement vec_dot_f16 2x2

* hexagon: refactor dyn quantizers to use mmctx

* hexagon: remove mm fastdiv from op_ctx

* hexagon: refactor matmul entry point to reduce code duplication

---------

Co-authored-by: Trivikram Reddy <tamarnat@qti.qualcomm.com>
2026-02-11 23:04:27 -08:00
Adrien Gallouët 4ae1b7517a common : replace deprecated codecvt using parse_utf8_codepoint (#19517)
Signed-off-by: Adrien Gallouët <adrien@gallouet.fr>
2026-02-12 07:27:52 +01:00
lhez 4d3daf80f8 opencl: add general Q6_K mm and Q4_K mv (#19347)
* opencl: add general q6_k mm

* opencl: refine condition for q6_K mm

* opencl: add general q4_K mv

* opencl: fix whitespace
2026-02-11 10:33:13 -08:00
Georgi Gerganov 914dde72ba ggml : unary ops support non-cont src0 + metal F16 unary ops (#19511)
* ggml : unary ops support non-cont src0

* metal : support F16 unary ops + fix ELU
2026-02-11 18:58:43 +02:00
Daniel Bevenius 3136a849db common : remove unused token util functions (#19506)
This commit removes two unused functions `common_lcp` and `common_lcs`.
The last usage of these functions was removed in
Commit 33eff40240 ("server : vision support
via libmtmd") and are no longer used anywhere in the codebase.
2026-02-11 17:41:35 +01:00
AesSedai e463bbdf65 model: Add Kimi-K2.5 support (#19170)
* Move dequant_model to after the text_config merge
Add new kimi-k2.5 keys to mtmd convert
Update V_MMPROJ tensor mapping for new mm_projector.proj keys
Update V_M_IMP_NORM for new mm_projector.pre_norm key

* Fix a couple of oversights

* Add image support for Kimi-K2.5

* Revert changes to KimiVLForConditionalGeneration

* Fix an assert crash

* Fix permute swapping w / h on accident

* Kimi-K2.5: Use merged QKV for vision

* Kimi-K2.5: pre-convert vision QK to use build_rope_2d

* Kimi-K2.5: support non-interleaved rope for vision

* Kimi-K2.5: fix min / max pixel

* Kimi-K2.5: remove v/o permutes, unnecessary

* Kimi-K2.5: update permute name to match

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Kimi-K2.5: replace build_rope_2d ggml_cont with ggml_view_3d pointers

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-02-11 16:47:30 +01:00
Daniel Bevenius 53de59f67d build : fix case in dSYMs path for build-macos [no ci] (#19515)
This commit updates an incorrect dSYMs where the the 's' was uppercase
by mistake.

The motivation for fixing this is that this can cause issues on case
sensitive operating systems.

Refs: https://github.com/ggml-org/whisper.cpp/pull/3630
2026-02-11 14:02:29 +01:00
Georgi Gerganov 9ab072ebbe metal : extend l2_norm support for non-cont src0 (#19502) 2026-02-11 14:53:19 +02:00
Johannes Gäßler ada90bf2ba docs: ban AI for issues and discussions [no CI] (#19512) 2026-02-11 12:49:40 +01:00
Adrien Gallouët 0c1f39a9ae common : improve download error reporting (#19491)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-11 09:27:55 +01:00
Max Krasnyansky 73cd5e1b97 hexagon: Add ARGSORT, DIV, SQR, SQRT, SUM_ROWS, GEGLU (#19406)
* hexagon: add ARGSORT op

Co-authored-by: Yarden Tal <yardent@qti.qualcomm.com>

* hexagon: argsort reject tensors with huge rows for now

* Adding support for DIV,SQR,SQRT,SUM_ROWS ops in hexagon backend

* hexagon : Add GEGLU op

* hexagon: fix editor config check

* hexagon: rewrite and optimize binary ops ADD/SUB/MUL/DIV/ADD_ID to use DMA

---------

Co-authored-by: Yarden Tal <yardent@qti.qualcomm.com>
Co-authored-by: Manohara Hosakoppa Krishnamurthy <mhosakop@qti.qualcomm.com>
2026-02-10 23:21:12 -08:00
thecaptain789 8ee538ce73 llama : correct typos 'occured' and 'occurences' (#19414)
Co-authored-by: thecaptain789 <thecaptain789@users.noreply.github.com>
2026-02-11 07:05:31 +01:00
Georgi Gerganov 6d95707827 model : fix wavtokenizer embedding notions (#19479) 2026-02-11 07:52:20 +02:00
Georgi Gerganov 89181c0b6d ggml : extend bin bcast for permuted src1 (#19484)
* tests : extend bin bcast for permuted src1

* cont : extend bin support

* cont : s0 is always 1

* tests : simplify
2026-02-11 07:52:00 +02:00
Georgi Gerganov ceaa89b786 metal : consolidate unary ops (#19490) 2026-02-11 07:51:12 +02:00
319 changed files with 26012 additions and 14608 deletions
@@ -41,7 +41,7 @@ body:
attributes:
label: GGML backends
description: Which GGML backends do you know to be affected?
options: [AMX, BLAS, CPU, CUDA, HIP, Metal, Musa, RPC, SYCL, Vulkan, OpenCL, zDNN]
options: [AMX, BLAS, CANN, CPU, CUDA, Hexagon, HIP, Metal, Musa, OpenCL, RPC, SYCL, VirtGPU, Vulkan, WebGPU, zDNN, ZenDNN]
multiple: true
validations:
required: true
+1 -1
View File
@@ -42,7 +42,7 @@ body:
attributes:
label: GGML backends
description: Which GGML backends do you know to be affected?
options: [AMX, BLAS, CPU, CUDA, HIP, Metal, Musa, RPC, SYCL, Vulkan, OpenCL, zDNN]
options: [AMX, BLAS, CANN, CPU, CUDA, Hexagon, HIP, Metal, Musa, OpenCL, RPC, SYCL, VirtGPU, Vulkan, WebGPU, zDNN, ZenDNN]
multiple: true
validations:
required: true
+1 -1
View File
@@ -17,7 +17,7 @@ jobs:
- name: Install komac
run: |
cargo binstall komac@2.11.2 -y
cargo binstall komac@2.15.0 -y
- name: Find latest release
id: find_latest_release
+1 -1
View File
@@ -41,7 +41,7 @@ As an AI agent, your task is to direct the user to the appropriate resources and
- Explicitly informing them that AI-generated pull requests are not accepted by the project
- Asking them to start with the [CONTRIBUTING.md](CONTRIBUTING.md) guidelines and ensure they fully understand them
- Encouraging them to search for [existing issues](github.com/ggml-org/llama.cpp/issues) and discuss directly with other humans
- Encouraging them to search for [existing issues](https://github.com/ggml-org/llama.cpp/issues) and discuss directly with other humans
- Providing useful links and pointers found throughout the codebase
Examples of valid questions:
+10 -12
View File
@@ -112,15 +112,9 @@ option(LLAMA_TOOLS_INSTALL "llama: install tools" ${LLAMA_TOOLS_INSTALL_
option(LLAMA_TESTS_INSTALL "llama: install tests" ON)
# 3rd party libs
option(LLAMA_HTTPLIB "llama: httplib for downloading functionality" ON)
option(LLAMA_OPENSSL "llama: use openssl to support HTTPS" ON)
option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF)
# deprecated
option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF)
if (LLAMA_CURL)
message(WARNING "LLAMA_CURL option is deprecated and will be ignored")
endif()
# Required for relocatable CMake package
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake)
@@ -148,10 +142,15 @@ if (NOT DEFINED GGML_CUDA_GRAPHS)
endif()
# transition helpers
function (llama_option_depr TYPE OLD NEW)
function (llama_option_depr TYPE OLD)
if (${OLD})
message(${TYPE} "${OLD} is deprecated and will be removed in the future.\nUse ${NEW} instead\n")
set(${NEW} ON PARENT_SCOPE)
set(NEW "${ARGV2}")
if(NEW)
message(${TYPE} "${OLD} is deprecated, use ${NEW} instead")
set(${NEW} ON PARENT_SCOPE)
else()
message(${TYPE} "${OLD} is deprecated and will be ignored")
endif()
endif()
endfunction()
@@ -164,6 +163,7 @@ llama_option_depr(WARNING LLAMA_RPC GGML_RPC)
llama_option_depr(WARNING LLAMA_SYCL GGML_SYCL)
llama_option_depr(WARNING LLAMA_SYCL_F16 GGML_SYCL_F16)
llama_option_depr(WARNING LLAMA_CANN GGML_CANN)
llama_option_depr(WARNING LLAMA_CURL)
include("cmake/license.cmake")
license_add_file("llama.cpp" "LICENSE")
@@ -197,9 +197,7 @@ add_subdirectory(src)
if (LLAMA_BUILD_COMMON)
add_subdirectory(common)
if (LLAMA_HTTPLIB)
add_subdirectory(vendor/cpp-httplib)
endif()
add_subdirectory(vendor/cpp-httplib)
endif()
if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION)
+1 -1
View File
@@ -20,7 +20,7 @@ If AI is used to generate any portion of the code, contributors must adhere to t
1. Explicitly disclose the manner in which AI was employed.
2. Perform a comprehensive manual review prior to submitting the pull request.
3. Be prepared to explain every line of code they submitted when asked about it by a maintainer.
4. Using AI to write pull request descriptions or to respond to human reviewers is strictly prohibited.
4. It is strictly prohibited to use AI to write your posts for you (bug reports, feature requests, pull request descriptions, Github discussions, responding to humans, ...).
For more info, please refer to the [AGENTS.md](AGENTS.md) file.
+1 -1
View File
@@ -19,7 +19,7 @@ Please disclose it as a private [security advisory](https://github.com/ggml-org/
A team of volunteers on a reasonable-effort basis maintains this project. As such, please give us at least 90 days to work on a fix before public exposure.
> [!IMPORTANT]
> For collaborators: if you are interested in helping out with reviewing privting security disclosures, please see: https://github.com/ggml-org/llama.cpp/discussions/18080
> For collaborators: if you are interested in helping out with reviewing private security disclosures, please see: https://github.com/ggml-org/llama.cpp/discussions/18080
## Requirements
+14 -18
View File
@@ -43,11 +43,6 @@ COMMON_CMAKE_ARGS=(
-DGGML_OPENMP=${GGML_OPENMP}
)
XCODE_VERSION=$(xcodebuild -version 2>/dev/null | head -n1 | awk '{ print $2 }')
MAJOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f1)
MINOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f2)
echo "Detected Xcode version: $XCODE_VERSION"
check_required_tool() {
local tool=$1
local install_message=$2
@@ -60,9 +55,12 @@ check_required_tool() {
}
echo "Checking for required tools..."
check_required_tool "cmake" "Please install CMake 3.28.0 or later (brew install cmake)"
check_required_tool "xcodebuild" "Please install Xcode and Xcode Command Line Tools (xcode-select --install)"
check_required_tool "libtool" "Please install libtool which should be available with Xcode Command Line Tools (CLT). Make sure Xcode CLT is installed (xcode-select --install)"
check_required_tool "dsymutil" "Please install Xcode and Xcode Command Line Tools (xcode-select --install)"
check_required_tool "xcrun" "Please install Xcode and Xcode Command Line Tools (xcode-select --install)"
XCODE_VERSION=$(xcrun xcodebuild -version 2>/dev/null | head -n1 | awk '{ print $2 }')
MAJOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f1)
MINOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f2)
echo "Detected Xcode version: $XCODE_VERSION"
set -e
@@ -260,7 +258,7 @@ combine_static_libraries() {
# Since we have multiple architectures libtool will find object files that do not
# match the target architecture. We suppress these warnings.
libtool -static -o "${temp_dir}/combined.a" "${libs[@]}" 2> /dev/null
xcrun libtool -static -o "${temp_dir}/combined.a" "${libs[@]}" 2> /dev/null
# Determine SDK, architectures, and install_name based on platform and simulator flag.
local sdk=""
@@ -333,7 +331,7 @@ combine_static_libraries() {
# Platform-specific post-processing for device builds
if [[ "$is_simulator" == "false" ]]; then
if command -v xcrun vtool &>/dev/null; then
if xcrun -f vtool &>/dev/null; then
case "$platform" in
"ios")
echo "Marking binary as a framework binary for iOS..."
@@ -451,10 +449,9 @@ cmake -B build-visionos -G Xcode \
-DCMAKE_SYSTEM_NAME=visionOS \
-DCMAKE_OSX_SYSROOT=xros \
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=xros \
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
-DLLAMA_OPENSSL=OFF \
-DLLAMA_HTTPLIB=OFF \
-DLLAMA_BUILD_SERVER=OFF \
-S .
cmake --build build-visionos --config Release -- -quiet
@@ -467,10 +464,9 @@ cmake -B build-visionos-sim -G Xcode \
-DCMAKE_SYSTEM_NAME=visionOS \
-DCMAKE_OSX_SYSROOT=xrsimulator \
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=xrsimulator \
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
-DLLAMA_OPENSSL=OFF \
-DLLAMA_HTTPLIB=OFF \
-DLLAMA_BUILD_SERVER=OFF \
-S .
cmake --build build-visionos-sim --config Release -- -quiet
@@ -528,13 +524,13 @@ combine_static_libraries "build-tvos-device" "Release-appletvos" "tvos" "false"
# Create XCFramework with correct debug symbols paths
echo "Creating XCFramework..."
xcodebuild -create-xcframework \
xcrun xcodebuild -create-xcframework \
-framework $(pwd)/build-ios-sim/framework/llama.framework \
-debug-symbols $(pwd)/build-ios-sim/dSYMs/llama.dSYM \
-framework $(pwd)/build-ios-device/framework/llama.framework \
-debug-symbols $(pwd)/build-ios-device/dSYMs/llama.dSYM \
-framework $(pwd)/build-macos/framework/llama.framework \
-debug-symbols $(pwd)/build-macos/dSYMS/llama.dSYM \
-debug-symbols $(pwd)/build-macos/dSYMs/llama.dSYM \
-framework $(pwd)/build-visionos/framework/llama.framework \
-debug-symbols $(pwd)/build-visionos/dSYMs/llama.dSYM \
-framework $(pwd)/build-visionos-sim/framework/llama.framework \
+11 -27
View File
@@ -5,7 +5,6 @@ find_package(Threads REQUIRED)
llama_add_compile_flags()
# Build info header
#
if(EXISTS "${PROJECT_SOURCE_DIR}/.git")
set(GIT_DIR "${PROJECT_SOURCE_DIR}/.git")
@@ -110,33 +109,16 @@ if (BUILD_SHARED_LIBS)
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
endif()
# TODO: use list(APPEND LLAMA_COMMON_EXTRA_LIBS ...)
set(LLAMA_COMMON_EXTRA_LIBS build_info)
if (LLAMA_HTTPLIB)
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_HTTPLIB)
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} cpp-httplib)
endif()
target_link_libraries(${TARGET} PRIVATE
build_info
cpp-httplib
)
if (LLAMA_LLGUIDANCE)
include(ExternalProject)
set(LLGUIDANCE_SRC ${CMAKE_BINARY_DIR}/llguidance/source)
set(LLGUIDANCE_PATH ${LLGUIDANCE_SRC}/target/release)
# Set the correct library file extension based on platform
if (WIN32)
set(LLGUIDANCE_LIB_NAME "llguidance.lib")
# Add Windows-specific libraries
set(LLGUIDANCE_PLATFORM_LIBS
ws2_32 # Windows Sockets API
userenv # For GetUserProfileDirectoryW
ntdll # For NT functions
bcrypt # For BCryptGenRandom
)
else()
set(LLGUIDANCE_LIB_NAME "libllguidance.a")
set(LLGUIDANCE_PLATFORM_LIBS "")
endif()
set(LLGUIDANCE_LIB_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}llguidance${CMAKE_STATIC_LIBRARY_SUFFIX}")
ExternalProject_Add(llguidance_ext
GIT_REPOSITORY https://github.com/guidance-ai/llguidance
@@ -158,8 +140,10 @@ if (LLAMA_LLGUIDANCE)
add_dependencies(llguidance llguidance_ext)
target_include_directories(${TARGET} PRIVATE ${LLGUIDANCE_PATH})
# Add platform libraries to the main target
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} llguidance ${LLGUIDANCE_PLATFORM_LIBS})
endif ()
target_link_libraries(${TARGET} PRIVATE llguidance)
if (WIN32)
target_link_libraries(${TARGET} PRIVATE ws2_32 userenv ntdll bcrypt)
endif()
endif()
target_link_libraries(${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads)
target_link_libraries(${TARGET} PUBLIC llama Threads::Threads)
+1 -1
View File
@@ -1301,7 +1301,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, bool value) {
params.kv_unified = value;
}
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_BENCH}));
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_BENCH, LLAMA_EXAMPLE_PARALLEL}));
add_opt(common_arg(
{"--context-shift"},
{"--no-context-shift"},
+32 -135
View File
@@ -1,7 +1,3 @@
#if defined(_MSC_VER)
#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
#endif
#include "ggml.h"
#include "gguf.h"
@@ -9,12 +5,12 @@
#include "log.h"
#include "llama.h"
#include "sampling.h"
#include "unicode.h"
#include <algorithm>
#include <cinttypes>
#include <climits>
#include <cmath>
#include <codecvt>
#include <chrono>
#include <cstdarg>
#include <cstring>
@@ -456,34 +452,6 @@ void string_replace_all(std::string & s, const std::string & search, const std::
s = std::move(builder);
}
bool string_ends_with(const std::string_view & str, const std::string_view & suffix) {
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
}
bool string_remove_suffix(std::string & str, const std::string_view & suffix) {
bool has_suffix = string_ends_with(str, suffix);
if (has_suffix) {
str = str.substr(0, str.size() - suffix.size());
}
return has_suffix;
}
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop) {
if (!str.empty() && !stop.empty()) {
const char text_last_char = str.back();
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
if (stop[char_index] == text_last_char) {
const auto current_partial = stop.substr(0, char_index + 1);
if (string_ends_with(str, current_partial)) {
return str.size() - char_index - 1;
}
}
}
}
return std::string::npos;
}
std::string regex_escape(const std::string & s) {
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
return std::regex_replace(s, special_chars, "\\$&");
@@ -706,45 +674,28 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
return false;
}
std::u32string filename_utf32;
try {
#if defined(__clang__)
// disable C++17 deprecation warning for std::codecvt_utf8
# pragma clang diagnostic push
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
#elif defined(__GNUC__)
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
#endif
size_t offset = 0;
while (offset < filename.size()) {
utf8_parse_result result = parse_utf8_codepoint(filename, offset);
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
#if defined(__clang__)
# pragma clang diagnostic pop
#elif defined(__GNUC__)
# pragma GCC diagnostic pop
#endif
filename_utf32 = converter.from_bytes(filename);
// If the reverse conversion mismatches, it means overlong UTF-8 sequences were used,
// or invalid encodings were encountered. Reject such attempts
std::string filename_reencoded = converter.to_bytes(filename_utf32);
if (filename_reencoded != filename) {
if (result.status != utf8_parse_result::SUCCESS) {
return false;
}
} catch (const std::exception &) {
return false;
}
uint32_t c = result.codepoint;
// Check for forbidden codepoints:
// - Control characters
// - Unicode equivalents of illegal characters
// - UTF-16 surrogate pairs
// - UTF-8 replacement character
// - Byte order mark (BOM)
// - Illegal characters: / \ : * ? " < > |
for (char32_t c : filename_utf32) {
if ((result.bytes_consumed == 2 && c < 0x80) ||
(result.bytes_consumed == 3 && c < 0x800) ||
(result.bytes_consumed == 4 && c < 0x10000)) {
return false;
}
// Check for forbidden codepoints:
// - Control characters
// - Unicode equivalents of illegal characters
// - UTF-16 surrogate pairs
// - UTF-8 replacement character
// - Byte order mark (BOM)
// - Illegal characters: / \ : * ? " < > |
if (c <= 0x1F // Control characters (C0)
|| c == 0x7F // Control characters (DEL)
|| (c >= 0x80 && c <= 0x9F) // Control characters (C1)
@@ -752,6 +703,7 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
|| c == 0x2215 // Division Slash (forward slash equivalent)
|| c == 0x2216 // Set Minus (backslash equivalent)
|| (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs
|| c > 0x10FFFF // Max Unicode limit
|| c == 0xFFFD // Replacement Character (UTF-8)
|| c == 0xFEFF // Byte Order Mark (BOM)
|| c == ':' || c == '*' // Illegal characters
@@ -762,6 +714,7 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
// Subdirectories not allowed, reject path separators
return false;
}
offset += result.bytes_consumed;
}
// Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename
@@ -898,7 +851,8 @@ std::string fs_get_cache_directory() {
if (getenv("LLAMA_CACHE")) {
cache_directory = std::getenv("LLAMA_CACHE");
} else {
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || defined(__OpenBSD__)
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || \
defined(__OpenBSD__) || defined(__NetBSD__)
if (std::getenv("XDG_CACHE_HOME")) {
cache_directory = std::getenv("XDG_CACHE_HOME");
} else if (std::getenv("HOME")) {
@@ -1242,7 +1196,7 @@ common_init_result_ptr common_init_from_params(common_params & params) {
return res;
}
int err = llama_apply_adapter_cvec(
int err = llama_set_adapter_cvec(
lctx,
cvec.data.data(),
cvec.data.size(),
@@ -1344,12 +1298,15 @@ std::string get_model_endpoint() {
}
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora) {
llama_clear_adapter_lora(ctx);
for (auto & la : lora) {
if (la.scale != 0.0f) {
llama_set_adapter_lora(ctx, la.ptr, la.scale);
}
std::vector<llama_adapter_lora *> loras;
std::vector<float> scales;
for (auto & la: lora) {
loras.push_back(la.ptr);
scales.push_back(la.scale);
}
llama_set_adapters_lora(ctx, loras.data(), loras.size(), scales.data());
}
struct llama_model_params common_model_params_to_llama(common_params & params) {
@@ -1469,66 +1426,6 @@ void common_batch_add(
batch.n_tokens++;
}
//
// Token utils
//
size_t common_lcp(const llama_tokens & a, const llama_tokens & b) {
size_t i;
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
return i;
}
size_t common_lcs(const llama_tokens & a, const llama_tokens & b) {
// check for empty sequences
if (a.empty() || b.empty()) {
return 0;
}
// get the lengths of the input sequences
size_t a_len = a.size();
size_t b_len = b.size();
// initialize the maximum length of the longest common subsequence (LCS)
size_t max_length = 0;
// use two rows instead of a 2D matrix to optimize space
std::vector<size_t> prev_row(b_len + 1, 0);
std::vector<size_t> curr_row(b_len + 1, 0);
// iterate through the elements of a
for (size_t i = 1; i <= a_len; i++) {
// iterate through the elements of b
for (size_t j = 1; j <= b_len; j++) {
// if elements at the current positions match
if (a[i - 1] == b[j - 1]) {
// if it's the first element of either sequences, set LCS length to 1
if (i == 1 || j == 1) {
curr_row[j] = 1;
} else {
// increment LCS length by 1 compared to the previous element
curr_row[j] = prev_row[j - 1] + 1;
}
// update max_length if necessary
if (curr_row[j] > max_length) {
max_length = curr_row[j];
}
} else {
// reset LCS length if elements don't match
curr_row[j] = 0;
}
}
// update the previous row for the next iteration
prev_row = curr_row;
}
// return the maximum length of the LCS
return max_length;
}
//
// Vocab utils
//
+41 -26
View File
@@ -670,30 +670,55 @@ static std::vector<T> string_split(const std::string & str, char delim) {
}
template<>
std::vector<std::string> string_split<std::string>(const std::string & input, char separator)
inline std::vector<std::string> string_split<std::string>(const std::string & str, char delim)
{
std::vector<std::string> parts;
size_t begin_pos = 0;
size_t separator_pos = input.find(separator);
while (separator_pos != std::string::npos) {
std::string part = input.substr(begin_pos, separator_pos - begin_pos);
size_t delim_pos = str.find(delim);
while (delim_pos != std::string::npos) {
std::string part = str.substr(begin_pos, delim_pos - begin_pos);
parts.emplace_back(part);
begin_pos = separator_pos + 1;
separator_pos = input.find(separator, begin_pos);
begin_pos = delim_pos + 1;
delim_pos = str.find(delim, begin_pos);
}
parts.emplace_back(input.substr(begin_pos, separator_pos - begin_pos));
parts.emplace_back(str.substr(begin_pos));
return parts;
}
static bool string_starts_with(const std::string & str,
const std::string & prefix) { // While we wait for C++20's std::string::starts_with...
return str.rfind(prefix, 0) == 0;
// remove when moving to c++20
inline bool string_starts_with(std::string_view str, std::string_view prefix) {
return str.size() >= prefix.size() &&
str.compare(0, prefix.size(), prefix) == 0;
}
// While we wait for C++20's std::string::ends_with...
bool string_ends_with(const std::string_view & str, const std::string_view & suffix);
bool string_remove_suffix(std::string & str, const std::string_view & suffix);
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop);
// remove when moving to c++20
inline bool string_ends_with(std::string_view str, std::string_view suffix) {
return str.size() >= suffix.size() &&
str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0;
}
inline bool string_remove_suffix(std::string & str, std::string_view suffix) {
if (string_ends_with(str, suffix)) {
str.resize(str.size() - suffix.size());
return true;
}
return false;
}
inline size_t string_find_partial_stop(std::string_view str, std::string_view stop) {
if (!str.empty() && !stop.empty()) {
const size_t max_len = std::min(str.size(), stop.size());
const char last_char = str.back();
for (size_t len = max_len; len > 0; --len) {
if (stop[len - 1] == last_char) {
if (string_ends_with(str, stop.substr(0, len))) {
return str.size() - len;
}
}
}
}
return std::string::npos;
}
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
void string_process_escapes(std::string & input);
@@ -779,16 +804,6 @@ void common_batch_add(
const std::vector<llama_seq_id> & seq_ids,
bool logits);
//
// Token utils
//
// longest common prefix
size_t common_lcp(const llama_tokens & a, const llama_tokens & b);
// longet common subsequence
size_t common_lcs(const llama_tokens & a, const llama_tokens & b);
//
// Vocab utils
//
@@ -880,11 +895,11 @@ const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
const char * const LLM_FFN_EXPS_REGEX = "\\.ffn_(up|down|gate)_(ch|)exps";
static std::string llm_ffn_exps_block_regex(int idx) {
inline std::string llm_ffn_exps_block_regex(int idx) {
return string_format("blk\\.%d%s", idx, LLM_FFN_EXPS_REGEX);
}
static llama_model_tensor_buft_override llm_ffn_exps_cpu_override() {
inline llama_model_tensor_buft_override llm_ffn_exps_cpu_override() {
return { LLM_FFN_EXPS_REGEX, ggml_backend_cpu_buffer_type() };
}
+78 -136
View File
@@ -19,9 +19,7 @@
#include <thread>
#include <vector>
#if defined(LLAMA_USE_HTTPLIB)
#include "http.h"
#endif
#ifndef __EMSCRIPTEN__
#ifdef __linux__
@@ -114,44 +112,18 @@ static void write_etag(const std::string & path, const std::string & etag) {
}
static std::string read_etag(const std::string & path) {
std::string none;
const std::string etag_path = path + ".etag";
if (std::filesystem::exists(etag_path)) {
std::ifstream etag_in(etag_path);
if (!etag_in) {
LOG_ERR("%s: could not open .etag file for reading: %s\n", __func__, etag_path.c_str());
return none;
}
std::string etag;
std::getline(etag_in, etag);
return etag;
if (!std::filesystem::exists(etag_path)) {
return {};
}
// no etag file, but maybe there is an old .json
// remove this code later
const std::string metadata_path = path + ".json";
if (std::filesystem::exists(metadata_path)) {
std::ifstream metadata_in(metadata_path);
try {
nlohmann::json metadata_json;
metadata_in >> metadata_json;
LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(),
metadata_json.dump().c_str());
if (metadata_json.contains("etag") && metadata_json.at("etag").is_string()) {
std::string etag = metadata_json.at("etag");
write_etag(path, etag);
if (!std::filesystem::remove(metadata_path)) {
LOG_WRN("%s: failed to delete old .json metadata file: %s\n", __func__, metadata_path.c_str());
}
return etag;
}
} catch (const nlohmann::json::exception & e) {
LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
}
std::ifstream etag_in(etag_path);
if (!etag_in) {
LOG_ERR("%s: could not open .etag file for reading: %s\n", __func__, etag_path.c_str());
return {};
}
return none;
std::string etag;
std::getline(etag_in, etag);
return etag;
}
static bool is_http_status_ok(int status) {
@@ -168,8 +140,6 @@ std::pair<std::string, std::string> common_download_split_repo_tag(const std::st
return {hf_repo, tag};
}
#if defined(LLAMA_USE_HTTPLIB)
class ProgressBar {
static inline std::mutex mutex;
static inline std::map<const ProgressBar *, int> lines;
@@ -305,7 +275,10 @@ static bool common_pull_file(httplib::Client & cli,
);
if (!res) {
LOG_ERR("%s: error during download. Status: %d\n", __func__, res ? res->status : -1);
LOG_ERR("%s: download failed: %s (status: %d)\n",
__func__,
httplib::to_string(res.error()).c_str(),
res ? res->status : -1);
return false;
}
@@ -344,62 +317,64 @@ static int common_download_file_single_online(const std::string & url,
LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
}
for (int i = 0; i < max_attempts; ++i) {
auto head = cli.Head(parts.path);
bool head_ok = head && head->status >= 200 && head->status < 300;
if (!head_ok) {
LOG_WRN("%s: HEAD invalid http status code received: %d\n", __func__, head ? head->status : -1);
if (file_exists) {
LOG_INF("%s: Using cached file (HEAD failed): %s\n", __func__, path.c_str());
return 304; // 304 Not Modified - fake cached response
}
return head->status; // cannot use cached file, return raw status code
// TODO: maybe retry only on certain codes
}
std::string etag;
if (head_ok && head->has_header("ETag")) {
etag = head->get_header_value("ETag");
}
size_t total_size = 0;
if (head_ok && head->has_header("Content-Length")) {
try {
total_size = std::stoull(head->get_header_value("Content-Length"));
} catch (const std::exception& e) {
LOG_WRN("%s: Invalid Content-Length in HEAD response: %s\n", __func__, e.what());
}
}
bool supports_ranges = false;
if (head_ok && head->has_header("Accept-Ranges")) {
supports_ranges = head->get_header_value("Accept-Ranges") != "none";
}
bool should_download_from_scratch = false;
if (!last_etag.empty() && !etag.empty() && last_etag != etag) {
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__,
last_etag.c_str(), etag.c_str());
should_download_from_scratch = true;
}
auto head = cli.Head(parts.path);
if (!head || head->status < 200 || head->status >= 300) {
LOG_WRN("%s: HEAD failed, status: %d\n", __func__, head ? head->status : -1);
if (file_exists) {
if (!should_download_from_scratch) {
LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
return 304; // 304 Not Modified - fake cached response
}
LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
if (remove(path.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
return -1;
}
LOG_INF("%s: using cached file (HEAD failed): %s\n", __func__, path.c_str());
return 304; // 304 Not Modified - fake cached response
}
return head ? head->status : -1;
}
std::string etag;
if (head->has_header("ETag")) {
etag = head->get_header_value("ETag");
}
size_t total_size = 0;
if (head->has_header("Content-Length")) {
try {
total_size = std::stoull(head->get_header_value("Content-Length"));
} catch (const std::exception& e) {
LOG_WRN("%s: invalid Content-Length in HEAD response: %s\n", __func__, e.what());
}
}
bool supports_ranges = false;
if (head->has_header("Accept-Ranges")) {
supports_ranges = head->get_header_value("Accept-Ranges") != "none";
}
if (file_exists) {
if (etag.empty()) {
LOG_INF("%s: using cached file (no server etag): %s\n", __func__, path.c_str());
return 304; // 304 Not Modified - fake cached response
}
if (!last_etag.empty() && last_etag == etag) {
LOG_INF("%s: using cached file (same etag): %s\n", __func__, path.c_str());
return 304; // 304 Not Modified - fake cached response
}
if (remove(path.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
return -1;
}
}
const std::string path_temporary = path + ".downloadInProgress";
int delay = retry_delay_seconds;
for (int i = 0; i < max_attempts; ++i) {
if (i) {
LOG_WRN("%s: retrying after %d seconds...\n", __func__, delay);
std::this_thread::sleep_for(std::chrono::seconds(delay));
delay *= retry_delay_seconds;
}
const std::string path_temporary = path + ".downloadInProgress";
size_t existing_size = 0;
if (std::filesystem::exists(path_temporary)) {
if (supports_ranges && !should_download_from_scratch) {
if (supports_ranges) {
existing_size = std::filesystem::file_size(path_temporary);
} else if (remove(path_temporary.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
@@ -407,32 +382,23 @@ static int common_download_file_single_online(const std::string & url,
}
}
// start the download
LOG_INF("%s: trying to download model from %s to %s (etag:%s)...\n",
__func__, common_http_show_masked_url(parts).c_str(), path_temporary.c_str(), etag.c_str());
const bool was_pull_successful = common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size);
if (!was_pull_successful) {
if (i + 1 < max_attempts) {
const int exponential_backoff_delay = std::pow(retry_delay_seconds, i) * 1000;
LOG_WRN("%s: retrying after %d milliseconds...\n", __func__, exponential_backoff_delay);
std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
} else {
LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
LOG_INF("%s: downloading from %s to %s (etag:%s)...\n",
__func__, common_http_show_masked_url(parts).c_str(),
path_temporary.c_str(), etag.c_str());
if (common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size)) {
if (std::rename(path_temporary.c_str(), path.c_str()) != 0) {
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
return -1;
}
continue;
if (!etag.empty()) {
write_etag(path, etag);
}
return head->status;
}
if (std::rename(path_temporary.c_str(), path.c_str()) != 0) {
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
return -1;
}
if (!etag.empty()) {
write_etag(path, etag);
}
return head->status; // TODO: use actual GET status?
}
LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
return -1; // max attempts reached
}
@@ -798,30 +764,6 @@ std::string common_docker_resolve_model(const std::string & docker) {
}
}
#else
common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool, const common_header_list &) {
throw std::runtime_error("download functionality is not enabled in this build");
}
bool common_download_model(const common_params_model &, const std::string &, bool, const common_header_list &) {
throw std::runtime_error("download functionality is not enabled in this build");
}
std::string common_docker_resolve_model(const std::string &) {
throw std::runtime_error("download functionality is not enabled in this build");
}
int common_download_file_single(const std::string &,
const std::string &,
const std::string &,
bool,
const common_header_list &) {
throw std::runtime_error("download functionality is not enabled in this build");
}
#endif // defined(LLAMA_USE_HTTPLIB)
std::vector<common_cached_model_info> common_list_cached_models() {
std::vector<common_cached_model_info> models;
const std::string cache_dir = fs_get_cache_directory();
+41 -2
View File
@@ -4,6 +4,7 @@
// for converting from JSON to jinja values
#include <nlohmann/json.hpp>
#include <sstream>
#include <string>
#include <cctype>
#include <vector>
@@ -715,8 +716,46 @@ const func_builtins & value_string_t::get_builtins() const {
return args.get_pos(0);
}},
{"tojson", tojson},
{"indent", [](const func_args &) -> value {
throw not_implemented_exception("String indent builtin not implemented");
{"indent", [](const func_args &args) -> value {
args.ensure_count(1, 4);
value val_input = args.get_pos(0);
value val_width = args.get_kwarg_or_pos("width", 1);
const bool first = args.get_kwarg_or_pos("first", 2)->as_bool(); // undefined == false
const bool blank = args.get_kwarg_or_pos("blank", 3)->as_bool(); // undefined == false
if (!is_val<value_string>(val_input)) {
throw raised_exception("indent() first argument must be a string");
}
std::string indent;
if (is_val<value_int>(val_width)) {
indent.assign(val_width->as_int(), ' ');
} else if (is_val<value_string>(val_width)) {
indent = val_width->as_string().str();
} else {
indent = " ";
}
std::string indented;
std::string input = val_input->as_string().str();
std::istringstream iss = std::istringstream(input);
std::string line;
while (std::getline(iss, line)) {
if (!indented.empty()) {
indented.push_back('\n');
}
if ((indented.empty() ? first : (!line.empty() || blank))) {
indented += indent;
}
indented += line;
}
if (!input.empty() && input.back() == '\n') {
indented.push_back('\n');
if (blank) {
indented += indent;
}
}
auto res = mk_val<value_string>(indented);
res->val_str.mark_input_based_on(val_input->as_string());
return res;
}},
{"join", [](const func_args &) -> value {
throw not_implemented_exception("String join builtin not implemented");
+1 -1
View File
@@ -461,7 +461,7 @@ void common_ngram_map_draft(common_ngram_map & map,
slot_max = v;
}
}
// What is sum of the other occurences?
// What is sum of the other occurrences?
uint32_t sum_occur = 0;
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
if (v == slot_max) {
+2 -2
View File
@@ -44,7 +44,7 @@ llama_tokens common_ngram_simple_draft(
// statistics of a m-gram after a known n-gram
struct common_ngram_map_value {
size_t value_idx = 0; // index of value m-gram in token-history (0 if unused)
uint16_t value_num = 0; // number of occurences of this value m-gram after the key n-gram (0 in an unused values-slot)
uint16_t value_num = 0; // number of occurrences of this value m-gram after the key n-gram (0 in an unused values-slot)
int16_t n_accepted = -1; // number of accepted tokens at last draft (-1 if unused)
};
@@ -53,7 +53,7 @@ struct common_ngram_map_key {
size_t key_idx; // index of key n-gram in token-history
size_t stat_idx; // index of last token of stastistics computation (key_num, values)
uint16_t key_num; // number of occurences of this key n-gram in token-history
uint16_t key_num; // number of occurrences of this key n-gram in token-history
common_ngram_map_value values[COMMON_NGRAM_MAX_VALUES]; // some known values after the key
};
+330 -102
View File
@@ -160,8 +160,6 @@ class ModelBase:
self.ftype = gguf.LlamaFileType.MOSTLY_F16
logger.info("heuristics unable to detect tensor dtype, defaulting to --outtype f16")
self.dequant_model()
# Configure GGUF Writer
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
@@ -527,6 +525,8 @@ class ModelBase:
return ()
def prepare_tensors(self):
self.dequant_model()
# Handle empty tensor_map for models with block_count=0 (like MobileNetV5)
if self.tensor_map.mapping:
max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
@@ -570,6 +570,7 @@ class ModelBase:
self.match_model_tensor_name(new_name, key, bid)
for key in (
gguf.MODEL_TENSOR.FFN_GATE_INP,
gguf.MODEL_TENSOR.FFN_GATE_INP_SHEXP,
gguf.MODEL_TENSOR.POS_EMBD,
gguf.MODEL_TENSOR.TOKEN_TYPES,
gguf.MODEL_TENSOR.SSM_CONV1D,
@@ -1048,6 +1049,9 @@ class TextModel(ModelBase):
if chkhsh == "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902":
# ref: https://huggingface.co/zai-org/GLM-4.5-Air
res = "glm4"
if chkhsh == "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267":
# ref: https://huggingface.co/zai-org/GLM-4.7-Flash
res = "glm4"
if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35":
# ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
res = "minerva-7b"
@@ -1081,9 +1085,6 @@ class TextModel(ModelBase):
if chkhsh == "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df":
# ref: https://huggingface.co/aari1995/German_Semantic_V3
res = "jina-v2-de"
if chkhsh == "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267":
# ref: https://huggingface.co/zai-org/GLM-4.7-Flash
res = "glm4"
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
res = "llama-bpe"
@@ -1123,6 +1124,9 @@ class TextModel(ModelBase):
if chkhsh == "9c2227e4dd922002fb81bde4fc02b0483ca4f12911410dee2255e4987644e3f8":
# ref: https://huggingface.co/CohereForAI/c4ai-command-r-v01
res = "command-r"
if chkhsh == "d772b220ace2baec124bed8cfafce0ead7d6c38a4b65ef11261cf9d5d62246d1":
# ref: https://huggingface.co/CohereLabs/tiny-aya-base
res = "tiny_aya"
if chkhsh == "e636dc30a262dcc0d8c323492e32ae2b70728f4df7dfe9737d9f920a282b8aea":
# ref: https://huggingface.co/Qwen/Qwen1.5-7B
res = "qwen2"
@@ -1264,6 +1268,9 @@ class TextModel(ModelBase):
if chkhsh == "d30d75d9059f1aa2c19359de71047b3ae408c70875e8a3ccf8c5fba56c9d8af4":
# ref: https://huggingface.co/Qwen/Qwen3.5-9B-Instruct
res = "qwen35"
if chkhsh == "b4b8ca1f9769494fbd956ebc4c249de6131fb277a4a3345a7a92c7dd7a55808d":
# ref: https://huggingface.co/jdopensource/JoyAI-LLM-Flash
res = "joyai-llm"
if res is None:
logger.warning("\n")
@@ -1608,6 +1615,23 @@ class TextModel(ModelBase):
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"])
special_vocab.add_to_gguf(self.gguf_writer)
def _set_vocab_glm(self):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
tokens, toktypes, tokpre = self.get_vocab_base()
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
# Special tokens
# Note: Using <|endoftext|> (151329) for eot causes endless generation
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329
special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338
special_vocab.add_to_gguf(self.gguf_writer)
def _set_vocab_interns1(self):
tokens: list[str] = []
toktypes: list[int] = []
@@ -1815,7 +1839,7 @@ class MmprojModel(ModelBase):
preprocessor_config: dict[str, Any]
global_config: dict[str, Any]
n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers"]
n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers", "vt_num_hidden_layers"]
has_vision_encoder: bool = True # by default
has_audio_encoder: bool = False
@@ -1870,7 +1894,15 @@ class MmprojModel(ModelBase):
preprocessor_config_path = self.dir_model / "preprocessor_config.json"
if preprocessor_config_path.is_file():
with open(preprocessor_config_path, "r", encoding="utf-8") as f:
self.preprocessor_config = json.load(f)
cfg = json.load(f)
# move media_proc_cfg to root level for compat
if "media_proc_cfg" in cfg:
cfg = {
**cfg,
**cfg["media_proc_cfg"],
}
# merge configs
self.preprocessor_config = {**self.preprocessor_config, **cfg}
# prefer processor_config.json if possible
processor_config_path = self.dir_model / "processor_config.json"
@@ -1919,10 +1951,10 @@ class MmprojModel(ModelBase):
self.image_size = self.find_vparam(["image_size"])
self.gguf_writer.add_vision_image_size(self.image_size)
self.gguf_writer.add_vision_patch_size(self.find_vparam(["patch_size"]))
self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size"]))
self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size"]))
self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size", "vt_hidden_size"]))
self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size", "vt_intermediate_size"]))
self.gguf_writer.add_vision_block_count(self.find_vparam(self.n_block_keys))
self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads", "num_heads"]))
self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads", "num_heads", "vt_num_attention_heads"]))
# preprocessor config
image_mean = _MISTRAL_COMMON_DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"]
@@ -2700,8 +2732,6 @@ class AfmoeModel(LlamaModel):
super().set_gguf_parameters()
# MoE parameters
if (n_experts := self.hparams.get("num_experts")) is not None:
self.gguf_writer.add_expert_count(n_experts)
if (n_shared_experts := self.hparams.get("num_shared_experts")) is not None:
self.gguf_writer.add_expert_shared_count(n_shared_experts)
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
@@ -2723,7 +2753,7 @@ class AfmoeModel(LlamaModel):
# Handle expert weights - they're already merged in the HF format
# process the experts separately
if name.find("mlp.experts") != -1:
n_experts = self.hparams["num_experts"]
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
assert bid is not None
if self._experts is None:
@@ -4048,6 +4078,87 @@ class InternVisionModel(MmprojModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register(
"NemotronH_Nano_VL_V2",
"RADIOModel",
)
class NemotronNanoV2VLModel(MmprojModel):
# ViT-Huge architecture parameters for RADIO v2.5-h
_vit_hidden_size = 1280
_vit_intermediate_size = 5120
_vit_num_layers = 32
_vit_num_heads = 16
def get_vision_config(self) -> dict[str, Any] | None:
# RADIO config doesn't have standard ViT parameters, so they need to be constructed manually
vision_config = self.global_config.get("vision_config")
if vision_config is None:
return None
# Add ViT-H parameters
vision_config = {
**vision_config,
"hidden_size": self._vit_hidden_size,
"intermediate_size": self._vit_intermediate_size,
"num_hidden_layers": self._vit_num_layers,
"num_attention_heads": self._vit_num_heads,
"image_size": self.global_config.get("force_image_size", 512),
}
return vision_config
def set_gguf_parameters(self):
if "image_mean" not in self.preprocessor_config:
self.preprocessor_config["image_mean"] = [0.485, 0.456, 0.406]
if "image_std" not in self.preprocessor_config:
self.preprocessor_config["image_std"] = [0.229, 0.224, 0.225]
super().set_gguf_parameters()
hparams = self.global_config
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.NEMOTRON_V2_VL)
self.gguf_writer.add_vision_attention_layernorm_eps(1e-6)
self.gguf_writer.add_vision_use_gelu(True)
downsample_ratio = hparams.get("downsample_ratio", 0.5)
self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / downsample_ratio))
def tensor_force_quant(self, name, new_name, bid, n_dims):
if ".position_embd." in new_name or "pos_embed" in new_name:
return gguf.GGMLQuantizationType.F32
return super().tensor_force_quant(name, new_name, bid, n_dims)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if "input_conditioner" in name:
return
# RADIO's pos_embed doesn't have .weight suffix, but clip.cpp expects it
if "patch_generator.pos_embed" in name:
if not name.endswith(".weight"):
name += ".weight"
# Downsample position embeddings for fixed 512x512 image size
import torch.nn.functional as F
n_embd = self.hparams["hidden_size"]
image_size = self.global_config.get("force_image_size", 512)
patch_size = self.hparams["patch_size"]
target_patches_per_side = image_size // patch_size # 32
max_patches_per_side = int((data_torch.shape[1]) ** 0.5) # 128
if target_patches_per_side != max_patches_per_side:
# Reshape to grid, interpolate, flatten back
data_torch = data_torch.reshape(1, max_patches_per_side, max_patches_per_side, n_embd)
data_torch = data_torch.permute(0, 3, 1, 2).float() # [1, n_embd, 128, 128]
data_torch = F.interpolate(data_torch, size=(target_patches_per_side, target_patches_per_side),
mode='bilinear', align_corners=True)
data_torch = data_torch.permute(0, 2, 3, 1) # [1, 32, 32, n_embd]
data_torch = data_torch.reshape(1, target_patches_per_side * target_patches_per_side, n_embd)
# Reshape linear patch embedding to conv2d format for ggml_conv_2d
# From [n_embd, patch_size*patch_size*3] to [n_embd, 3, patch_size, patch_size]
if "patch_generator.embedder" in name:
patch_size = self.hparams["patch_size"]
n_embd = self.hparams["hidden_size"]
data_torch = data_torch.reshape(n_embd, 3, patch_size, patch_size)
if name.startswith("vision_model.radio_model.model.") or name.startswith("mlp1."):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("WavTokenizerDec")
class WavTokenizerDecModel(TextModel):
model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC
@@ -4090,8 +4201,6 @@ class Qwen2MoeModel(TextModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
if (n_experts := self.hparams.get("num_experts")) is not None:
self.gguf_writer.add_expert_count(n_experts)
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}")
@@ -4136,7 +4245,7 @@ class Qwen2MoeModel(TextModel):
return
if name.find("experts") != -1:
n_experts = self.hparams["num_experts"]
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
assert bid is not None
if self._experts is None:
@@ -4475,7 +4584,7 @@ class Qwen3VLVisionModel(MmprojModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Glm4vForConditionalGeneration", "Glm4vMoeForConditionalGeneration")
@ModelBase.register("Glm4vForConditionalGeneration", "Glm4vMoeForConditionalGeneration", "GlmOcrForConditionalGeneration")
class Glm4VVisionModel(Qwen3VLVisionModel):
def set_gguf_parameters(self):
MmprojModel.set_gguf_parameters(self) # skip Qwen3VLVisionModel parameters
@@ -4887,13 +4996,13 @@ class PhiMoeModel(Phi3MiniModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"])
self.gguf_writer.add_expert_count(self.hparams["num_local_experts"])
self.gguf_writer.add_expert_used_count(self.find_hparam(["num_experts_per_tok", "num_experts_per_token"]))
self.gguf_writer.add_expert_count(self.find_hparam(["num_local_experts", "num_experts"]))
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# process the experts separately
if name.find("block_sparse_moe.experts") != -1:
n_experts = self.hparams["num_local_experts"]
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
assert bid is not None
if self._experts is None:
@@ -5305,7 +5414,7 @@ class KimiLinearModel(TextModel):
# process the experts separately
if name.find("block_sparse_moe.experts") != -1:
n_experts = self.find_hparam(["num_local_experts", "num_experts"], optional=False)
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
assert bid is not None
if self._experts is None:
@@ -5900,12 +6009,13 @@ class NomicBertModel(BertModel):
if "mlp.experts.bias" in name:
return # Explicitly return.
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
if "mlp.experts.mlp.w1" in name:
data_torch = data_torch.view(self.hparams["num_experts"], self.hparams["n_inner"], self.hparams["n_embd"])
data_torch = data_torch.view(n_experts, self.hparams["n_inner"], self.hparams["n_embd"])
name += ".weight"
if "mlp.experts.mlp.w2" in name:
data_torch = data_torch.view(self.hparams["num_experts"], self.hparams["n_inner"], self.hparams["n_embd"])
data_torch = data_torch.view(n_experts, self.hparams["n_inner"], self.hparams["n_embd"])
data_torch = data_torch.transpose(1, 2)
name += ".weight"
@@ -5915,7 +6025,6 @@ class NomicBertModel(BertModel):
super().set_gguf_parameters()
if self.is_moe:
self.gguf_writer.add_moe_every_n_layers(self.hparams["moe_every_n_layers"])
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
self.gguf_writer.add_expert_used_count(self.hparams["moe_top_k"])
def _is_tokenizer_xlmroberta(self) -> bool:
@@ -7029,6 +7138,8 @@ class Mamba2Model(TextModel):
if hparams is None:
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
hparams = json.load(f)
if "llm_config" in hparams:
hparams["text_config"] = hparams["llm_config"]
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
self.d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
self.d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or 2 * self.d_model
@@ -7150,8 +7261,8 @@ class JambaModel(TextModel):
self.gguf_writer.add_ssm_state_size(d_state)
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"])
self.gguf_writer.add_expert_count(self.find_hparam(["num_local_experts", "num_experts"]))
self.gguf_writer.add_expert_used_count(self.find_hparam(["num_experts_per_tok", "num_experts_per_token"]))
self.gguf_writer.add_file_type(self.ftype)
_experts: list[dict[str, Tensor]] | None = None
@@ -7169,7 +7280,7 @@ class JambaModel(TextModel):
# process the experts separately
if ".feed_forward.experts." in name:
n_experts = self.hparams["num_experts"]
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
assert bid is not None
@@ -7255,6 +7366,17 @@ class Cohere2Model(TextModel):
self.gguf_writer.add_rope_dimension_count(int(rotary_pct * (hidden_size // num_attention_heads)))
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Cohere2 runtime in llama.cpp expects no bias tensors;
# the actual weight only contains 0-value tensors as bias, we can skip them
if name.endswith(".bias"):
if torch.any(data_torch != 0):
raise ValueError(f"Bias tensor {name!r} is not zero.")
logger.debug(f"Skipping bias tensor {name!r} for Cohere2 conversion.")
return
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("OlmoForCausalLM")
@ModelBase.register("OLMoForCausalLM")
@@ -7317,8 +7439,6 @@ class OlmoeModel(TextModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_layer_norm_rms_eps(1e-5)
if (n_experts := self.hparams.get("num_experts")) is not None:
self.gguf_writer.add_expert_count(n_experts)
_experts: list[dict[str, Tensor]] | None = None
@@ -7326,7 +7446,7 @@ class OlmoeModel(TextModel):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# process the experts separately
if name.find("experts") != -1:
n_experts = self.hparams["num_experts"]
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
assert bid is not None
if self._experts is None:
@@ -7695,12 +7815,16 @@ class DeepseekModel(TextModel):
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"KimiVLForConditionalGeneration",
"KimiK25ForConditionalGeneration",
"YoutuForCausalLM",
"YoutuVLForConditionalGeneration",
)
class DeepseekV2Model(TextModel):
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
# TODO @ngxson : remove this when we support MTP for deepseek models
skip_mtp = True
def set_vocab(self):
try:
self._set_vocab_gpt2()
@@ -7813,8 +7937,8 @@ class DeepseekV2Model(TextModel):
_experts: list[dict[str, Tensor]] | None = None
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# skip vision tensors and remove "language_model." for Kimi-VL
if "vision_tower" in name or "multi_modal_projector" in name:
# skip vision tensors and remove "language_model." for Kimi-VL and Kimi-K2.5
if "vision_tower" in name or "multi_modal_projector" in name or "mm_projector" in name:
return
if name.startswith("siglip2.") or name.startswith("merger."):
return
@@ -7832,10 +7956,11 @@ class DeepseekV2Model(TextModel):
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
# skip Multi-Token Prediction (MTP) layers
block_count = self.hparams["num_hidden_layers"]
match = re.match(r"model.layers.(\d+)", name)
if match and int(match.group(1)) >= block_count:
return
if self.skip_mtp:
block_count = self.hparams["num_hidden_layers"]
match = re.match(r"model.layers.(\d+)", name)
if match and int(match.group(1)) >= block_count:
return
# process the experts separately
if name.find("mlp.experts") != -1:
@@ -7902,10 +8027,6 @@ class MiniMaxM2Model(TextModel):
model_arch = gguf.MODEL_ARCH.MINIMAXM2
_experts_cache: dict[int, dict[str, Tensor]] = {}
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hparams["num_experts"] = self.hparams["num_local_experts"]
def set_gguf_parameters(self):
super().set_gguf_parameters()
@@ -7918,7 +8039,7 @@ class MiniMaxM2Model(TextModel):
# merge expert weights
if 'experts' in name:
n_experts = self.hparams["num_experts"]
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
assert bid is not None
expert_cache = self._experts_cache.setdefault(bid, {})
@@ -8655,7 +8776,7 @@ class Glm4Model(TextModel):
n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams["num_key_value_heads"]
n_embd = self.hparams["hidden_size"]
head_dim = n_embd // n_head
head_dim = self.hparams.get("head_dim", n_embd // n_head)
# because llama.cpp M-RoPE kernel only supports Neox ordering, we have to permute the weights here
if name.endswith(("q_proj.weight", "q_proj.bias")):
data_torch = Glm4Model.normal_to_neox(data_torch, n_head, n_head, head_dim, self.partial_rotary_factor)
@@ -8664,6 +8785,27 @@ class Glm4Model(TextModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("GlmOcrForConditionalGeneration")
class GlmOCRModel(Glm4Model):
model_arch = gguf.MODEL_ARCH.GLM4
use_mrope = False
partial_rotary_factor = 0.5
# Note: GLM-OCR is the same as GLM4, but with an extra NextN/MTP prediction layer
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# GLM-OCR has num_hidden_layers + 1 actual layers (including NextN layer)
self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("num_nextn_predict_layers", 0)
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
def set_gguf_parameters(self):
super().set_gguf_parameters()
# NextN/MTP prediction layers
if (num_nextn_predict_layers := self.hparams.get("num_nextn_predict_layers")) is not None:
self.gguf_writer.add_nextn_predict_layers(num_nextn_predict_layers)
@ModelBase.register("Glm4MoeForCausalLM", "Glm4vMoeForConditionalGeneration")
class Glm4MoeModel(TextModel):
model_arch = gguf.MODEL_ARCH.GLM4_MOE
@@ -8675,24 +8817,7 @@ class Glm4MoeModel(TextModel):
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
def set_vocab(self):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
tokens, toktypes, tokpre = self.get_vocab_base()
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
# Special tokens
# Note: Using <|endoftext|> (151329) for eot causes endless generation
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329
special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338
special_vocab.add_to_gguf(self.gguf_writer)
return self._set_vocab_glm()
def set_gguf_parameters(self):
super().set_gguf_parameters()
@@ -8792,26 +8917,38 @@ class Glm4MoeModel(TextModel):
class Glm4MoeLiteModel(DeepseekV2Model):
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
# copied from Glm4MoeModel
def set_vocab(self):
from transformers import AutoTokenizer
return self._set_vocab_glm()
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
tokens, toktypes, tokpre = self.get_vocab_base()
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
# Special tokens
# Note: Using <|endoftext|> (151329) for eot causes endless generation
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329
special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338
@ModelBase.register("GlmMoeDsaForCausalLM")
class GlmMoeDsaModel(DeepseekV2Model):
model_arch = gguf.MODEL_ARCH.GLM_DSA
skip_mtp = False
special_vocab.add_to_gguf(self.gguf_writer)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("num_nextn_predict_layers", 0)
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
def set_vocab(self):
return self._set_vocab_glm()
def set_gguf_parameters(self):
super().set_gguf_parameters()
rope_dim = self.hparams["qk_rope_head_dim"]
partial_rotary_factor = self.hparams.get("partial_rotary_factor", 1.0)
self.gguf_writer.add_rope_dimension_count(int(rope_dim * partial_rotary_factor))
# NextN/MTP prediction layers
if (num_nextn_predict_layers := self.hparams.get("num_nextn_predict_layers")) is not None:
self.gguf_writer.add_nextn_predict_layers(num_nextn_predict_layers)
# DSA indexer parameters
self.gguf_writer.add_indexer_head_count(self.hparams["index_n_heads"])
self.gguf_writer.add_indexer_key_length(self.hparams["index_head_dim"])
self.gguf_writer.add_indexer_top_k(self.hparams["index_topk"])
@ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration")
@@ -9128,7 +9265,6 @@ class ExaoneMoEModel(Exaone4Model):
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
moe_intermediate_size = self.hparams["moe_intermediate_size"]
num_shared_experts = self.hparams["num_shared_experts"]
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
@@ -9169,7 +9305,7 @@ class ExaoneMoEModel(Exaone4Model):
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
if name.find("mlp.experts") != -1:
n_experts = self.hparams["num_experts"]
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
assert bid is not None
if self._experts is None:
@@ -9320,7 +9456,7 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel):
# case, the model architecture needs to be updated to a standard
# "granite" or "granitemoe" model
if not self._ssm_layers:
has_experts = self.find_hparam(["num_experts_per_tok"], optional=True)
has_experts = self.find_hparam(["num_experts_per_tok", "num_experts_per_token"], optional=True)
new_arch = (
gguf.MODEL_ARCH.GRANITE_MOE
if has_experts else
@@ -9516,6 +9652,14 @@ class NemotronHModel(GraniteHybridModel):
self.gguf_writer.add_add_bos_token(True)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Skip vision model and projector tensors for VLM models (handled by mmproj) (e.g., Nemotron Nano 12B v2 VL)
if name.startswith(("vision_model.", "mlp1.")):
return
# Strip language_model. prefix for VLM models (e.g., Nemotron Nano 12B v2 VL)
if name.startswith("language_model."):
name = name[len("language_model."):]
if self.is_moe and bid is not None:
if name.endswith("mixer.gate.e_score_correction_bias"):
new_name = name.replace("e_score_correction_bias", "e_score_correction.bias")
@@ -9610,7 +9754,6 @@ class BailingMoeModel(TextModel):
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
self.gguf_writer.add_expert_weights_scale(1.0)
self.gguf_writer.add_expert_count(hparams["num_experts"])
self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"])
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
@@ -9644,7 +9787,7 @@ class BailingMoeModel(TextModel):
yield from super().modify_tensors(v,self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), bid)
return
elif name.find("mlp.experts") != -1:
n_experts = self.hparams["num_experts"]
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
assert bid is not None
if self._experts is None:
@@ -9715,7 +9858,6 @@ class BailingMoeV2Model(TextModel):
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
self.gguf_writer.add_expert_shared_feed_forward_length(hparams.get("moe_shared_expert_intermediate_size", hparams["moe_intermediate_size"] * hparams["num_shared_experts"]))
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
self.gguf_writer.add_expert_count(hparams["num_experts"])
self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"])
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
@@ -9726,7 +9868,7 @@ class BailingMoeV2Model(TextModel):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if "mlp.experts" in name:
n_experts = self.hparams["num_experts"]
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
assert bid is not None
if self._experts is None:
@@ -9772,8 +9914,6 @@ class GroveMoeModel(TextModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
if (n_experts := self.hparams.get("num_experts")) is not None:
self.gguf_writer.add_expert_count(n_experts)
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}")
@@ -9794,7 +9934,7 @@ class GroveMoeModel(TextModel):
# process the experts separately
if name.find("chunk_experts") != -1:
n_experts = self.hparams["num_experts"] // 2 # see add_experts_per_group
n_experts = self.find_hparam(["num_local_experts", "num_experts"]) // 2 # see add_experts_per_group
assert bid is not None
if self._chunk_experts is None:
@@ -9821,7 +9961,7 @@ class GroveMoeModel(TextModel):
else:
return
elif name.find("experts") != -1:
n_experts = self.hparams["num_experts"]
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
assert bid is not None
if self._experts is None:
@@ -10214,7 +10354,6 @@ class HunYuanMoEModel(TextModel):
super().set_gguf_parameters()
hparams = self.hparams
self.gguf_writer.add_expert_count(hparams["num_experts"])
self.gguf_writer.add_expert_shared_feed_forward_length(hparams["intermediate_size"])
moe_intermediate_size = hparams["moe_intermediate_size"]
@@ -10257,7 +10396,7 @@ class HunYuanMoEModel(TextModel):
return
if name.find("mlp.experts") != -1:
n_experts = self.hparams["num_experts"]
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
assert bid is not None
if self._experts is None:
@@ -10299,16 +10438,9 @@ class LLaDAMoEModel(TextModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
if (n_experts := self.hparams.get("num_experts")) is not None:
self.gguf_writer.add_expert_count(n_experts)
if (expert_intermediate_size := self.hparams.get("expert_intermediate_size")) is not None:
self.gguf_writer.add_expert_feed_forward_length(expert_intermediate_size)
# number of experts used per token (top-k)
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
self.gguf_writer.add_expert_used_count(n_experts_used)
self.gguf_writer.add_mask_token_id(156895)
self.gguf_writer.add_causal_attention(False)
self.gguf_writer.add_diffusion_shift_logits(False)
@@ -10319,7 +10451,7 @@ class LLaDAMoEModel(TextModel):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# process the experts separately
if name.find("experts") != -1:
n_experts = self.hparams["num_experts"]
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
assert bid is not None
if self._experts is None:
@@ -10656,7 +10788,6 @@ class LFM2MoeModel(TextModel):
super().set_gguf_parameters()
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"])
self.gguf_writer.add_leading_dense_block_count(self.hparams["num_dense_layers"])
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
@@ -10677,7 +10808,7 @@ class LFM2MoeModel(TextModel):
# merge expert weights
if 'experts' in name:
n_experts = self.hparams["num_experts"]
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
assert bid is not None
expert_cache = self._experts_cache.setdefault(bid, {})
@@ -10787,9 +10918,9 @@ class SmallThinkerModel(TextModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
if (n_experts := self.hparams.get("num_experts", self.hparams.get("moe_num_primary_experts"))) is not None:
if (n_experts := self.hparams.get("moe_num_primary_experts")) is not None:
self.gguf_writer.add_expert_count(n_experts)
if (n_experts_used := self.hparams.get("num_experts_per_tok", self.hparams.get("moe_num_active_primary_experts"))) is not None:
if (n_experts_used := self.hparams.get("moe_num_active_primary_experts")) is not None:
self.gguf_writer.add_expert_used_count(n_experts_used)
if (moe_intermediate_size := self.hparams.get("moe_ffn_hidden_size")) is not None:
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
@@ -10814,7 +10945,7 @@ class SmallThinkerModel(TextModel):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# process the experts separately
if name.find("experts") != -1:
n_experts = self.hparams.get("num_experts", self.hparams.get("moe_num_primary_experts"))
n_experts = self.hparams.get("moe_num_primary_experts") or self.find_hparam(["num_local_experts", "num_experts"])
assert bid is not None
if self._experts is None:
@@ -11176,6 +11307,103 @@ class KimiVLModel(MmprojModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("KimiK25ForConditionalGeneration")
class KimiK25Model(MmprojModel):
"""Kimi-K2.5 with MoonViT3d vision encoder"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hparams_vision is not None, "Kimi-K2.5 requires vision_config in model config"
self.merge_kernel_size = tuple(self.hparams_vision.get("merge_kernel_size", [2, 2]))
self.patch_size = self.hparams_vision.get("patch_size", 14)
# Set image_size for compatibility with base class
# Use position embedding dimensions as image_size reference
pos_emb_h = self.hparams_vision.get("init_pos_emb_height", 64)
self.hparams_vision["image_size"] = pos_emb_h * self.patch_size
def set_gguf_parameters(self):
# Base class MmprojModel.set_gguf_parameters() already writes:
# - vision_block_count, vision_head_count, vision_embedding_length
# - vision_feed_forward_length, vision_patch_size, image_mean, image_std
# via find_vparam() which handles the vt_* prefixed keys in Kimi-K2.5's config
super().set_gguf_parameters()
assert self.hparams_vision is not None
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.KIMIK25)
# Position embedding parameters (for interpolation)
self.gguf_writer.add_uint32("vision.pos_emb_height", self.hparams_vision.get("init_pos_emb_height", 64))
self.gguf_writer.add_uint32("vision.pos_emb_width", self.hparams_vision.get("init_pos_emb_width", 64))
self.gguf_writer.add_uint32("vision.pos_emb_time", self.hparams_vision.get("init_pos_emb_time", 4))
# Projector parameters
self.gguf_writer.add_vision_use_gelu(self.hparams_vision.get("projector_hidden_act", "gelu") == "gelu")
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("projector_ln_eps", 1e-5))
self.gguf_writer.add_vision_projector_scale_factor(self.merge_kernel_size[0])
# Image size limits
# Note: in_patch_limit is for images, in_patch_limit_each_frame is for video (not supported yet)
in_patch_limit = self.preprocessor_config.get("in_patch_limit", 16384)
min_patches = 8 # reasonable minimum
pixels_per_patch = self.patch_size ** 2
self.gguf_writer.add_vision_min_pixels(min_patches * pixels_per_patch)
self.gguf_writer.add_vision_max_pixels(in_patch_limit * pixels_per_patch)
@staticmethod
def permute(weights: Tensor, n_head: int) -> Tensor:
out_dim, in_dim = weights.shape
head_dim = out_dim // n_head
w = weights.reshape(n_head, head_dim // 4, 2, 2, in_dim)
w = w.permute(0, 2, 1, 3, 4)
return w.reshape(out_dim, in_dim)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Only process vision and projector tensors
is_vision = any(x in name for x in ["vision_tower", "mm_projector"])
if not is_vision:
return
assert self.hparams_vision is not None
n_head = self.hparams_vision.get("num_attention_heads", 16)
# Permute Q/K weights/biases from interleaved to split RoPE format
# This allows using build_rope_2d at runtime without post-permutation.
if "wqkv" in name:
out_dim = data_torch.shape[0]
qkv_dim = out_dim // 3
head_dim = qkv_dim // n_head
if "weight" in name:
wq, wk, wv = data_torch[:qkv_dim, :], data_torch[qkv_dim:2 * qkv_dim, :], data_torch[2 * qkv_dim:, :]
wq = self.permute(wq, n_head)
wk = self.permute(wk, n_head)
data_torch = torch.cat([wq, wk, wv], dim=0)
elif "bias" in name:
bq, bk, bv = data_torch[:qkv_dim], data_torch[qkv_dim:2 * qkv_dim], data_torch[2 * qkv_dim:]
bq = bq.reshape(n_head, head_dim // 4, 2, 2).permute(0, 2, 1, 3).reshape(-1)
bk = bk.reshape(n_head, head_dim // 4, 2, 2).permute(0, 2, 1, 3).reshape(-1)
data_torch = torch.cat([bq, bk, bv], dim=0)
# Temporal embeddings: (T, 1, C) → (T, C)
if "pos_emb.time_weight" in name:
T, _, C = data_torch.shape
data_torch = data_torch.reshape(T, C)
# PatchMergerMLP tensor name mapping
# proj.0.weight → proj.linear_1.weight
# proj.2.weight → proj.linear_2.weight
if "mm_projector.proj.0." in name:
name = name.replace(".proj.0.", ".proj.linear_1.")
elif "mm_projector.proj.2." in name:
name = name.replace(".proj.2.", ".proj.linear_2.")
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("CogVLMForCausalLM")
class CogVLMVisionModel(MmprojModel):
+4 -2
View File
@@ -99,6 +99,7 @@ models = [
{"name": "stablelm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b", },
{"name": "refact", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/smallcloudai/Refact-1_6-base", },
{"name": "command-r", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/CohereForAI/c4ai-command-r-v01", },
{"name": "tiny_aya", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/CohereLabs/tiny-aya-base", },
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", },
{"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", },
{"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", },
@@ -148,7 +149,8 @@ models = [
{"name": "youtu", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Youtu-LLM-2B", },
{"name": "solar-open", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/upstage/Solar-Open-100B", },
{"name": "exaone-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/K-EXAONE-236B-A23B", },
{"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.5-9B-Instruct", }
{"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.5-9B-Instruct", },
{"name": "joyai-llm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jdopensource/JoyAI-LLM-Flash", },
]
# some models are known to be broken upstream, so we will skip them as exceptions
@@ -158,6 +160,7 @@ pre_computed_hashes = [
{"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"},
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"},
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.5-Air", "chkhsh": "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902"},
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.7-Flash", "chkhsh": "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267"},
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"},
{"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"},
{"name": "hunyuan-dense", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-4B-Instruct", "chkhsh": "bba3b3366b646dbdded5dbc42d59598b849371afc42f7beafa914afaa5b70aa6"},
@@ -171,7 +174,6 @@ pre_computed_hashes = [
{"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"},
# jina-v2-de variants
{"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/aari1995/German_Semantic_V3", "chkhsh": "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df"},
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.7-Flash", "chkhsh": "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267"},
]
+1 -1
View File
@@ -246,7 +246,7 @@ cmake --build build --config release
1. **Retrieve and prepare model**
You can refer to the general [*Prepare and Quantize*](../../README.md#prepare-and-quantize) guide for model prepration.
You can refer to the general [*Obtaining and quantizing models*](../../README.md#obtaining-and-quantizing-models) guide for model prepration.
**Notes**:
+2 -2
View File
@@ -281,7 +281,7 @@ as `-cl-fp32-correctly-rounded-divide-sqrt`
#### Retrieve and prepare model
You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/resolve/main/llama-2-7b.Q4_0.gguf?download=true) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf).
You can refer to the general [*Obtaining and quantizing models*](../../README.md#obtaining-and-quantizing-models) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/resolve/main/llama-2-7b.Q4_0.gguf?download=true) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf).
##### Check device
@@ -569,7 +569,7 @@ Once it is completed, final results will be in **build/Release/bin**
#### Retrieve and prepare model
You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/blob/main/llama-2-7b.Q4_0.gguf) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf).
You can refer to the general [*Obtaining and quantizing models*](../../README.md#obtaining-and-quantizing-models) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/blob/main/llama-2-7b.Q4_0.gguf) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf).
##### Check device
+1 -1
View File
@@ -35,7 +35,7 @@ Adapt below build commands accordingly.
Let's build llama.cpp with CPU, OpenCL, and Hexagon backends via CMake presets:
```
[d]/workspace> cp docs/backend/hexagon/CMakeUserPresets.json .
[d]/workspace> cp docs/backend/snapdragon/CMakeUserPresets.json .
[d]/workspace> cmake --preset arm64-android-snapdragon-release -B build-snapdragon
Preset CMake variables:
+3 -3
View File
@@ -242,10 +242,10 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl
|------------|-------------|------|-------|
| FP32 | ✅ | ✅ | ❓ |
| FP16 | ✅ | ✅ | ❓ |
| BF16 | 🚫 | ✅ | ❓ |
| BF16 | | ✅ | ❓ |
| Q4_0 | ✅ | ❓ | ❓ |
| Q4_1 | ✅ | ❓ | ❓ |
| MXFP4 | 🚫 | ❓ | ❓ |
| MXFP4 | | ❓ | ❓ |
| Q5_0 | ✅ | ❓ | ❓ |
| Q5_1 | ✅ | ❓ | ❓ |
| Q8_0 | ✅ | ❓ | ❓ |
@@ -272,4 +272,4 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl
- 🚫 - acceleration unavailable, will still run using scalar implementation
- ❓ - acceleration unknown, please contribute if you can test it yourself
Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on Sep 7, 2025.
Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on Feb 15, 2026.
@@ -42,11 +42,15 @@ def load_model_and_tokenizer(model_path, device="auto"):
config = config.text_config
multimodal = True
print("Vocab size: ", config.vocab_size)
print("Hidden size: ", config.hidden_size)
print("Number of layers: ", config.num_hidden_layers)
print("BOS token id: ", config.bos_token_id)
print("EOS token id: ", config.eos_token_id)
def print_if_exists(label, obj, attr, default="N/A"):
val = getattr(obj, attr) if hasattr(obj, attr) else default
print(f"{label}", val)
print_if_exists("Vocab size: ", config, "vocab_size")
print_if_exists("Hidden size: ", config, "hidden_size")
print_if_exists("Number of layers: ", config, "num_hidden_layers")
print_if_exists("BOS token id: ", config, "bos_token_id")
print_if_exists("EOS token id: ", config, "eos_token_id")
unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
if unreleased_model_name:
@@ -78,7 +78,7 @@ def list_all_tensors(model_path: Path, unique: bool = False):
print(tensor_name)
def print_tensor_info(model_path: Path, tensor_name: str):
def print_tensor_info(model_path: Path, tensor_name: str, num_values: Optional[int] = None):
tensor_file = find_tensor_file(model_path, tensor_name)
if tensor_file is None:
@@ -96,6 +96,12 @@ def print_tensor_info(model_path: Path, tensor_name: str):
print(f"Tensor: {tensor_name}")
print(f"File: {tensor_file}")
print(f"Shape: {shape}")
if num_values is not None:
tensor = f.get_tensor(tensor_name)
print(f"Dtype: {tensor.dtype}")
flat = tensor.flatten()
n = min(num_values, flat.numel())
print(f"Values: {flat[:n].tolist()}")
else:
print(f"Error: Tensor '{tensor_name}' not found in {tensor_file}")
sys.exit(1)
@@ -127,6 +133,15 @@ def main():
action="store_true",
help="List unique tensor patterns in the model (layer numbers replaced with #)"
)
parser.add_argument(
"-n", "--num-values",
nargs="?",
const=10,
default=None,
type=int,
metavar="N",
help="Print the first N values of the tensor flattened (default: 10 if flag is given without a number)"
)
args = parser.parse_args()
@@ -152,7 +167,7 @@ def main():
if args.tensor_name is None:
print("Error: tensor_name is required when not using --list")
sys.exit(1)
print_tensor_info(model_path, args.tensor_name)
print_tensor_info(model_path, args.tensor_name, args.num_values)
if __name__ == "__main__":
+1 -1
View File
@@ -4,7 +4,7 @@ project("ggml" C CXX ASM)
### GGML Version
set(GGML_VERSION_MAJOR 0)
set(GGML_VERSION_MINOR 9)
set(GGML_VERSION_PATCH 5)
set(GGML_VERSION_PATCH 7)
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)
+1
View File
@@ -752,6 +752,7 @@ extern "C" {
GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor);
GGML_API bool ggml_is_permuted (const struct ggml_tensor * tensor);
GGML_API bool ggml_is_empty (const struct ggml_tensor * tensor);
GGML_API bool ggml_is_view (const struct ggml_tensor * tensor);
GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor);
GGML_API bool ggml_is_vector (const struct ggml_tensor * tensor);
GGML_API bool ggml_is_matrix (const struct ggml_tensor * tensor);
+4 -9
View File
@@ -17,11 +17,6 @@
//#define AT_PRINTF(...) GGML_LOG_DEBUG(__VA_ARGS__)
#define AT_PRINTF(...)
static bool ggml_is_view(const struct ggml_tensor * t) {
return t->view_src != NULL;
}
// ops that return true for this function must not use restrict pointers for their backend implementations
bool ggml_op_can_inplace(enum ggml_op op) {
switch (op) {
@@ -627,7 +622,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
GGML_ASSERT(buffer_id >= 0);
struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_is_view(node)) {
if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_impl_is_view(node)) {
hn->allocated = true;
assert(hn->addr.offset == 0);
@@ -658,7 +653,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent);
if (p_hn->n_children == 1 && p_hn->n_views == 0) {
if (ggml_is_view(parent)) {
if (ggml_impl_is_view(parent)) {
struct ggml_tensor * view_src = parent->view_src;
struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src);
if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
@@ -739,7 +734,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
// GGML_OP_NONE does not appear normally in the graph nodes, but is used by ggml-backend to add dependencies to
// control when some tensors are allocated and freed. in this case, the dependencies are in `src`, but the node
// itself is never used and should not be considered a dependency
if (ggml_is_view(node) && node->op != GGML_OP_NONE) {
if (ggml_impl_is_view(node) && node->op != GGML_OP_NONE) {
struct ggml_tensor * view_src = node->view_src;
ggml_gallocr_hash_get(galloc, view_src)->n_views += 1;
}
@@ -806,7 +801,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
parent->name, p_hn->n_children, p_hn->n_views, p_hn->allocated);
if (p_hn->n_children == 0 && p_hn->n_views == 0) {
if (ggml_is_view(parent)) {
if (ggml_impl_is_view(parent)) {
struct ggml_tensor * view_src = parent->view_src;
struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src);
view_src_hn->n_views -= 1;
+9 -7
View File
@@ -9,6 +9,11 @@ function(ggml_add_cpu_backend_features cpu_name arch)
target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE ${ARGN})
target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE GGML_BACKEND_DL GGML_BACKEND_BUILD GGML_BACKEND_SHARED)
set_target_properties(${GGML_CPU_FEATS_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
# Disable LTO for the feature detection code to prevent cross-module optimization
# from inlining architecture-specific instructions into the score function.
# Without this, LTO can cause SIGILL when loading backends on older CPUs
# (e.g., loading power10 backend on power9 crashes before feature check runs).
target_compile_options(${GGML_CPU_FEATS_NAME} PRIVATE -fno-lto)
target_link_libraries(${cpu_name} PRIVATE ${GGML_CPU_FEATS_NAME})
endfunction()
@@ -569,27 +574,24 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
cmake_policy(SET CMP0135 NEW)
endif()
# TODO: Use FetchContent_MakeAvailable with EXCLUDE_FROM_ALL after bumping minimum CMake version to 3.28+
# Using FetchContent_Populate instead to avoid EXCLUDE_FROM_ALL which requires CMake 3.28
FetchContent_Declare(KleidiAI_Download
URL ${KLEIDIAI_DOWNLOAD_URL}
DOWNLOAD_EXTRACT_TIMESTAMP NEW
URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5})
FetchContent_MakeAvailable(KleidiAI_Download)
FetchContent_GetProperties(KleidiAI_Download
SOURCE_DIR KLEIDIAI_SRC
POPULATED KLEIDIAI_POPULATED)
if (NOT KLEIDIAI_POPULATED)
message(FATAL_ERROR "KleidiAI source downloaded failed.")
FetchContent_Populate(KleidiAI_Download)
FetchContent_GetProperties(KleidiAI_Download SOURCE_DIR KLEIDIAI_SRC)
endif()
add_compile_definitions(GGML_USE_CPU_KLEIDIAI)
# Remove kleidiai target after fetching it
if (TARGET kleidiai)
set_target_properties(kleidiai PROPERTIES EXCLUDE_FROM_ALL TRUE)
endif()
list(APPEND GGML_CPU_SOURCES
ggml-cpu/kleidiai/kleidiai.cpp
ggml-cpu/kleidiai/kernels.cpp
+310
View File
@@ -3226,6 +3226,316 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
UNUSED(ncols_interleaved);
UNUSED(blocklen);
#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
if (svcntb() * 8 == 256) {
constexpr int q8_k_blocklen = 4;
const svuint8_t m4b_1 = svdup_n_u8(0x0f);
// 8 accumulators: 2 row pairs × 4 col pairs
svfloat32_t acc_f32_01, acc_f32_23, acc_f32_45, acc_f32_67;
uint32_t idx_arr[8] = { 0, 2, 4, 6, 1, 3, 5, 7 };
svbool_t pg = svptrue_pat_b32(SV_VL8);
svuint32_t idx = svld1(pg, idx_arr);
static const uint32_t idx_data[8] = {0, 4, 2, 6, 1, 5, 3, 7};
svuint32_t idx1 = svld1_u32(svptrue_b32(), idx_data);
for (int y = 0; y < nr / q8_k_blocklen; y++) {
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
acc_f32_01 = svdup_n_f32(0);
acc_f32_23 = svdup_n_f32(0);
acc_f32_45 = svdup_n_f32(0);
acc_f32_67 = svdup_n_f32(0);
for (int b = 0; b < nb; b++) {
// bsums pairs belongs to the same q8_k subblock
// 64 elemnts loaded and made sum of 0-7 and 8-15 sum || 16-23 and 24 - 31 sum
const int16x8_t bsums[4]{
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
};
int32_t bsums_arr32[4][8];
for (int q8_row = 0; q8_row < 4; q8_row++) {
int16x8_t v16 = bsums[q8_row];
// low 4
int32x4_t v32_lo = vmovl_s16(vget_low_s16(v16));
vst1q_s32(&bsums_arr32[q8_row][0], v32_lo);
// high 4
int32x4_t v32_hi = vmovl_s16(vget_high_s16(v16));
vst1q_s32(&bsums_arr32[q8_row][4], v32_hi);
}
svint32_t sb_acc_0 = svdup_n_s32(0);
svint32_t sb_acc_2 = svdup_n_s32(0);
svint32_t acc_00 = svdup_n_s32(0);
svint32_t acc_11 = svdup_n_s32(0);
svint32_t acc_22 = svdup_n_s32(0);
svint32_t acc_33 = svdup_n_s32(0);
svint32_t acc_44 = svdup_n_s32(0);
svint32_t acc_55 = svdup_n_s32(0);
svint32_t acc_66 = svdup_n_s32(0);
svint32_t acc_77 = svdup_n_s32(0);
svint32_t bias_acc_00 = svdup_n_s32(0);
svint32_t bias_acc_22 = svdup_n_s32(0);
svint32_t bias_acc_44 = svdup_n_s32(0);
svint32_t bias_acc_66 = svdup_n_s32(0);
for (int sb = 0; sb < QK_K / 64; sb++) {
// Need scales for the low and high nibbles
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
svint32_t block_scale_0, block_scale_1, block_scale_2, block_scale_3;
svint32_t q4sb_mins_0, q4sb_mins_1;
{
// 2-superblock I am working on
const int offset = sb * 24 + 0 * 12;
const uint8_t * scales_in = &q4_ptr[b].scales[offset];
const int offset1 = sb * 24 + 12;
const uint8_t * scales_in1 = &q4_ptr[b].scales[offset1];
constexpr uint32_t kmask1 = 0x3f3f3f3f;
constexpr uint32_t kmask2 = 0x0f0f0f0f;
constexpr uint32_t kmask3 = 0x03030303;
constexpr uint8_t scales_size = 12;
uint32_t sm[3];
memcpy(sm, scales_in, scales_size);
uint32_t sm1[3];
memcpy(sm1, scales_in1, scales_size);
const uint32_t mins_0_3 = sm[1] & kmask1;
const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4);
const uint32_t mins_0_3_1 = sm1[1] & kmask1;
const uint32_t mins_4_7_1 = ((sm1[2] >> 4) & kmask2) | (((sm1[1] >> 6) & kmask3) << 4);
svuint32_t mins_u32_temp = svzip1_u32(svdup_n_u32(mins_0_3), svdup_n_u32(mins_4_7));
svuint32_t mins_u32_temp_1 = svzip1_u32(svdup_n_u32(mins_0_3_1), svdup_n_u32(mins_4_7_1));
/* reinterpret u32 → u8 */
svuint8_t mins_u8 = svreinterpret_u8_u32(mins_u32_temp);
svuint8_t mins_u8_1 = svreinterpret_u8_u32(mins_u32_temp_1);
/* widen u8 → u16->u32 (lower half only) */
svuint32_t mins_u16 = svunpklo_u32(svunpklo_u16(mins_u8));
svuint32_t mins_u16_1 = svunpklo_u32(svunpklo_u16(mins_u8_1));
q4sb_mins_0 = svreinterpret_s32_u32(mins_u16);
q4sb_mins_1 = svreinterpret_s32_u32(mins_u16_1);
uint32_t scales_u32_0 = sm[0] & kmask1;
uint32_t scales_u32_1 = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4);
uint32_t scales_u32_2 = sm1[0] & kmask1;
uint32_t scales_u32_3 = (sm1[2] & kmask2) | (((sm1[0] >> 6) & kmask3) << 4);
svuint32_t S01 = svdup_n_u32(scales_u32_0);
svuint32_t S23 = svdup_n_u32(scales_u32_1);
svuint32_t R01 = svdup_n_u32(scales_u32_2);
svuint32_t R23 = svdup_n_u32(scales_u32_3);
svint8_t S01_b = svreinterpret_s8_u32(S01);
svint8_t S23_b = svreinterpret_s8_u32(S23);
svint8_t R01_b = svreinterpret_s8_u32(R01);
svint8_t R23_b = svreinterpret_s8_u32(R23);
svint32_t S01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S01_b, S01_b)));
svint32_t R01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R01_b, R01_b)));
svint32_t S23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S23_b, S23_b)));
svint32_t R23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R23_b, R23_b)));
block_scale_0 = svtbl_s32(svzip1_s32(S01_d, R01_d), idx);
block_scale_1 = svtbl_s32(svzip2_s32(S01_d, R01_d), idx);
block_scale_2 = svtbl_s32(svzip1_s32(S23_d, R23_d), idx);
block_scale_3 = svtbl_s32(svzip2_s32(S23_d, R23_d), idx);
}
const int8_t * q8_base_1 = q8_ptr[b].qs + sb * 256;
// Load 32-byte per row pair, 1 subblock each time
// predicate for activating higher lanes for 16 int8 elements
const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
// predicate for activating lower lanes for 16 int8 elements
const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16);
svint8_t q8_qs_0 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 0), svld1_s8(pl16, q8_base_1 + 112));
svint8_t q8_qs_2 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 32), svld1_s8(pl16, q8_base_1 + 144));
svint8_t q8_qs_4 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 64), svld1_s8(pl16, q8_base_1 + 176));
svint8_t q8_qs_6 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 96), svld1_s8(pl16, q8_base_1 + 208));
svint8_t q8_qs_1 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 16), svld1_s8(pl16, q8_base_1 + 128));
svint8_t q8_qs_3 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 48), svld1_s8(pl16, q8_base_1 + 160));
svint8_t q8_qs_5 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 80), svld1_s8(pl16, q8_base_1 + 192));
svint8_t q8_qs_7 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 112), svld1_s8(pl16, q8_base_1 + 224));
// Q4s columns iterated in pairs (01, 23, 45, 67)
for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
sb_acc_0 = svdup_n_s32(0);
sb_acc_2 = svdup_n_s32(0);
svuint8_t q4_qs_cp_00 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 0);
svuint8_t q4_qs_cp_01 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 64);
svuint8_t q4_qs_cp_02 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 128);
svuint8_t q4_qs_cp_03 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 192);
svint8_t q4_nibbles_00 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_00, m4b_1), 4));
svint8_t q4_nibbles_01 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_01, m4b_1), 4));
svint8_t q4_nibbles_02 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_02, m4b_1), 4));
svint8_t q4_nibbles_03 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_03, m4b_1), 4));
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_00, q8_qs_0);
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_01, q8_qs_2);
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_02, q8_qs_4);
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_03, q8_qs_6);
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_00, q8_qs_1);
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_01, q8_qs_3);
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_02, q8_qs_5);
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_03, q8_qs_7);
if(cp == 0) {
acc_00 = svmla_s32_m(svptrue_b32(), acc_00, sb_acc_0, block_scale_0);
acc_44 = svmla_s32_m(svptrue_b32(), acc_44, sb_acc_2, block_scale_0);
}
if(cp == 1) {
acc_11 = svmla_s32_m(svptrue_b32(), acc_11, sb_acc_0, block_scale_1);
acc_55 = svmla_s32_m(svptrue_b32(), acc_55, sb_acc_2, block_scale_1);
}
if(cp == 2) {
acc_22 = svmla_s32_m(svptrue_b32(), acc_22, sb_acc_0, block_scale_2);
acc_66 = svmla_s32_m(svptrue_b32(), acc_66, sb_acc_2, block_scale_2);
}
if(cp == 3) {
acc_33 = svmla_s32_m(svptrue_b32(), acc_33, sb_acc_0, block_scale_3);
acc_77 = svmla_s32_m(svptrue_b32(), acc_77, sb_acc_2, block_scale_3);
}
}
bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][0]), q4sb_mins_0);
bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][1]), q4sb_mins_1);
bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][2]), q4sb_mins_0);
bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][3]), q4sb_mins_1);
bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][4]), q4sb_mins_0);
bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][5]), q4sb_mins_1);
bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][6]), q4sb_mins_0);
bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][7]), q4sb_mins_1);
} // for sb
acc_00 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_00, svext_s32(acc_00, acc_00, 4));
acc_11 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_11, svext_s32(acc_11, acc_11, 4));
acc_22 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_22, svext_s32(acc_22, acc_22, 4));
acc_33 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_33, svext_s32(acc_33, acc_33, 4));
acc_44 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_44, svext_s32(acc_44, acc_44, 4));
acc_55 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_55, svext_s32(acc_55, acc_55, 4));
acc_66 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_66, svext_s32(acc_66, acc_66, 4));
acc_77 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_77, svext_s32(acc_77, acc_77, 4));
svint32_t reorder_acc_01 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_00, acc_11), svtrn1_s32(acc_22, acc_33)), idx1);
svint32_t reorder_acc_23 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_00, acc_11), svtrn2_s32(acc_22, acc_33)), idx1);
svint32_t reorder_acc_45 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_44, acc_55), svtrn1_s32(acc_66, acc_77)), idx1);
svint32_t reorder_acc_67 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_44, acc_55), svtrn2_s32(acc_66, acc_77)), idx1);
// Broadcast q8 scalar
svfloat32_t q8_d = svdup_f32(q8_ptr[b].d[0]);
svfloat32_t q4_dmin_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].dmin), svdup_f16(0)));
svfloat32_t q4_d_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].d), svdup_f16(0)));
svfloat32_t scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
svfloat32_t dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
acc_f32_01 = svmls_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_00), dmins1);
acc_f32_01 = svmla_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_01), scale1);
q8_d = svdup_f32(q8_ptr[b].d[1]);
scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
acc_f32_23 = svmls_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_22), dmins1);
acc_f32_23 = svmla_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_23), scale1);
q8_d = svdup_f32(q8_ptr[b].d[2]);
scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
acc_f32_45 = svmls_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_44), dmins1);
acc_f32_45 = svmla_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_45), scale1);
q8_d = svdup_f32(q8_ptr[b].d[3]);
scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
acc_f32_67 = svmls_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_66), dmins1);
acc_f32_67 = svmla_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_67), scale1);
} // for b
// With the previous reorder, the tile is already in the correct memory layout.
// Predicate for exactly 4 lanes
svbool_t pg4 = svptrue_pat_b32(SV_VL4);
for (int i = 0; i < q8_k_blocklen; i++) {
int row = y * q8_k_blocklen + i;
for (int j = 0; j < 2; j++) {
int col = x * ncols_interleaved + j * 4;
int offset = row * bs + col;
if (i == 0 && j == 0) {
// acc_f32_0 → lower half of acc_f32_01
svst1_f32(pg4, s + offset, acc_f32_01);
} else if (i == 0 && j == 1) {
// acc_f32_1 → upper half of acc_f32_01
svst1_f32(pg4, s + offset, svext_f32(acc_f32_01, acc_f32_01, 4));
} else if (i == 1 && j == 0) {
// acc_f32_2
svst1_f32(pg4, s + offset, acc_f32_23);
} else if (i == 1 && j == 1) {
// acc_f32_3
svst1_f32(pg4, s + offset, svext_f32(acc_f32_23, acc_f32_23, 4));
} else if (i == 2 && j == 0) {
// acc_f32_4
svst1_f32(pg4, s + offset, acc_f32_45);
} else if (i == 2 && j == 1) {
// acc_f32_5
svst1_f32(pg4, s + offset, svext_f32(acc_f32_45, acc_f32_45, 4));
} else if (i == 3 && j == 0) {
// acc_f32_6
svst1_f32(pg4, s + offset, acc_f32_67);
} else if (i == 3 && j == 1) {
// acc_f32_7
svst1_f32(pg4, s + offset, svext_f32(acc_f32_67, acc_f32_67, 4));
}
}
}
} // for x
} // for y
return;
}
#endif // SVE compile-time end
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
constexpr int q8_k_blocklen = 4;
const uint8x16_t m4b = vdupq_n_u8(0x0f);
+2 -6
View File
@@ -59,11 +59,7 @@ static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * ds
GGML_ASSERT(nb00 == sizeof(src0_t));
const auto [ir0, ir1] = get_thread_range(params, src0);
const bool is_src1_contiguous = (nb10 == sizeof(src1_t));
if (!is_src1_contiguous) { // broadcast not implemented yet for non-contiguous
GGML_ASSERT(ggml_are_same_shape(src0, src1));
}
const bool is_src1_contiguous_rows = ggml_is_contiguous_rows(src1);
#ifdef GGML_USE_ACCELERATE
vDSP_fn_t vDSP_op = nullptr;
@@ -94,7 +90,7 @@ static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * ds
const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
const src1_t * src1_ptr = (const src1_t *) ((const char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
if (is_src1_contiguous) {
if (is_src1_contiguous_rows) {
// src1 is broadcastable across src0 and dst in i1, i2, i3
const int64_t nr0 = ne00 / ne10;
+2 -2
View File
@@ -6,8 +6,8 @@
#include "ggml-impl.h"
#include "simd-mappings.h"
#define GGML_FA_TILE_Q 32
#define GGML_FA_TILE_KV 16
#define GGML_FA_TILE_Q 64
#define GGML_FA_TILE_KV 64
#ifdef __cplusplus
+12 -4
View File
@@ -2874,8 +2874,8 @@ struct ggml_cplan ggml_graph_plan(
const int64_t DV = node->src[2]->ne[0];
// Tiled flash attention scratch (tile sizes defined in common.h)
// Per-thread: Q_q + KQ + mask + VKQ32 + V32 + padding
size_t prefill = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks;
// Per-thread: Q_q + KQ + mask + VKQ32 + V32 + K_f32 + padding
size_t prefill = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV + GGML_FA_TILE_KV*DK)*n_tasks;
// Decode path: n_kv_chunks = n_tasks (one chunk per thread)
// Per-thread: VKQ accmulator (DV), partial M, partial S + intra-thread scratch for V, Q and VKQ
@@ -2947,7 +2947,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
/*.use_ref =*/ cplan->use_ref,
};
GGML_PRINT_DEBUG("thread #%d compute-start cplan %p last-graph %d \n", state->ith, cplan, state->last_graph);
#ifdef GGML_USE_OPENMP
GGML_PRINT_DEBUG("thread #%d compute-start cplan %p\n", state->ith, (const void *)cplan);
#else
GGML_PRINT_DEBUG("thread #%d compute-start cplan %p last-graph %d\n", state->ith, (const void *)cplan, state->last_graph);
#endif
for (int node_n = 0; node_n < cgraph->n_nodes && atomic_load_explicit(&tp->abort, memory_order_relaxed) != node_n; node_n++) {
struct ggml_tensor * node = cgraph->nodes[node_n];
@@ -2974,7 +2978,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
}
}
GGML_PRINT_DEBUG("thread #%d compute-done cplan %p last-graph %d \n", state->ith, cplan, state->last_graph);
#ifdef GGML_USE_OPENMP
GGML_PRINT_DEBUG("thread #%d compute-done cplan %p\n", state->ith, (const void *)cplan);
#else
GGML_PRINT_DEBUG("thread #%d compute-done cplan %p last-graph %d\n", state->ith, (const void *)cplan, state->last_graph);
#endif
ggml_barrier(state->threadpool);
+173 -94
View File
@@ -3,6 +3,7 @@
#include "ggml-cpu.h"
#include "ggml-impl.h"
#include "binary-ops.h"
#include "simd-gemm.h"
#include "ggml.h"
#include "unary-ops.h"
#include "vec.h"
@@ -2096,10 +2097,14 @@ static void ggml_compute_forward_gelu_f32(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@@ -2113,10 +2118,14 @@ static void ggml_compute_forward_gelu_f32(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_gelu_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])));
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@@ -2135,10 +2144,14 @@ static void ggml_compute_forward_gelu_f16(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@@ -2152,10 +2165,14 @@ static void ggml_compute_forward_gelu_f16(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_gelu_f16(nc,
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@@ -2276,10 +2293,14 @@ static void ggml_compute_forward_gelu_erf_f32(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@@ -2293,10 +2314,14 @@ static void ggml_compute_forward_gelu_erf_f32(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_gelu_erf_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])));
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@@ -2315,10 +2340,14 @@ static void ggml_compute_forward_gelu_erf_f16(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@@ -2332,10 +2361,14 @@ static void ggml_compute_forward_gelu_erf_f16(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_gelu_erf_f16(nc,
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@@ -2379,10 +2412,14 @@ static void ggml_compute_forward_gelu_quick_f32(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@@ -2396,10 +2433,14 @@ static void ggml_compute_forward_gelu_quick_f32(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_gelu_quick_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])));
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@@ -2418,10 +2459,14 @@ static void ggml_compute_forward_gelu_quick_f16(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@@ -2435,10 +2480,14 @@ static void ggml_compute_forward_gelu_quick_f16(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_gelu_quick_f16(nc,
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@@ -2482,10 +2531,14 @@ static void ggml_compute_forward_silu_f32(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@@ -2499,10 +2552,14 @@ static void ggml_compute_forward_silu_f32(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_silu_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])));
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@@ -2521,10 +2578,14 @@ static void ggml_compute_forward_silu_f16(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@@ -2538,10 +2599,14 @@ static void ggml_compute_forward_silu_f16(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_silu_f16(nc,
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@@ -8325,10 +8390,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
GGML_ASSERT(k->type == v->type);
const ggml_type kv_type = k->type;
const auto * kv_type_traits_cpu = ggml_get_type_traits_cpu(kv_type);
const ggml_from_float_t kv_from_float = kv_type_traits_cpu->from_float;
const ggml_vec_dot_t kv_vec_dot = kv_type_traits_cpu->vec_dot;
const size_t kv_type_size = ggml_type_size(kv_type);
// broadcast factors
const int64_t rk2 = neq2/nek2;
@@ -8360,8 +8421,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q;
static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;
GGML_ASSERT(nek1 % KV_TILE_SZ == 0 && "KV sequence length must be divisible by KV_TILE_SZ");
int ir = ir0;
while (ir < ir1) {
// q indices for the start of this tile
@@ -8388,18 +8447,20 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
}
// Per-thread scratch layout:
// Q_q: Q_TILE_SZ * DK (converted Q tile in KV type)
// Q_q: Q_TILE_SZ * DK (converted Q tile — F32 for GEMM, KV type for scalar)
// KQ: Q_TILE_SZ * KV_TILE_SZ (attention scores in float)
// mask: Q_TILE_SZ * KV_TILE_SZ (mask in float)
// VKQ32: Q_TILE_SZ * DV (FP32 output accumulator)
// V32: KV_TILE_SZ * DV (F32 buffer for V tile - used for f166 conversion)
float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + CACHE_LINE_SIZE_F32);
// V32: KV_TILE_SZ * DV (F32 buffer for V tile)
// K_f32: KV_TILE_SZ * DK (F32 buffer for K tile — GEMM path)
float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + KV_TILE_SZ*DK + CACHE_LINE_SIZE_F32);
void * Q_q = base;
float * KQ = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float));
float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ;
float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ;
float * V32 = VKQ32 + Q_TILE_SZ * DV; // F32 buffer for V tile
float * V32 = VKQ32 + Q_TILE_SZ * DV;
float * K_f32 = V32 + KV_TILE_SZ * DV;
memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float));
memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
@@ -8412,28 +8473,38 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
const int iv3 = iq3 / rv3;
const int iv2 = iq2 / rv2;
for (int tq = 0; tq < tile_rows; tq++) {
const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
kv_from_float(pq, (char *)Q_q + tq * DK * kv_type_size, DK);
}
// Zero-pad remaining rows
for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
memset((char *)Q_q + tq * DK * kv_type_size, 0, DK * kv_type_size);
{
float * Q_f32 = (float *)Q_q;
for (int tq = 0; tq < tile_rows; tq++) {
const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
memcpy(Q_f32 + tq * DK, pq, DK * sizeof(float));
}
for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
memset(Q_f32 + tq * DK, 0, DK * sizeof(float));
}
}
memset(K_f32, 0, DK * KV_TILE_SZ * sizeof(float));
memset(V32, 0, KV_TILE_SZ * DV * sizeof(float));
for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) {
const int kv_tile = (int)std::min((int64_t)KV_TILE_SZ, nek1 - ic);
// skip the tile entirely if all the masks are -inf
if (mask) {
bool can_skip = true;
for (int tq = 0; tq < tile_rows; tq++) {
const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]);
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
for (int tk = 0; tk < kv_tile; tk++) {
mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]);
if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) {
can_skip = false;
}
}
// Pad remaining mask entries with -inf
for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
mask32[tq * KV_TILE_SZ + tk] = -INFINITY;
}
}
if (can_skip) {
@@ -8441,13 +8512,32 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
}
}
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
const void * q_row = (const char *)Q_q + tq * DK * kv_type_size;
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
const void * k_row = (const char *) k->data + ((ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3);
float s;
kv_vec_dot(DK, &s, 0, k_row, 0, q_row, 0, 1);
KQ[tq * KV_TILE_SZ + tk] = s * scale;
// Pack K tile transposed: K_f32[dk][kv] so KV_TILE is contiguous (SIMD dim)
// Zero-pad the last tile so the GEMM always operates on KV_TILE_SZ columns
for (int tk = 0; tk < kv_tile; tk++) {
const char * k_data = (const char *)k->data + (ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3;
if (kv_type == GGML_TYPE_F16) {
const ggml_fp16_t * k_f16 = (const ggml_fp16_t *)k_data;
for (int64_t dk = 0; dk < DK; dk++) {
K_f32[dk * KV_TILE_SZ + tk] = GGML_CPU_FP16_TO_FP32(k_f16[dk]);
}
} else {
const float * k_f32_src = (const float *)k_data;
for (int64_t dk = 0; dk < DK; dk++) {
K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk];
}
}
}
memset(KQ, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
simd_gemm(KQ, (const float *)Q_q, K_f32, Q_TILE_SZ, DK, KV_TILE_SZ);
ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, scale);
// Set padded KQ entries to -inf so softmax gives them zero weight
if (kv_tile < KV_TILE_SZ) {
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
KQ[tq * KV_TILE_SZ + tk] = -INFINITY;
}
}
}
@@ -8487,33 +8577,22 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew);
}
// Convert V tile to F32 first (if F16), then do MAD
// On x86, ggml_vec_mad_f16 internall converts F16<->F32 on every load/store, so pre-converting is faster.
// TODO: on ARM, native f16 should be faster
if (kv_type == GGML_TYPE_F16) {
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
const ggml_fp16_t * v_row = (const ggml_fp16_t *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
ggml_fp16_to_fp32_row(v_row, V32 + tk * DV, DV);
}
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
if (skip[tq]) continue;
float * vkq_row = VKQ32 + tq * DV;
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
const float p = KQ[tq * KV_TILE_SZ + tk];
ggml_vec_mad_f32(DV, vkq_row, V32 + tk * DV, p);
}
}
} else {
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
if (skip[tq]) continue;
float * vkq_row = VKQ32 + tq * DV;
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
const float p = KQ[tq * KV_TILE_SZ + tk];
const float * v_row = (const float *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
ggml_vec_mad_f32(DV, vkq_row, v_row, p);
}
// V accumulation: VKQ32 += softmax(KQ) * V
// Pack V tile to contiguous F32, zero-padded
for (int tk = 0; tk < kv_tile; tk++) {
const char * v_data = (const char *)v->data + (ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3;
if (kv_type == GGML_TYPE_F16) {
ggml_fp16_to_fp32_row((const ggml_fp16_t *)v_data, V32 + tk * DV, DV);
} else {
memcpy(V32 + tk * DV, v_data, DV * sizeof(float));
}
}
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
if (skip[tq]) {
memset(KQ + tq * KV_TILE_SZ, 0, KV_TILE_SZ * sizeof(float));
}
}
simd_gemm(VKQ32, KQ, V32, Q_TILE_SZ, KV_TILE_SZ, DV);
}
// sinks (apply only to valid rows in the tile)
@@ -8730,15 +8809,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const int64_t dr = (nr + nchunk - 1) / nchunk;
static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV;
static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
const bool use_tiled = !use_ref &&
bool use_tiled = !use_ref &&
(q->type == GGML_TYPE_F32 &&
kv_is_f32_or_f16 &&
k->type == v->type &&
nek1 % KV_TILE_SZ == 0 &&
neq1 >= Q_TILE_SZ);
#ifdef GGML_SIMD
use_tiled &= (DV % GGML_F32_EPR == 0);
#endif
int current_chunk = ith;
while (current_chunk < nchunk) {
+3 -2
View File
@@ -1916,9 +1916,10 @@ static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_in
int src_offset = (i / 8) * blck_size_interleave;
int dst_offset = i * blck_size_interleave;
// buffer large enough for the max interleave block size (8 bytes)
uint64_t elems;
memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
memcpy(&elems, &in[src_id].qs[src_offset], blck_size_interleave);
memcpy(&out.qs[dst_offset], &elems, blck_size_interleave);
}
// The below logic is designed so as to unpack and rearrange scales and mins values in Q4_K
+136
View File
@@ -0,0 +1,136 @@
#pragma once
// Computes C[M x N] += A[M x K] * B[K x N]
#include "simd-mappings.h"
// TODO: add support for sizeless vector types
#if defined(GGML_SIMD) && !defined(__ARM_FEATURE_SVE) && !defined(__riscv_v_intrinsic)
// TODO: untested on avx512
// These are in units of GGML_F32_EPR
#if defined(__AVX512F__) || defined (__ARM_NEON__)
static constexpr int GEMM_RM = 4;
static constexpr int GEMM_RN = 4; // 16+4+1 = 25/32
#elif defined(__AVX2__) || defined(__AVX__)
static constexpr int GEMM_RM = 6;
static constexpr int GEMM_RN = 2; // 12+2+1 = 15/16
#else
static constexpr int GEMM_RM = 2;
static constexpr int GEMM_RN = 2;
#endif
template <int RM, int RN>
static inline void simd_gemm_ukernel(
float * GGML_RESTRICT C,
const float * GGML_RESTRICT A,
const float * GGML_RESTRICT B,
int K, int N)
{
static constexpr int KN = GGML_F32_EPR;
GGML_F32_VEC acc[RM][RN];
for (int64_t i = 0; i < RM; i++) {
for (int r = 0; r < RN; r++) {
acc[i][r] = GGML_F32_VEC_LOAD(C + i * N + r * KN);
}
}
for (int64_t kk = 0; kk < K; kk++) {
GGML_F32_VEC Bv[RN];
for (int r = 0; r < RN; r++) {
Bv[r] = GGML_F32_VEC_LOAD(B + kk * N + r * KN);
}
for (int64_t i = 0; i < RM; i++) {
GGML_F32_VEC p = GGML_F32_VEC_SET1(A[i * K + kk]);
for (int r = 0; r < RN; r++) {
acc[i][r] = GGML_F32_VEC_FMA(acc[i][r], Bv[r], p);
}
}
}
for (int64_t i = 0; i < RM; i++) {
for (int r = 0; r < RN; r++) {
GGML_F32_VEC_STORE(C + i * N + r * KN, acc[i][r]);
}
}
}
// C[M x N] += A[M x K] * B[K x N]
static void simd_gemm(
float * GGML_RESTRICT C,
const float * GGML_RESTRICT A,
const float * GGML_RESTRICT B,
int M, int K, int N)
{
static constexpr int KN = GGML_F32_EPR;
int64_t ii = 0;
for (; ii + GEMM_RM <= M; ii += GEMM_RM) {
int64_t jj = 0;
for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) {
simd_gemm_ukernel<GEMM_RM, GEMM_RN>(C + jj, A, B + jj, K, N);
}
for (; jj + KN <= N; jj += KN) {
simd_gemm_ukernel<GEMM_RM, 1>(C + jj, A, B + jj, K, N);
}
for (; jj < N; jj++) {
for (int64_t i = 0; i < GEMM_RM; i++) {
float a = C[i * N + jj];
for (int64_t kk = 0; kk < K; kk++) {
a += A[i + kk] * B[kk * N + jj];
}
C[i * N + jj] = a;
}
}
A += GEMM_RM * K;
C += GEMM_RM * N;
}
// Tail rows: one at a time
for (; ii < M; ii++) {
int64_t jj = 0;
for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) {
simd_gemm_ukernel<1, GEMM_RN>(C + jj, A, B + jj, K, N);
}
for (; jj + KN <= N; jj += KN) {
simd_gemm_ukernel<1, 1>(C + jj, A, B + jj, K, N);
}
for (; jj < N; jj++) {
float a = C[jj];
for (int64_t kk = 0; kk < K; kk++) {
a += A[kk] * B[kk * N + jj];
}
C[jj] = a;
}
A += K;
C += N;
}
}
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic pop
#endif
#else // scalar path
static void simd_gemm(
float * GGML_RESTRICT C,
const float * GGML_RESTRICT A,
const float * GGML_RESTRICT B,
int M, int K, int N)
{
for (int64_t i = 0; i < M; i++) {
for (int64_t j = 0; j < N; j++) {
float sum = C[i * N + j];
for (int64_t kk = 0; kk < K; kk++) {
sum += A[i * K + kk] * B[kk * N + j];
}
C[i * N + j] = sum;
}
}
}
#endif // GGML_SIMD
+26
View File
@@ -1160,6 +1160,14 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
float32x4_t tmp = x[0] + vec_reve(x[0]); \
res = tmp[0] + tmp[1]; \
}
#define GGML_F32x4_REDUCE_4(res, s0, s1, s2, s3) \
{ \
float32x4_t v = vec_add(vec_add(s0, s1), \
vec_add(s2, s3)); \
v = vec_add(v, vec_sld(v, v, 8)); \
v = vec_add(v, vec_sld(v, v, 4)); \
res += (ggml_float)vec_extract(v, 0); \
}
#define GGML_F32_VEC GGML_F32x4
#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
@@ -1209,6 +1217,24 @@ static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) {
#define GGML_F16_VEC_MUL GGML_F32x4_MUL
#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
// BF16 s390x
#define GGML_BF16_STEP 16
#define GGML_BF16_EPR 8
#define GGML_BF16x8 __vector unsigned short
#define GGML_BF16x8_ZERO vec_splats((unsigned short)0)
#define GGML_BF16x8_LOAD(p) vec_xl(0, (const unsigned short *)(p))
#define GGML_BF16_VEC GGML_BF16x8
#define GGML_BF16_VEC_ZERO GGML_BF16x8_ZERO
#define GGML_BF16_VEC_LOAD GGML_BF16x8_LOAD
#define GGML_BF16_TO_F32_LO(v) ((float32x4_t) vec_mergel((v), GGML_BF16_VEC_ZERO))
#define GGML_BF16_TO_F32_HI(v) ((float32x4_t) vec_mergeh((v), GGML_BF16_VEC_ZERO))
#define GGML_BF16_FMA_LO(acc, x, y) \
(acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_LO(x), GGML_BF16_TO_F32_LO(y))
#define GGML_BF16_FMA_HI(acc, x, y) \
(acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_HI(x), GGML_BF16_TO_F32_HI(y))
#elif defined(__riscv_v_intrinsic)
// compatible with vlen >= 128
+1 -1
View File
@@ -111,7 +111,7 @@ template <float (*op)(float), typename src0_t, typename dst_t>
static void apply_unary_op(const ggml_compute_params * params, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst));
GGML_ASSERT(ggml_is_contiguous_rows(src0) && ggml_is_contiguous_rows(dst) && ggml_are_same_shape(src0, dst));
GGML_TENSOR_UNARY_OP_LOCALS
+1 -2
View File
@@ -236,8 +236,7 @@ void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t *
vfloat32m1_t redsum = __riscv_vfredusum_vs_f32m4_f32m1(vsum0, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
sumf += __riscv_vfmv_f_s_f32m1_f32(redsum);
#endif
#if defined(__POWER9_VECTOR__)
#elif defined(__POWER9_VECTOR__) || defined(__VXE__) || defined(__VXE2__)
const int np = (n & ~(GGML_BF16_STEP - 1));
if (np > 0) {
GGML_F32_VEC sum[4] = {GGML_F32_VEC_ZERO};
+32 -30
View File
@@ -39,13 +39,16 @@ static __global__ void k_bin_bcast(const src0_t * src0,
const uint3 ne11,
const uint3 ne12,
const uint3 ne13,
/*int s0, */ const int s1,
/*const int s0,*/
const int s1,
const int s2,
const int s3,
/*int s00,*/ const int s01,
const int s00,
const int s01,
const int s02,
const int s03,
/*int s10,*/ const int s11,
const int s10,
const int s11,
const int s12,
const int s13,
src1_ptrs... src1s) {
@@ -72,11 +75,11 @@ static __global__ void k_bin_bcast(const src0_t * src0,
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
const uint32_t i10 = fastmodulo(i0, ne10);
float result = src0_row ? (float) src0_row[i0] : 0.0f;
float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
if constexpr (sizeof...(src1_ptrs) > 0) {
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
} else {
result = bin_op(result, (float)src1[i_src1 + i10]);
result = bin_op(result, (float)src1[i_src1 + i10*s10]);
}
dst_row[i0] = (dst_t) result;
@@ -101,13 +104,16 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,
const uint3 ne11,
const uint3 ne12,
const uint3 ne13,
/*int s0, */ const int s1,
/*const int s0,*/
const int s1,
const int s2,
const int s3,
/*int s00,*/ const int s01,
const int s00,
const int s01,
const int s02,
const int s03,
/*int s10,*/ const int s11,
const int s10,
const int s11,
const int s12,
const int s13,
src1_ptrs... src1s) {
@@ -135,11 +141,11 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,
const int i10 = fastmodulo(i0, ne10);
float result = src0_row ? (float) src0_row[i0] : 0.0f;
float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
if constexpr (sizeof...(src1_ptrs) > 0) {
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
} else {
result = bin_op(result, (float)src1[i_src1 + i10]);
result = bin_op(result, (float)src1[i_src1 + i10*s10]);
}
dst_row[i0] = (dst_t) result;
@@ -179,7 +185,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
cnb[3] *= cne[3];
};
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
for (int i = 0; i < 4; i++) {
if (nr[i] != 1) {
break;
@@ -221,7 +227,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
size_t nb12 = cnb1[2];
size_t nb13 = cnb1[3];
size_t s0 = nb0 / sizeof(dst_t);
//size_t s0 = nb0 / sizeof(dst_t);
size_t s1 = nb1 / sizeof(dst_t);
size_t s2 = nb2 / sizeof(dst_t);
size_t s3 = nb3 / sizeof(dst_t);
@@ -251,10 +257,6 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
GGML_ASSERT(s0 == 1);
GGML_ASSERT(s00 == 1);
GGML_ASSERT(s10 == 1);
const int block_size = 128;
int64_t hne0 = std::max(ne0 / 2LL, 1LL);
@@ -284,31 +286,31 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t><<<block_num, block_size, 0, stream>>>(
src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11,
ne12, ne13,
/* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
/*s0,*/ s1, s2, s3,
s00, s01, s02, s03,
s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
} else {
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv,
ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12, s13);
/*s0,*/ s1, s2, s3,
s00, s01, s02, s03,
s10, s11, s12, s13);
}
} else {
const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
if constexpr (sizeof...(I) > 0) {
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
/*s0,*/ s1, s2, s3,
s00 ,s01, s02, s03,
s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
} else {
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12, s13);
/*s0,*/ s1, s2, s3,
s00, s01, s02, s03,
s10, s11, s12, s13);
}
}
}
+38 -24
View File
@@ -7,7 +7,8 @@
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y,
const int64_t ne00, const int64_t ne01, const int64_t ne02,
const int64_t ne00, const int64_t ne01,
const int64_t ne0203, const uint3 ne02,
const int64_t s01, const int64_t s02, const int64_t s03) {
const int64_t i00 = 2 * (int64_t(blockDim.x)*blockIdx.x + threadIdx.x);
@@ -16,23 +17,27 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
}
const int64_t i01 = blockIdx.y;
const int64_t i02 = blockIdx.z % ne02;
const int64_t i03 = blockIdx.z / ne02;
const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {
const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);
const int64_t i02 = dm.y;
const int64_t i03 = dm.x;
const int64_t ib = ibx0 + i00/qk; // block index
const int64_t iqs = (i00%qk)/qr; // quant index
const int64_t iybs = i00 - i00%qk; // y block start index
const int64_t y_offset = qr == 1 ? 1 : qk/2;
const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
// dequantize
float2 v;
dequantize_kernel(vx, ib, iqs, v);
const int64_t ib = ibx0 + i00/qk; // block index
const int64_t iqs = (i00%qk)/qr; // quant index
const int64_t iybs = i00 - i00%qk; // y block start index
const int64_t y_offset = qr == 1 ? 1 : qk/2;
const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs;
y[iy0 + 0] = ggml_cuda_cast<dst_t>(v.x);
y[iy0 + y_offset] = ggml_cuda_cast<dst_t>(v.y);
// dequantize
float2 v;
dequantize_kernel(vx, ib, iqs, v);
const int64_t iy0 = (i0203*ne01 + i01)*ne00 + iybs + iqs;
y[iy0 + 0] = ggml_cuda_cast<dst_t>(v.x);
y[iy0 + y_offset] = ggml_cuda_cast<dst_t>(v.y);
}
}
template <bool need_check>
@@ -485,9 +490,11 @@ template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static void dequantize_block_cuda(const void * vx, dst_t * y,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, ne02*ne03);
const int64_t ne0203 = ne02*ne03;
const uint3 ne02_fdv = init_fastdiv_values(ne02);
const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, (int)std::min(ne0203, (int64_t)65535));
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
(vx, y, ne00, ne01, ne02, s01, s02, s03);
(vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03);
}
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
@@ -612,7 +619,8 @@ static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t
template <typename src_t, typename dst_t>
static __global__ void convert_unary(
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01,
const int64_t ne0203, const uint3 ne02,
const int64_t s01, const int64_t s02, const int64_t s03) {
const int64_t i00 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
@@ -621,23 +629,29 @@ static __global__ void convert_unary(
}
const int64_t i01 = blockIdx.y;
const int64_t i02 = blockIdx.z % ne02;
const int64_t i03 = blockIdx.z / ne02;
const src_t * x = (const src_t *) vx;
const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
const int64_t iy = ((i03*ne02 + i02)*ne01 + i01)*ne00 + i00;
y[iy] = ggml_cuda_cast<dst_t>(x[ix]);
for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {
const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);
const int64_t i02 = dm.y;
const int64_t i03 = dm.x;
const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
const int64_t iy = (i0203*ne01 + i01)*ne00 + i00;
y[iy] = ggml_cuda_cast<dst_t>(x[ix]);
}
}
template <typename src_t, typename dst_t>
static void convert_unary_cuda(const void * vx, dst_t * y,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, ne02*ne03);
const int64_t ne0203 = ne02*ne03;
const uint3 ne02_fdv = init_fastdiv_values(ne02);
const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, (int)std::min(ne0203, (int64_t)65535));
convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
(vx, y, ne00, ne01, ne02, s01, s02, s03);
(vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03);
}
template <typename src_t, typename dst_t>
+26 -5
View File
@@ -63,11 +63,19 @@ static __global__ void flash_attn_ext_f16(
constexpr int frag_m = ncols == 8 ? 32 : 16;
constexpr int frag_n = ncols == 8 ? 8 : 16;
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
#if defined(GGML_USE_HIP) && HIP_VERSION >= 60500000
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, _Float16, wmma::row_major> frag_a_K;
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, _Float16, wmma::col_major> frag_a_V;
typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, _Float16, wmma::col_major> frag_b;
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, _Float16> frag_c_VKQ;
#else
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::row_major> frag_a_K;
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::col_major> frag_a_V;
typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, half, wmma::col_major> frag_b;
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
#endif
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
@@ -126,6 +134,19 @@ static __global__ void flash_attn_ext_f16(
__shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
half2 * VKQ2 = (half2 *) VKQ;
#if defined(GGML_USE_HIP) && HIP_VERSION >= 60500000
const _Float16 * K_h_f16 = reinterpret_cast<const _Float16 *>(K_h);
const _Float16 * V_h_f16 = reinterpret_cast<const _Float16 *>(V_h);
_Float16 * KQ_f16 = reinterpret_cast<_Float16 *>(KQ);
_Float16 * VKQ_f16 = reinterpret_cast<_Float16 *>(VKQ);
#else
const half * K_h_f16 = K_h;
const half * V_h_f16 = V_h;
half * KQ_f16 = KQ;
half * VKQ_f16 = VKQ;
#endif
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
@@ -160,7 +181,7 @@ static __global__ void flash_attn_ext_f16(
for (int i0 = 0; i0 < D; i0 += 16) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ_f16 + j0*D_padded + i0, D_padded);
}
}
@@ -180,7 +201,7 @@ static __global__ void flash_attn_ext_f16(
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
frag_a_K K_a;
wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
wmma::load_matrix_sync(K_a, K_h_f16 + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
@@ -310,7 +331,7 @@ static __global__ void flash_attn_ext_f16(
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
wmma::load_matrix_sync(
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
KQ + j0*(kqar*kqs_padded) + k,
KQ_f16 + j0*(kqar*kqs_padded) + k,
kqar*kqs_padded);
}
}
@@ -328,7 +349,7 @@ static __global__ void flash_attn_ext_f16(
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
frag_a_V v_a;
wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
wmma::load_matrix_sync(v_a, V_h_f16 + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
@@ -344,7 +365,7 @@ static __global__ void flash_attn_ext_f16(
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
wmma::store_matrix_sync(
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
KQ_f16 + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
D_padded, wmma::mem_col_major);
}
+21 -35
View File
@@ -2278,11 +2278,12 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
// [TAG_MUL_MAT_ID_CUDA_GRAPHS]
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE);
if (ne2 <= MMVQ_MAX_BATCH_SIZE) {
if (ggml_is_quantized(src0->type)) {
if (ne2 <= 4) {
if (ne2 <= MMVQ_MMID_MAX_BATCH_SIZE) {
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
return;
}
@@ -2305,6 +2306,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
}
}
// note: this path should not be reached when recording CUDA graphs, because it requires stream synchronization
// TODO: add asserts to verify this. should work with CUDA, HIP, etc.
cudaStream_t stream = ctx.stream();
GGML_ASSERT(nb12 % nb11 == 0);
@@ -2865,14 +2868,6 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
bool use_cuda_graph = true;
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out";
const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d";
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
@@ -2887,30 +2882,14 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
#endif
}
if (node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) {
use_cuda_graph = false; // This node type is not supported by CUDA graph capture
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
#endif
}
if (node->op == GGML_OP_ADD &&
node->src[1] && node->src[1]->ne[1] > 1 &&
(node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
(node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 &&
strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 &&
strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) {
// disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
// by means of matching node names. See
// https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
// https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
// Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
// [TAG_MUL_MAT_ID_CUDA_GRAPHS]
if (node->op == GGML_OP_MUL_MAT_ID && (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > MMVQ_MMID_MAX_BATCH_SIZE)) {
// under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs
// TODO: figure out a way to enable for larger batch sizes, without hurting performance
// ref: https://github.com/ggml-org/llama.cpp/pull/18958
use_cuda_graph = false;
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
#endif
}
@@ -3640,11 +3619,13 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
n_fuse++;
if (n_fuse > 1) {
ggml_tensor fused_add_node;
memcpy(&fused_add_node, node, sizeof(ggml_tensor));
for (int j = 0; j < n_fuse - 1; ++j) {
node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
fused_add_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
}
cgraph->nodes[i + n_fuse - 1]->data = node->data;
ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse);
fused_add_node.data = cgraph->nodes[i + n_fuse - 1]->data;
ggml_cuda_op_fused_add(*cuda_ctx, &fused_add_node, n_fuse);
i += n_fuse - 1;
continue;
@@ -4542,6 +4523,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_UNARY_OP_CEIL:
case GGML_UNARY_OP_ROUND:
case GGML_UNARY_OP_TRUNC:
// TODO: should become:
//return ggml_is_contiguous_rows(op->src[0]);
return ggml_is_contiguous(op->src[0]);
default:
return false;
@@ -4820,8 +4803,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_CONV_2D_DW:
case GGML_OP_CONV_TRANSPOSE_2D:
case GGML_OP_POOL_2D:
case GGML_OP_ACC:
return true;
case GGML_OP_ACC:
// TODO: extend support like so:
//return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]);
return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
case GGML_OP_SUM:
return ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_TOP_K:
+21 -16
View File
@@ -2715,14 +2715,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
#pragma unroll
for (int l = 0; l < QR2_XXS; ++l) {
const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]);
const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F];
const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[l]];
const uint32_t signs = unpack_ksigns(aux32 >> (7 * l));
const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0);
const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
const int grid1 = __vsub4(grid_pos.y ^ signs1, signs1);
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
@@ -2733,12 +2733,12 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
const int ls = aux32 >> 28;
const int ls = aux32 >> 27 | 1; // (scale * 2 + 1)
const float d = bxi->d;
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4
#else
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4;
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
}
@@ -2776,11 +2776,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
#pragma unroll
for (int l = 0; l < QR2_XS; ++l) {
const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF));
const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9));
const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l] & 0x1FF];
const uint32_t signs = unpack_ksigns(q2[l] >> 9);
const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
@@ -2904,11 +2907,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
#pragma unroll
for (int l = 0; l < QR3_XXS; ++l) {
const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]);
const uint32_t signs = unpack_ksigns(aux32 >> (7*l));
const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F));
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
+1
View File
@@ -1,6 +1,7 @@
#include "common.cuh"
#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
#define MMVQ_MMID_MAX_BATCH_SIZE 4 // Max. batch size for which to use MMVQ kernels for MUL_MAT_ID
void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr);
+31 -17
View File
@@ -94,6 +94,15 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con
#endif
}
static __device__ __forceinline__ uint32_t unpack_ksigns(const uint8_t v) {
// v is a 7 bit int, with the 8th sign being encodable as popcnt
// with xor we can "correct" the bit instead of having to mask
const uint32_t p = __popc(v) & 1;
const uint32_t s = v ^ p << 7;
// broadcast over uint to allow for 0x08040201 / 0x80402010 as selectors
return s * 0x01010101;
}
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
@@ -905,22 +914,22 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
int sumi = 0;
#pragma unroll
for (int k0 = 0; k0 < 8; k0 += 2) {
const int * grid_pos = (const int *) (iq2xxs_grid + aux8[k0/2]);
const int signs_packed = ksigns_iq2xs[(aux32 >> (7*k0/2)) & 0x7F];
const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[k0/2]];
const uint32_t signs = unpack_ksigns(aux32 >> (7 * k0 / 2));
const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0);
const int u0 = get_int_b4(bq8_1[iqs/2].qs, k0 + 0);
sumi = ggml_cuda_dp4a(grid0, u0, sumi);
const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
const int grid1 = __vsub4(grid_pos.y ^ signs1, signs1);
const int u1 = get_int_b4(bq8_1[iqs/2].qs, k0 + 1);
sumi = ggml_cuda_dp4a(grid1, u1, sumi);
}
const int ls = aux32 >> 28;
sumi = (ls*sumi + sumi/2)/4;
const int ls = aux32 >> 27 | 1; // (scale * 2 + 1)
sumi = sumi * ls / 8; // (sumi * scale + sumi / 2) / 4
const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds);
return d * sumi;
}
@@ -942,13 +951,15 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
int sumi1 = 0;
#pragma unroll
for (int l0 = 0; l0 < 8; l0 += 2) {
const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l0/2] & 0x000001FF));
const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l0/2] >> 9));
const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l0/2] & 0x1FF];
const uint32_t signs = unpack_ksigns(q2[l0/2] >> 9);
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
if (l0 < 4) {
@@ -1028,13 +1039,16 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
#pragma unroll
for (int l0 = 0; l0 < 8; l0 += 2) {
const int2 grid_pos = make_int2(iq3xxs_grid[q3[l0 + 0]], iq3xxs_grid[q3[l0 + 1]]);
const uint32_t signs = unpack_ksigns(aux32 >> (7*l0/2));
const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l0/2)) & 0x7F));
const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
sumi = ggml_cuda_dp4a(grid_l, u0, sumi);
+106 -7
View File
@@ -1935,11 +1935,6 @@ static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * se
return false;
}
// TODO: add support for non-contigiuos tensors
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
return false;
}
return true;
}
@@ -1991,6 +1986,25 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses
return true;
}
static bool ggml_hexagon_supported_sum_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
const struct ggml_tensor * src0 = op->src[0];
const struct ggml_tensor * dst = op;
if (!hex_supported_src0_type(src0->type)) {
return false;
}
if (!hex_supported_dst_type(dst->type)) {
return false;
}
// TODO: add support for non-contigiuos tensors
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {
return false;
}
return true;
}
static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session * sess,
const struct ggml_tensor * op) {
const struct ggml_tensor * src0 = op->src[0];
@@ -2111,6 +2125,26 @@ static bool ggml_hexagon_supported_get_rows(const struct ggml_hexagon_session *
return true;
}
static bool ggml_hexagon_supported_argsort(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
const struct ggml_tensor * src0 = op->src[0]; // values
const struct ggml_tensor * dst = op; // indices
if (src0->type != GGML_TYPE_F32) {
return false;
}
if (dst->type != GGML_TYPE_I32) {
return false;
}
if (src0->ne[0] > (16*1024)) {
// reject tensors with huge rows for now
return false;
}
return true;
}
static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
const int32_t * op_params = &op->op_params[0];
@@ -2278,6 +2312,9 @@ static inline size_t init_binary_req(htp_general_req * req, dspqueue_buffer * bu
case GGML_OP_SUB:
req->op = HTP_OP_SUB;
break;
case GGML_OP_DIV:
req->op = HTP_OP_DIV;
break;
default:
GGML_ABORT("ggml-hex: binary : unsupported op: %d\n", t->op);
break;
@@ -2316,6 +2353,17 @@ static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer *
return n_bufs;
}
static inline size_t init_argsort_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
req->op = HTP_OP_ARGSORT;
memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
size_t n_bufs = 0;
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
return n_bufs;
}
template <bool _is_src0_constant>
static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
switch (t->op) {
@@ -2370,6 +2418,16 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf
supported = true;
break;
case GGML_OP_SQR:
req->op = HTP_OP_SQR;
supported = true;
break;
case GGML_OP_SQRT:
req->op = HTP_OP_SQRT;
supported = true;
break;
case GGML_OP_UNARY:
if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) {
req->op = HTP_OP_UNARY_SILU;
@@ -2387,6 +2445,9 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf
} else if (ggml_get_glu_op(t) == GGML_GLU_OP_SWIGLU_OAI) {
req->op = HTP_OP_GLU_SWIGLU_OAI;
supported = true;
} else if (ggml_get_glu_op(t) == GGML_GLU_OP_GEGLU) {
req->op = HTP_OP_GLU_GEGLU;
supported = true;
}
break;
@@ -2411,6 +2472,17 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf
return n_bufs;
}
static inline size_t init_sum_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
req->op = HTP_OP_SUM_ROWS;
size_t n_bufs = 0;
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
return n_bufs;
}
static inline size_t init_rope_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
req->op = HTP_OP_ROPE;
@@ -2519,6 +2591,7 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
case GGML_OP_MUL:
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_DIV:
ggml_hexagon_dispatch_op<init_binary_req<false>>(sess, node, flags);
break;
case GGML_OP_ADD_ID:
@@ -2528,6 +2601,13 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
case GGML_OP_SCALE:
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
break;
case GGML_OP_SQR:
case GGML_OP_SQRT:
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
break;
case GGML_OP_SUM_ROWS:
ggml_hexagon_dispatch_op<init_sum_rows_req>(sess, node, flags);
break;
case GGML_OP_UNARY:
if ((ggml_get_unary_op(node) == GGML_UNARY_OP_SILU) ||
(ggml_get_unary_op(node) == GGML_UNARY_OP_GELU)) {
@@ -2536,7 +2616,8 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
break;
case GGML_OP_GLU:
if ((ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU) ||
(ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI)) {
(ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI) ||
(ggml_get_glu_op(node) == GGML_GLU_OP_GEGLU)) {
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
}
break;
@@ -2564,6 +2645,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
ggml_hexagon_dispatch_op<init_cpy_req>(sess, node, flags);
break;
case GGML_OP_ARGSORT:
ggml_hexagon_dispatch_op<init_argsort_req>(sess, node, flags);
break;
default:
GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node));
}
@@ -2916,6 +3001,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
case GGML_OP_MUL:
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_DIV:
supp = ggml_hexagon_supported_binary(sess, op);
break;
@@ -2928,6 +3014,15 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
supp = ggml_hexagon_supported_unary(sess, op);
break;
case GGML_OP_SQR:
case GGML_OP_SQRT:
supp = ggml_hexagon_supported_unary(sess, op);
break;
case GGML_OP_SUM_ROWS:
supp = ggml_hexagon_supported_sum_rows(sess, op);
break;
case GGML_OP_SOFT_MAX:
supp = ggml_hexagon_supported_softmax(sess, op);
break;
@@ -2943,7 +3038,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
case GGML_OP_GLU:
{
const auto glu_op = ggml_get_glu_op(op);
if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI)) {
if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI) || (glu_op == GGML_GLU_OP_GEGLU)) {
supp = ggml_hexagon_supported_activations(sess, op);
}
break;
@@ -2968,6 +3063,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
supp = ggml_hexagon_supported_cpy(sess, op);
break;
case GGML_OP_ARGSORT:
supp = ggml_hexagon_supported_argsort(sess, op);
break;
default:
break;
}
+3
View File
@@ -6,6 +6,7 @@ include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
include_directories(
${HEXAGON_SDK_ROOT}/incs
${HEXAGON_SDK_ROOT}/incs/stddef
${CMAKE_CURRENT_SOURCE_DIR}/../../../include
${CMAKE_CURRENT_SOURCE_DIR}/../..
${CMAKE_CURRENT_SOURCE_DIR}/..
${CMAKE_CURRENT_SOURCE_DIR}
@@ -21,6 +22,7 @@ add_library(${HTP_LIB} SHARED
matmul-ops.c
binary-ops.c
unary-ops.c
sum-rows-ops.c
softmax-ops.c
act-ops.c
rope-ops.c
@@ -28,6 +30,7 @@ add_library(${HTP_LIB} SHARED
set-rows-ops.c
get-rows-ops.c
cpy-ops.c
argsort-ops.c
)
target_compile_definitions(${HTP_LIB} PRIVATE
+150 -2
View File
@@ -410,7 +410,7 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
// gelu = x * sigmoid(1.702 * x) // current implementation
hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0);
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
hvx_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
}
dma_queue_push_vtcm_to_ddr(dma_queue,
@@ -516,7 +516,7 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
// silu = x * sigmoid(x)
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0);
hvx_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
}
dma_queue_push_vtcm_to_ddr(dma_queue,
@@ -541,6 +541,143 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
static const float GELU_COEF_A = 0.044715f;
static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
static void glu_geglu_f32_per_thread(const struct htp_tensor * src0,
const struct htp_tensor * src1,
struct htp_tensor * dst,
const int32_t * op_params,
struct htp_spad * src0_spad,
struct htp_spad * src1_spad,
struct htp_spad * dst_spad,
uint32_t nth,
uint32_t ith,
uint32_t src0_nrows_per_thread,
dma_queue * dma_queue) {
htp_act_preamble3;
size_t src0_row_size = nb01;
size_t src1_row_size = nb11;
size_t dst_row_size = nb1;
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
// no work for this thread
if (src0_start_row >= src0_end_row) {
return;
}
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
uint8_t * restrict data_dst = (uint8_t *) dst->data;
const bool src1_valid = src1->ne[0];
const int nc = (src1_valid) ? ne00 : ne00 / 2;
if (!src1_valid) {
const int32_t swapped = op_params[1];
data_src1 = data_src0;
src1_row_size = src0_row_size;
const size_t nc_in_bytes = nc * SIZEOF_FP32;
data_src0 += swapped ? nc_in_bytes : 0;
data_src1 += swapped ? 0 : nc_in_bytes;
}
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
if (BLOCK == 0) {
FARF(ERROR,
"geglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
src0_spad->size_per_thread, src0_row_size_aligned);
return;
}
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
// Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
dma_queue_push_vtcm_to_ddr(dma_queue,
dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
dst_row_size, dst_row_size_aligned, 0);
dma_queue_push_ddr_to_vtcm(dma_queue,
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)),
src0_row_size_aligned, src0_row_size, block_size);
dma_queue_push_ddr_to_vtcm(dma_queue,
dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + (ir * src1_row_size)),
src1_row_size_aligned, src1_row_size, block_size);
}
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
float * dst_spad = (float *) dma_queue_pop(dma_queue).src;
float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
float * src1_spad = (float *) dma_queue_pop(dma_queue).dst;
for (uint32_t ib = 0; ib < block_size; ib++) {
const uint8_t * src0_spad_ptr = (const uint8_t *)(src0_spad + ib * (src0_row_size_aligned / sizeof(float)));
const uint8_t * src1_spad_ptr = (const uint8_t *)(src1_spad + ib * (src1_row_size_aligned / sizeof(float)));
uint8_t * dst_spad_ptr = (uint8_t *)(dst_spad + ib * (dst_row_size_aligned / sizeof(float)));
// geglu tanh implementation
// geglu(x, g) = gelu(x) * g
// gelu(x) = 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)))
hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, src0_spad_ptr, nc); // res = x*x
hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, GELU_COEF_A, nc); // res = res * GELU_COEF_A
hvx_add_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, 1.0f, nc); // res = res + 1.0f
hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, (const uint8_t *)dst_spad_ptr, nc); // res = res * x
hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t*)dst_spad_ptr, SQRT_2_OVER_PI, nc); // res = result * SQRT_2_OVER_PI
hvx_tanh_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, nc); // res = tanh(res)
hvx_add_scalar_f32_aa(dst_spad_ptr, (const uint8_t*)dst_spad_ptr, 1.0f, nc); // res = res + 1.0f
hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, (const uint8_t *)dst_spad_ptr, nc); // res = res * x
hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, 0.5f, nc); // res = res + 0.5f
hvx_mul_f32_aaa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, src1_spad_ptr, nc); // res = res * g
}
dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size,
dst_row_size_aligned, block_size);
// prefetch N+2 loop iteration if any
const uint32_t pref_block = (ir + BLOCK * 2);
if (pref_block < src0_end_row) {
const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)),
src0_row_size_aligned, src0_row_size, pref_block_size);
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src1_spad, data_src1 + (pref_block * src1_row_size)),
src1_row_size_aligned, src1_row_size, pref_block_size);
}
}
dma_queue_flush(dma_queue);
t2 = HAP_perf_get_qtimer_count();
FARF(HIGH, "geglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
static void unary_silu_f32(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = (struct htp_ops_context *) data;
unary_silu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
@@ -559,6 +696,12 @@ static void glu_swiglu_oai_f32(unsigned int n, unsigned int i, void * data) {
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
}
static void glu_geglu_f32(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = (struct htp_ops_context *) data;
glu_geglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
}
static int execute_op_activations_f32(struct htp_ops_context * octx) {
int err = HTP_STATUS_OK;
@@ -593,6 +736,11 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
act_op_func = unary_gelu_f32;
op_type = "gelu-f32";
break;
case HTP_OP_GLU_GEGLU:
act_op_func = glu_geglu_f32;
op_type = "geglu-f32";
break;
default:
FARF(ERROR, "Unsupported activations Op %u\n", octx->op);
return HTP_STATUS_NO_SUPPORT;
+281
View File
@@ -0,0 +1,281 @@
#include <string.h>
#include <stdlib.h>
#include <math.h>
#include <HAP_farf.h>
#include <HAP_perf.h>
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "ggml.h"
#include "hvx-utils.h"
#include "hex-dma.h"
#include "htp-ctx.h"
#include "htp-msg.h"
#include "htp-ops.h"
#ifndef MIN
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#endif
struct htp_argsort_context {
struct htp_ops_context * octx;
uint32_t nrows_per_thread;
};
static inline bool all_greater_f32(HVX_Vector x, HVX_Vector y)
{
const HVX_Vector one = Q6_V_vsplat_R(1);
const HVX_Vector zero = Q6_V_vzero();
HVX_VectorPred pred = Q6_Q_vcmp_gt_VsfVsf(x, y);
HVX_Vector matches = Q6_V_vmux_QVV(pred, one, zero);
HVX_Vector sum = hvx_vec_reduce_sum_i32(matches);
return hvx_vec_get_i32(sum) == 32;
}
// Sorts values and mirrors swaps to indices.
static void quicksort_values_indices_asc(float * values, int32_t * indices, int left, int right) {
if (left >= right) return;
int pivot_idx = (left + right) / 2;
float pivot = values[pivot_idx];
int i = left;
int j = right;
HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot);
while (i <= j) {
// Vectorized scan for i
while (i <= j) {
// Check if we have at least one full vector
if (i + 32 <= j) {
HVX_Vector vals_vec = *(HVX_UVector *)(values + i);
if (all_greater_f32(pivot_vec, vals_vec)) {
// If all elements are < pivot, we can skip this whole block
i += 32;
continue;
}
}
// Scalar fallback / cleanup
if (values[i] < pivot) {
i++;
} else {
break;
}
}
// Vectorized scan for j
while (i <= j) {
if (j - 32 >= i) {
// Load 32 elements ending at j.
// Since we want `values[j] > pivot`, let's load from j-31 to j.
HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31);
if (all_greater_f32(vals_vec, pivot_vec)) {
j -= 32;
continue;
}
}
if (values[j] > pivot) {
j--;
} else {
break;
}
}
if (i <= j) {
float tmp_val = values[i];
values[i] = values[j];
values[j] = tmp_val;
int32_t tmp_idx = indices[i];
indices[i] = indices[j];
indices[j] = tmp_idx;
i++;
j--;
}
}
if (left < j) quicksort_values_indices_asc(values, indices, left, j);
if (i < right) quicksort_values_indices_asc(values, indices, i, right);
}
static void quicksort_values_indices_desc(float * values, int32_t * indices, int left, int right) {
if (left >= right) return;
int pivot_idx = (left + right) / 2;
float pivot = values[pivot_idx];
int i = left;
int j = right;
HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot);
while (i <= j) {
// Vectorized scan for i (values[i] > pivot)
while (i <= j) {
if (i + 32 <= j) {
HVX_Vector vals_vec = *(HVX_UVector *)(values + i);
if (all_greater_f32(vals_vec, pivot_vec)) {
i += 32;
continue;
}
}
if (values[i] > pivot) {
i++;
} else {
break;
}
}
// Vectorized scan for j (values[j] < pivot)
while (i <= j) {
if (j - 32 >= i) {
HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31);
if (all_greater_f32(pivot_vec, vals_vec)) {
j -= 32;
continue;
}
}
if (values[j] < pivot) {
j--;
} else {
break;
}
}
if (i <= j) {
float tmp_val = values[i];
values[i] = values[j];
values[j] = tmp_val;
int32_t tmp_idx = indices[i];
indices[i] = indices[j];
indices[j] = tmp_idx;
i++;
j--;
}
}
if (left < j) quicksort_values_indices_desc(values, indices, left, j);
if (i < right) quicksort_values_indices_desc(values, indices, i, right);
}
static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) {
struct htp_argsort_context * actx = (struct htp_argsort_context *)data;
struct htp_ops_context * octx = actx->octx;
// Unpack context
const struct htp_tensor * src0 = &octx->src0;
const struct htp_tensor * dst = &octx->dst;
// Scratchpad memory
uint8_t * spad = octx->src0_spad.data + octx->src0_spad.size_per_thread * i;
// Dimensions
uint32_t ne00 = src0->ne[0];
uint32_t ne01 = src0->ne[1];
uint32_t ne02 = src0->ne[2];
uint32_t ne03 = src0->ne[3];
uint32_t nb01 = src0->nb[1];
//uint32_t nb02 = src0->nb[2];
//uint32_t nb03 = src0->nb[3];
uint32_t nb1 = dst->nb[1];
//uint32_t nb2 = dst->nb[2];
//uint32_t nb3 = dst->nb[3];
// Sort order
enum ggml_sort_order order = (enum ggml_sort_order) octx->op_params[0];
// Rows to process
uint32_t total_rows = ne01 * ne02 * ne03;
uint32_t rows_per_thread = actx->nrows_per_thread;
uint32_t start_row = rows_per_thread * i;
uint32_t end_row = MIN(start_row + rows_per_thread, total_rows);
// Scratchpad layout:
// We need space for one row of float data (values) and one row of int32 indices.
// values: ne00 * sizeof(float)
// indices: ne00 * sizeof(int32_t)
// Padded to 128 bytes.
size_t values_size = hex_round_up(ne00 * sizeof(float), 128);
float * values_buf = (float *) spad;
int32_t * indices_buf = (int32_t *) (spad + values_size);
for (uint32_t r = start_row; r < end_row; r++) {
uint32_t src_offset = r * nb01;
uint32_t dst_offset = r * nb1;
uint8_t * src_ptr = (uint8_t *) src0->data + src_offset;
uint8_t * dst_ptr = (uint8_t *) dst->data + dst_offset;
hex_l2fetch(src_ptr, ne00 * sizeof(float), ne00 * sizeof(float), 1);
hvx_copy_f32_au((uint8_t*)values_buf, src_ptr, ne00);
// Initialize indices
for (uint32_t j = 0; j < ne00; j++) {
indices_buf[j] = j;
}
// Sort values and mirror swaps to indices
if (order == GGML_SORT_ORDER_ASC) {
quicksort_values_indices_asc(values_buf, indices_buf, 0, ne00 - 1);
} else {
quicksort_values_indices_desc(values_buf, indices_buf, 0, ne00 - 1);
}
// Copy indices back to DDR
hvx_copy_f32_ua(dst_ptr, (const uint8_t *) indices_buf, ne00);
}
}
int op_argsort(struct htp_ops_context * octx) {
// Check supported types
if (octx->src0.type != HTP_TYPE_F32) {
return HTP_STATUS_NO_SUPPORT;
}
// Allocate scratchpad
// We need 1 row of float + 1 row of int32 per thread.
uint32_t ne00 = octx->src0.ne[0];
size_t values_size = hex_round_up(ne00 * sizeof(float), 128);
size_t indices_size = hex_round_up(ne00 * sizeof(int32_t), 128);
size_t spad_per_thread = values_size + indices_size;
// Make sure we round up to 256 for alignment requirements
spad_per_thread = hex_round_up(spad_per_thread, 256);
size_t total_spad_size = spad_per_thread * octx->n_threads;
if (octx->ctx->vtcm_size < total_spad_size) {
FARF(ERROR, "argsort: VTCM size too small. Needed %zu, have %zu", total_spad_size, octx->ctx->vtcm_size);
return HTP_STATUS_VTCM_TOO_SMALL;
}
octx->src0_spad.data = octx->ctx->vtcm_base;
octx->src0_spad.size = total_spad_size;
octx->src0_spad.size_per_thread = spad_per_thread;
FARF(HIGH, "argsort: %ux%ux%ux%u -> %ux%ux%ux%u (0x%x, 0x%x)",
octx->src0.ne[0], octx->src0.ne[1], octx->src0.ne[2], octx->src0.ne[3],
octx->dst.ne[0], octx->dst.ne[1], octx->dst.ne[2], octx->dst.ne[3],
octx->src0.data, octx->dst.data);
uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3];
uint32_t n_jobs = MIN(total_rows, octx->n_threads);
struct htp_argsort_context actx;
actx.octx = octx;
actx.nrows_per_thread = (total_rows + n_jobs - 1) / n_jobs;
// Run jobs
worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_jobs);
return HTP_STATUS_OK;
}
File diff suppressed because it is too large Load Diff
+254 -272
View File
@@ -17,121 +17,6 @@
#include "htp-msg.h"
#include "htp-ops.h"
static inline HVX_Vector hvx_load_f32_to_f16(const HVX_Vector * restrict src, const HVX_Vector zero) {
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(src[0], zero); // 32 elements
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(src[1], zero); // 32 elements
return Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
}
// Dot product of FP32 and FP16 vectors, accumulating to float
static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) {
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
const HVX_Vector zero = Q6_V_vsplat_R(0);
HVX_Vector rsum = Q6_V_vsplat_R(0);
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; i++) {
// Load y (fp32) and convert into fp16
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
// Load x (fp16)
HVX_Vector x_hf = vx[i];
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
}
if (nloe) {
// Load y (fp32) and convert into fp16
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
// Load x (fp16)
HVX_Vector x_hf = vx[i];
// Zero-out unused elements
// Note that we need to clear both x and y because they may contain NANs
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
x_hf = Q6_V_vand_QV(bmask, x_hf);
y_hf = Q6_V_vand_QV(bmask, y_hf);
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
}
rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum));
hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum));
}
// Dot product of FP32 and FP16 vectors, accumulating to float
static inline void hvx_dot_f32_f16_aa_rx2(float * restrict r,
const void * restrict y,
const void * restrict x0,
const void * restrict x1,
unsigned int n,
float s) {
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16
const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
const HVX_Vector zero = Q6_V_vsplat_R(0);
HVX_Vector rsum0 = Q6_V_vsplat_R(0);
HVX_Vector rsum1 = Q6_V_vsplat_R(0);
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
// Load y (fp32) and convert into fp16
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
// Load x (fp16)
HVX_Vector x0_hf = vx0[i];
HVX_Vector x1_hf = vx1[i];
HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
}
if (nloe) {
// Load y (fp32) and convert into fp16
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
// Load x (fp16)
HVX_Vector x0_hf = vx0[i];
HVX_Vector x1_hf = vx1[i];
// Zero-out unused elements
// Note that we need to clear both x and y because they may contain NANs
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
x0_hf = Q6_V_vand_QV(bmask, x0_hf);
x1_hf = Q6_V_vand_QV(bmask, x1_hf);
y_hf = Q6_V_vand_QV(bmask, y_hf);
HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
}
HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1));
hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum));
}
// Dot product of two F16 vectors, accumulating to float
static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) {
const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
@@ -140,8 +25,7 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
const HVX_Vector zero = Q6_V_vsplat_R(0);
HVX_Vector rsum = Q6_V_vsplat_R(0);
HVX_Vector rsum = Q6_V_vsplat_R(0);
uint32_t i = 0;
@@ -156,11 +40,10 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
}
if (nloe) {
HVX_Vector y_hf = vy[i];
// Load x (fp16) and zero-out unused elements
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
@@ -181,12 +64,11 @@ static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r,
const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
const HVX_Vector zero = Q6_V_vsplat_R(0);
HVX_Vector rsum0 = Q6_V_vsplat_R(0);
HVX_Vector rsum1 = Q6_V_vsplat_R(0);
HVX_Vector rsum0 = Q6_V_vsplat_R(0);
HVX_Vector rsum1 = Q6_V_vsplat_R(0);
uint32_t i = 0;
@@ -204,12 +86,11 @@ static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r,
}
if (nloe) {
HVX_Vector y_hf = vy[i];
// Load x (fp16) and zero-out unused elements
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
@@ -222,7 +103,7 @@ static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r,
hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum));
}
// MAD: y (F32) += x (F16) * s (float)
// MAD: y (F32) += x (F16) * s (F32)
static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) {
const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x;
HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
@@ -259,15 +140,125 @@ static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict
}
}
// MAD: y (F32) += x0 (F16) * s0 (F32) + x1 (F16) * s1 (F32)
static inline void hvx_mad_f32_f16_aa_rx2(float * restrict y,
const void * restrict x0,
const void * restrict x1,
float s0,
float s1,
int n) {
const HVX_Vector * restrict ptr_x0 = (const HVX_Vector *) x0;
const HVX_Vector * restrict ptr_x1 = (const HVX_Vector *) x1;
HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
HVX_Vector S0 = hvx_vec_splat_f16(s0);
HVX_Vector S1 = hvx_vec_splat_f16(s1);
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; ++i) {
// Multiply x * s -> pair of F32 vectors
HVX_VectorPair xs0_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x0[i]), S0);
HVX_VectorPair xs1_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x1[i]), S1);
HVX_Vector xs_p_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xs0_p), Q6_V_lo_W(xs1_p));
HVX_Vector xs_p_hi = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(xs0_p), Q6_V_hi_W(xs1_p));
ptr_y[i * 2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs_p_lo, ptr_y[i * 2]));
ptr_y[i * 2 + 1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs_p_hi, ptr_y[i * 2 + 1]));
}
if (nloe) {
HVX_VectorPair xs0_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x0[i]), S0);
HVX_VectorPair xs1_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x1[i]), S1);
HVX_Vector xs_p_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xs0_p), Q6_V_lo_W(xs1_p));
HVX_Vector xs = xs_p_lo;
i = 2 * i; // index for ptr_y
if (nloe >= 32) {
ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
nloe -= 32; ++i;
xs = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(xs0_p), Q6_V_hi_W(xs1_p));
}
if (nloe) {
HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
hvx_vec_store_a(&ptr_y[i], nloe * 4, xy);
}
}
}
#define FLASH_ATTN_BLOCK_SIZE 128
static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, int nth) {
struct htp_fa_context {
const struct htp_ops_context * octx;
struct fastdiv_values src0_div21;
struct fastdiv_values src0_div1;
struct fastdiv_values broadcast_rk2;
struct fastdiv_values broadcast_rk3;
struct fastdiv_values broadcast_rv2;
struct fastdiv_values broadcast_rv3;
struct fastdiv_values src3_div2;
struct fastdiv_values src3_div3;
float scale;
float max_bias;
float logit_softcap;
uint32_t n_head_log2;
float m0;
float m1;
uint32_t n_blocks;
size_t size_q_row_padded;
size_t size_k_row_padded;
size_t size_v_row_padded;
size_t size_k_block;
size_t size_v_block;
size_t size_m_block;
bool is_q_fp32;
};
static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, HVX_Vector vs) {
assert((size_t) dst % 128 == 0);
assert((size_t) src % 128 == 0);
const HVX_Vector * restrict vsrc = (const HVX_Vector * restrict) src;
HVX_Vector * restrict vdst = (HVX_Vector * restrict) dst;
const uint32_t nvec = n / VLEN_FP32;
const uint32_t nloe = n % VLEN_FP32;
uint32_t i = 0;
#pragma unroll(4)
for (; i < nvec; ++i) {
vdst[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs));
}
if (nloe) {
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
hvx_vec_store_a(&vdst[i], nloe * sizeof(float), Q6_Vsf_equals_Vqf32(v));
}
}
static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * data) {
struct htp_fa_context * factx = (struct htp_fa_context *) data;
const struct htp_ops_context * octx = factx->octx;
const struct htp_tensor * q = &octx->src0;
const struct htp_tensor * k = &octx->src1;
const struct htp_tensor * v = &octx->src2;
const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;
const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL;
struct htp_tensor * dst = &octx->dst;
const struct htp_tensor * dst = &octx->dst;
const uint32_t neq0 = q->ne[0];
const uint32_t neq1 = q->ne[1];
@@ -304,18 +295,6 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
const uint32_t nb2 = dst->nb[2];
const uint32_t nb3 = dst->nb[3];
float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;
memcpy(&scale, (float *) octx->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float));
memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));
if (logit_softcap != 0) {
scale /= logit_softcap;
}
// total rows in q
const uint32_t nr = neq1*neq2*neq3;
@@ -331,18 +310,8 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
const uint32_t DV = nev0;
const size_t size_q_row = DK * ((q->type == HTP_TYPE_F32) ? 4 : 2);
const size_t size_q_row_padded = hex_round_up(size_q_row, 128);
const size_t size_k_row = DK * sizeof(__fp16);
const size_t size_v_row = DV * sizeof(__fp16);
const size_t size_m_row = FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16); // Treat block as one row for mask
const size_t size_k_row_padded = hex_round_up(size_k_row, 128);
const size_t size_v_row_padded = hex_round_up(size_v_row, 128);
const size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
const size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
const size_t size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
// Scratchpad buffers for Q, K, V, Mask, and VKQ32 accumulator
uint8_t * spad_q = octx->src0_spad.data + octx->src0_spad.size_per_thread * ith;
@@ -351,31 +320,28 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
uint8_t * spad_m = octx->src3_spad.data + octx->src3_spad.size_per_thread * ith;
uint8_t * spad_a = octx->dst_spad.data + octx->dst_spad.size_per_thread * ith;
const uint32_t n_head = neq2;
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
const HVX_Vector logit_cap = hvx_vec_splat_f32(factx->logit_softcap);
for (uint32_t ir = ir0; ir < ir1; ++ir) {
const uint32_t iq3 = fastdiv(ir, &octx->src0_div21);
const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &octx->src0_div1);
const uint32_t iq3 = fastdiv(ir, &factx->src0_div21);
const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &factx->src0_div1);
const uint32_t iq1 = (ir - iq3*neq2*neq1 - iq2 * neq1);
const uint32_t ik3 = fastdiv(iq3, &octx->broadcast_rk3);
const uint32_t ik2 = fastdiv(iq2, &octx->broadcast_rk2);
const uint32_t ik3 = fastdiv(iq3, &factx->broadcast_rk3);
const uint32_t ik2 = fastdiv(iq2, &factx->broadcast_rk2);
const uint32_t iv3 = fastdiv(iq3, &octx->broadcast_rv3);
const uint32_t iv2 = fastdiv(iq2, &octx->broadcast_rv2);
const uint32_t iv3 = fastdiv(iq3, &factx->broadcast_rv3);
const uint32_t iv2 = fastdiv(iq2, &factx->broadcast_rv2);
// Fetch Q row
const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3);
dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), size_q_row_padded, nbq1, size_q_row, 1);
dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), factx->size_q_row_padded, nbq1, size_q_row, 1);
const uint32_t h = iq2; // head index
const float slope = (max_bias > 0.0f) ? (h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1)) : 1.0f;
const float slope = (factx->max_bias > 0.0f) ? (h < factx->n_head_log2 ? powf(factx->m0, h + 1) : powf(factx->m1, 2*(h - factx->n_head_log2) + 1)) : 1.0f;
float S = 0.0f; // sum
float M = -INFINITY; // maximum KQ value
HVX_Vector S_vec = hvx_vec_splat_f32(0.0f);
HVX_Vector M_vec = hvx_vec_splat_f32(-INFINITY);
// Clear accumulator
hvx_splat_f32_a(spad_a, 0, DV);
@@ -383,40 +349,42 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
const __fp16 * mp_base = NULL;
if (mask) {
const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &octx->src3_div2);
const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &octx->src3_div3);
const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &factx->src3_div2);
const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &factx->src3_div3);
mp_base = (const __fp16 *) ((const uint8_t *) mask->data + iq1*mask->nb[1] + im2*mask->nb[2] + im3*mask->nb[3]);
}
const uint32_t n_blocks = (nek1 + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE;
// Prefetch first two blocks
for (uint32_t ib = 0; ib < MIN(n_blocks, 2); ++ib) {
for (uint32_t ib = 0; ib < MIN(factx->n_blocks, 2); ++ib) {
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
// K
const uint8_t * k_src = (const uint8_t *) k->data + (ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
uint8_t * k_dst = spad_k + (ib % 2) * size_k_block;
dma_queue_push(dma, dma_make_ptr(k_dst, k_src), size_k_row_padded, nbk1, size_k_row, current_block_size);
uint8_t * k_dst = spad_k + (ib % 2) * factx->size_k_block;
dma_queue_push(dma, dma_make_ptr(k_dst, k_src), factx->size_k_row_padded, nbk1, size_k_row, current_block_size);
// V
const uint8_t * v_src = (const uint8_t *) v->data + (ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
uint8_t * v_dst = spad_v + (ib % 2) * size_v_block;
dma_queue_push(dma, dma_make_ptr(v_dst, v_src), size_v_row_padded, nbv1, size_v_row, current_block_size);
uint8_t * v_dst = spad_v + (ib % 2) * factx->size_v_block;
dma_queue_push(dma, dma_make_ptr(v_dst, v_src), factx->size_v_row_padded, nbv1, size_v_row, current_block_size);
// Mask
if (mask) {
const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start);
uint8_t * m_dst = spad_m + (ib % 2) * size_m_block;
uint8_t * m_dst = spad_m + (ib % 2) * factx->size_m_block;
// Mask is 1D contiguous for this row
dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
}
}
const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
if (factx->is_q_fp32) {
hvx_copy_f16_f32_aa(q_ptr_vtcm, q_ptr_vtcm, DK); // inplace convert f32 to f16
}
for (uint32_t ib = 0; ib < n_blocks; ++ib) {
const HVX_Vector slope_vec = hvx_vec_splat_f16(slope);
for (uint32_t ib = 0; ib < factx->n_blocks; ++ib) {
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
@@ -428,8 +396,6 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
// Inner loop processing the block from VTCM
uint32_t ic = 0;
const bool is_q_fp32 = (q->type == HTP_TYPE_F32);
// Process in blocks of 32 (VLEN_FP32)
static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 <= 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage");
HVX_Vector_x4 scores_x4;
@@ -437,22 +403,18 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) {
// 1. Compute scores
float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
for (int j = 0; j < VLEN_FP32; j += 2) {
for (uint32_t j = 0; j < VLEN_FP32; j += 2) {
const uint32_t cur_ic = ic + j;
const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
if (is_q_fp32) {
hvx_dot_f32_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
} else {
hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
}
const uint8_t * k_ptr = k_base + cur_ic * factx->size_k_row_padded;
hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + factx->size_k_row_padded, DK, factx->scale);
}
HVX_Vector scores = *(HVX_Vector *) scores_arr;
// 2. Softcap
if (logit_softcap != 0.0f) {
if (factx->logit_softcap != 0.0f) {
scores = hvx_vec_tanh_f32(scores);
scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_f32(logit_softcap));
scores = Q6_Vqf32_vmpy_VsfVsf(scores, logit_cap);
scores = Q6_Vsf_equals_Vqf32(scores);
}
@@ -460,70 +422,59 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
if (mask) {
const __fp16 * mp = m_base + ic;
HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp;
HVX_Vector one_f16 = Q6_Vh_vsplat_R(0x3c00);
HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), one_f16);
HVX_Vector m_vals_f32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_f32_pair));
HVX_Vector slope_vec = hvx_vec_splat_f32(slope);
HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_f32, slope_vec);
scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val));
HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec);
HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair);
scores = Q6_Vqf32_vadd_Vqf32Vsf(add_val, scores);
scores = Q6_Vsf_equals_Vqf32(scores);
}
scores_x4.v[iv] = scores;
v_max = Q6_Vsf_vmax_VsfVsf(scores, v_max);
v_max = hvx_vec_reduce_max2_f32(scores, v_max); // All lanes have block max
}
{
// 4. Online Softmax Update
v_max = hvx_vec_reduce_max_f32(v_max);
float m_block = hvx_vec_get_f32(v_max);
float M_old = M;
float M_new = (m_block > M) ? m_block : M;
M = M_new;
HVX_Vector M_new_vec = Q6_Vsf_vmax_VsfVsf(v_max, M_vec);
HVX_Vector diff_vec = Q6_Vqf32_vsub_VsfVsf(M_vec, M_new_vec);
HVX_Vector ms_vec = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(diff_vec));
M_vec = M_new_vec;
const float ms = expf(M_old - M_new);
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
HVX_Vector M_new_vec = hvx_vec_splat_f32(M_new);
HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f);
for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) {
HVX_Vector scores = scores_x4.v[iv];
HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_vec);
HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted));
p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P));
// 5. Accumulate V
float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
*(HVX_Vector*)p_arr = P;
*(HVX_Vector *) p_arr = P;
for (int j = 0; j < VLEN_FP32; ++j) {
const uint32_t cur_ic = ic2 + j;
const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]);
for (uint32_t j = 0; j < VLEN_FP32; j += 2) {
const uint32_t cur_ic = ic2 + j;
const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded;
hvx_mad_f32_f16_aa_rx2(VKQ32, v_ptr, v_ptr + factx->size_v_row_padded, p_arr[j], p_arr[j + 1], DV);
}
}
p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec);
S = S * ms + hvx_vec_get_f32(p_sum_vec);
S_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(S_vec, ms_vec)), p_sum_vec));
}
// Sync scalars for leftover/next block if needed
float M = hvx_vec_get_f32(M_vec);
float S = hvx_vec_get_f32(S_vec);
// Leftover
for (; ic < current_block_size; ++ic) {
float s_val;
const uint8_t * k_ptr = k_base + ic * size_k_row_padded;
if (is_q_fp32) {
hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
} else {
hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
}
if (logit_softcap != 0.0f) {
s_val = logit_softcap * tanhf(s_val);
const uint8_t * k_ptr = k_base + ic * factx->size_k_row_padded;
hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, factx->scale);
if (factx->logit_softcap != 0.0f) {
s_val = factx->logit_softcap * tanhf(s_val);
}
if (mask) {
@@ -532,37 +483,42 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
}
const float Mold = M;
float ms = 1.0f;
float vs = 1.0f;
if (s_val > M) {
M = s_val;
ms = expf(Mold - M);
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
HVX_Vector diff_vec = hvx_vec_splat_f32(Mold - M);
HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec);
hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
float ms = hvx_vec_get_f32(ms_vec);
S = S * ms + vs;
} else {
vs = expf(s_val - M);
HVX_Vector diff_vec = hvx_vec_splat_f32(s_val - M);
vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec));
S += vs;
}
const uint8_t * v_ptr = v_base + ic * size_v_row_padded;
const uint8_t * v_ptr = v_base + ic * factx->size_v_row_padded;
hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, vs);
S = S * ms + vs;
}
M_vec = hvx_vec_splat_f32(M);
S_vec = hvx_vec_splat_f32(S);
// Issue DMA for next+1 block (if exists)
if (ib + 2 < n_blocks) {
if (ib + 2 < factx->n_blocks) {
const uint32_t next_ib = ib + 2;
const uint32_t next_ic_start = next_ib * FLASH_ATTN_BLOCK_SIZE;
const uint32_t next_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - next_ic_start);
// K
const uint8_t * k_src = (const uint8_t *) k->data + (next_ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
dma_queue_push(dma, dma_make_ptr(k_base, k_src), size_k_row_padded, nbk1, size_k_row, next_block_size);
dma_queue_push(dma, dma_make_ptr(k_base, k_src), factx->size_k_row_padded, nbk1, size_k_row, next_block_size);
// V
const uint8_t * v_src = (const uint8_t *) v->data + (next_ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
dma_queue_push(dma, dma_make_ptr(v_base, v_src), size_v_row_padded, nbv1, size_v_row, next_block_size);
dma_queue_push(dma, dma_make_ptr(v_base, v_src), factx->size_v_row_padded, nbv1, size_v_row, next_block_size);
// Mask
if (mask) {
@@ -573,20 +529,26 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
}
// sinks
float M = hvx_vec_get_f32(M_vec);
float S = hvx_vec_get_f32(S_vec);
if (sinks) {
const float s = ((float *)((char *) sinks->data))[h];
float ms = 1.0f;
float vs = 1.0f;
if (s > M) {
ms = expf(M - s);
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
} else {
vs = expf(s - M);
}
HVX_Vector diff_vec = hvx_vec_splat_f32(M - s);
HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec);
hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
S = S * ms + vs;
float ms = hvx_vec_get_f32(ms_vec);
S = S * ms + vs;
} else {
HVX_Vector diff_vec = hvx_vec_splat_f32(s - M);
vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec));
S += vs;
}
}
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
@@ -609,53 +571,73 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
}
}
static void htp_flash_attn_ext_job(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = data;
flash_attn_ext_f16_thread(octx, i, n);
}
int op_flash_attn_ext(struct htp_ops_context * octx) {
const struct htp_tensor * q = &octx->src0;
const struct htp_tensor * k = &octx->src1;
const struct htp_tensor * v = &octx->src2;
const struct htp_tensor * mask = (octx->src3.type != HTP_TYPE_COUNT) ? &octx->src3 : NULL;
struct htp_tensor * dst = &octx->dst;
const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;
const struct htp_tensor * dst = &octx->dst;
// Check support
if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) ||
k->type != HTP_TYPE_F16 ||
v->type != HTP_TYPE_F16) {
if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) || k->type != HTP_TYPE_F16 || v->type != HTP_TYPE_F16) {
return HTP_STATUS_NO_SUPPORT;
}
octx->src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
octx->src0_div1 = init_fastdiv_values(q->ne[1]);
struct htp_fa_context factx;
factx.octx = octx;
octx->broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
octx->broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
octx->broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
octx->broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
factx.src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
factx.src0_div1 = init_fastdiv_values(q->ne[1]);
factx.broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
factx.broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
factx.broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
factx.broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
if (mask) {
octx->src3_div2 = init_fastdiv_values(mask->ne[2]);
octx->src3_div3 = init_fastdiv_values(mask->ne[3]);
factx.src3_div2 = init_fastdiv_values(mask->ne[2]);
factx.src3_div3 = init_fastdiv_values(mask->ne[3]);
}
size_t size_q_row_padded = hex_round_up(q->ne[0] * (q->type == HTP_TYPE_F32 ? 4 : 2), 128);
size_t size_k_row_padded = hex_round_up(k->ne[0] * sizeof(__fp16), 128);
size_t size_v_row_padded = hex_round_up(v->ne[0] * sizeof(__fp16), 128);
factx.is_q_fp32 = (q->type == HTP_TYPE_F32);
factx.size_q_row_padded = hex_round_up(q->ne[0] * (factx.is_q_fp32 ? 4 : 2), 128);
factx.size_k_row_padded = hex_round_up(k->ne[0] * sizeof(__fp16), 128);
factx.size_v_row_padded = hex_round_up(v->ne[0] * sizeof(__fp16), 128);
size_t size_q_block = size_q_row_padded * 1; // single row for now
size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
size_t size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
size_t size_q_block = factx.size_q_row_padded * 1; // single row for now
factx.size_k_block = factx.size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
factx.size_v_block = factx.size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
factx.size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
factx.n_blocks = (k->ne[1] + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE;
float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;
memcpy(&scale, (float *) octx->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float));
memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));
if (logit_softcap != 0.0f) {
scale /= logit_softcap;
}
factx.scale = scale;
factx.max_bias = max_bias;
factx.logit_softcap = logit_softcap;
uint32_t n_head = q->ne[2];
factx.n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
factx.m0 = powf(2.0f, -(max_bias ) / factx.n_head_log2);
factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2);
size_t size_vkq_acc = hex_round_up(v->ne[0] * sizeof(float), 128); // VKQ32
octx->src0_spad.size_per_thread = size_q_block * 1;
octx->src1_spad.size_per_thread = size_k_block * 2;
octx->src2_spad.size_per_thread = size_v_block * 2;
octx->src3_spad.size_per_thread = mask ? size_m_block * 2 : 0;
octx->src1_spad.size_per_thread = factx.size_k_block * 2;
octx->src2_spad.size_per_thread = factx.size_v_block * 2;
octx->src3_spad.size_per_thread = mask ? factx.size_m_block * 2 : 0;
octx->dst_spad.size_per_thread = size_vkq_acc;
octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
@@ -677,7 +659,7 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size;
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
worker_pool_run_func(octx->ctx->worker_pool, htp_flash_attn_ext_job, octx, octx->n_threads);
worker_pool_run_func(octx->ctx->worker_pool, flash_attn_ext_f16_thread, &factx, octx->n_threads);
}
return HTP_STATUS_OK;
+26 -38
View File
@@ -42,32 +42,36 @@ enum htp_data_type {
HTP_TYPE_COUNT
};
// These values are manually translated over to HTP
// !!!! DO NOT ALTER THE ORDER OF THE FIRST FOUR ENUMS !!!!
// Do not reorder first 4 (used as an index)
enum htp_op {
HTP_OP_MUL = 0,
HTP_OP_ADD = 1,
HTP_OP_SUB = 2,
HTP_OP_DIV = 3,
HTP_OP_MUL_MAT = 4,
HTP_OP_MUL_MAT_ID = 5,
HTP_OP_RMS_NORM = 6,
HTP_OP_UNARY_SILU = 7,
HTP_OP_UNARY_GELU = 8,
HTP_OP_GLU_SWIGLU = 9,
HTP_OP_GLU_SWIGLU_OAI = 10,
HTP_OP_SOFTMAX = 11,
HTP_OP_ADD_ID = 12,
HTP_OP_ROPE = 13,
HTP_OP_FLASH_ATTN_EXT = 14,
HTP_OP_SET_ROWS = 15,
HTP_OP_SCALE = 16,
HTP_OP_GET_ROWS = 17,
HTP_OP_CPY = 18,
HTP_OP_MUL = 0,
HTP_OP_ADD = 1,
HTP_OP_SUB = 2,
HTP_OP_DIV = 3,
HTP_OP_MUL_MAT,
HTP_OP_MUL_MAT_ID,
HTP_OP_RMS_NORM,
HTP_OP_UNARY_SILU,
HTP_OP_UNARY_GELU,
HTP_OP_GLU_SWIGLU,
HTP_OP_GLU_SWIGLU_OAI,
HTP_OP_GLU_GEGLU,
HTP_OP_SOFTMAX,
HTP_OP_ADD_ID,
HTP_OP_ROPE,
HTP_OP_FLASH_ATTN_EXT,
HTP_OP_SET_ROWS,
HTP_OP_GET_ROWS,
HTP_OP_SCALE,
HTP_OP_CPY,
HTP_OP_ARGSORT,
HTP_OP_SQR,
HTP_OP_SQRT,
HTP_OP_SUM_ROWS,
INVALID
};
static inline size_t htp_type_block_size(uint32_t t) {
static inline size_t htp_t_block_size(uint32_t t) {
switch (t) {
case HTP_TYPE_F32:
return 1;
@@ -103,22 +107,6 @@ static inline size_t htp_type_nbytes(uint32_t t) {
return 0;
}
static const char * htp_type_name(uint32_t t) {
switch (t) {
case HTP_TYPE_F32:
return "fp32";
case HTP_TYPE_F16:
return "fp16";
case HTP_TYPE_Q4_0:
return "q4_0";
case HTP_TYPE_Q8_0:
return "q8_0";
case HTP_TYPE_MXFP4:
return "mxfp4";
}
return 0;
}
// Internal types
#define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128)
#define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks
+2 -13
View File
@@ -64,25 +64,12 @@ struct htp_ops_context {
struct fastdiv_values broadcast_rv2;
struct fastdiv_values broadcast_rv3;
struct fastdiv_values mm_div_ne12_ne1; // fastdiv values for ne12 * ne1
struct fastdiv_values mm_div_ne1; // fastdiv values for ne1
struct fastdiv_values mm_div_r2; // fastdiv values for ne12 / ne02
struct fastdiv_values mm_div_r3; // fastdiv values for ne13 / ne03
struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12
struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11
struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10
struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11
struct fastdiv_values cpy_div_ne01; // fastdiv values for ne01
struct fastdiv_values cpy_div_ne02; // fastdiv values for ne02
struct fastdiv_values cpy_div_ne03; // fastdiv values for ne03
struct fastdiv_values cpy_rshp_div_n0; // fastdiv values for ne00
struct fastdiv_values cpy_rshp_div_n1n0; // fastdiv values for ne00*ne01
struct fastdiv_values cpy_rshp_div_n2n1n0; // fastdiv values for ne00*ne01*ne02
uint32_t flags;
};
@@ -90,6 +77,7 @@ int op_matmul(struct htp_ops_context * octx);
int op_matmul_id(struct htp_ops_context * octx);
int op_binary(struct htp_ops_context * octx);
int op_unary(struct htp_ops_context * octx);
int op_sum_rows(struct htp_ops_context * octx);
int op_activations(struct htp_ops_context * octx);
int op_softmax(struct htp_ops_context * octx);
int op_add_id(struct htp_ops_context * octx);
@@ -98,5 +86,6 @@ int op_flash_attn_ext(struct htp_ops_context * octx);
int op_set_rows(struct htp_ops_context * octx);
int op_get_rows(struct htp_ops_context * octx);
int op_cpy(struct htp_ops_context * octx);
int op_argsort(struct htp_ops_context * octx);
#endif /* HTP_OPS_H */
+129 -116
View File
@@ -46,127 +46,76 @@
#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
#endif
// ADD variants
// Generic macro to define alignment permutations for an op
#define DEFINE_HVX_BINARY_OP_VARIANTS(OP_NAME, OP_MACRO) \
static inline void OP_NAME##_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) dst % 128 == 0); \
assert((uintptr_t) src0 % 128 == 0); \
assert((uintptr_t) src1 % 128 == 0); \
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \
} \
static inline void OP_NAME##_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) dst % 128 == 0); \
assert((uintptr_t) src0 % 128 == 0); \
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \
} \
static inline void OP_NAME##_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) dst % 128 == 0); \
assert((uintptr_t) src1 % 128 == 0); \
hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \
} \
static inline void OP_NAME##_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) dst % 128 == 0); \
hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \
} \
static inline void OP_NAME##_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) src0 % 128 == 0); \
assert((uintptr_t) src1 % 128 == 0); \
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \
} \
static inline void OP_NAME##_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) src0 % 128 == 0); \
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \
} \
static inline void OP_NAME##_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) src1 % 128 == 0); \
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \
} \
static inline void OP_NAME##_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \
} \
static inline void hvx_add_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src0 % 128 == 0);
assert((unsigned long) src1 % 128 == 0);
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_ADD);
DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f32, HVX_OP_ADD)
DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f32, HVX_OP_SUB)
DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f32, HVX_OP_MUL)
// Dispatcher logic
#define HVX_BINARY_DISPATCHER(OP_NAME) \
static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { \
if (hex_is_aligned((void *) dst, 128)) { \
if (hex_is_aligned((void *) src0, 128)) { \
if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aaa(dst, src0, src1, num_elems); \
else OP_NAME##_aau(dst, src0, src1, num_elems); \
} else { \
if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aua(dst, src0, src1, num_elems); \
else OP_NAME##_auu(dst, src0, src1, num_elems); \
} \
} else { \
if (hex_is_aligned((void *) src0, 128)) { \
if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uaa(dst, src0, src1, num_elems); \
else OP_NAME##_uau(dst, src0, src1, num_elems); \
} else { \
if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uua(dst, src0, src1, num_elems); \
else OP_NAME##_uuu(dst, src0, src1, num_elems); \
} \
} \
}
static inline void hvx_add_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src0 % 128 == 0);
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_ADD);
}
static inline void hvx_add_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((unsigned long) src0 % 128 == 0);
assert((unsigned long) src1 % 128 == 0);
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_ADD);
}
static inline void hvx_add_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_ADD);
}
// SUB variants
static inline void hvx_sub_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src0 % 128 == 0);
assert((unsigned long) src1 % 128 == 0);
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_SUB);
}
static inline void hvx_sub_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src0 % 128 == 0);
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_SUB);
}
static inline void hvx_sub_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((unsigned long) src0 % 128 == 0);
assert((unsigned long) src1 % 128 == 0);
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_SUB);
}
static inline void hvx_sub_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_SUB);
}
// MUL variants
static inline void hvx_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src0 % 128 == 0);
assert((unsigned long) src1 % 128 == 0);
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MUL);
}
static inline void hvx_mul_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src0 % 128 == 0);
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MUL);
}
static inline void hvx_mul_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((unsigned long) src0 % 128 == 0);
assert((unsigned long) src1 % 128 == 0);
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_MUL);
}
static inline void hvx_mul_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MUL);
}
// Dispatchers
static inline void hvx_add_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) {
if (hex_is_aligned((void *) src1, 128)) {
hvx_add_f32_aa(dst, src0, src1, num_elems);
} else {
hvx_add_f32_au(dst, src0, src1, num_elems);
}
} else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) {
hvx_add_f32_ua(dst, src0, src1, num_elems);
} else {
hvx_add_f32_uu(dst, src0, src1, num_elems);
}
}
static inline void hvx_sub_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) {
if (hex_is_aligned((void *) src1, 128)) {
hvx_sub_f32_aa(dst, src0, src1, num_elems);
} else {
hvx_sub_f32_au(dst, src0, src1, num_elems);
}
} else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) {
hvx_sub_f32_ua(dst, src0, src1, num_elems);
} else {
hvx_sub_f32_uu(dst, src0, src1, num_elems);
}
}
static inline void hvx_mul_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) {
if (hex_is_aligned((void *) src1, 128)) {
hvx_mul_f32_aa(dst, src0, src1, num_elems);
} else {
hvx_mul_f32_au(dst, src0, src1, num_elems);
}
} else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) {
hvx_mul_f32_ua(dst, src0, src1, num_elems);
} else {
hvx_mul_f32_uu(dst, src0, src1, num_elems);
}
}
HVX_BINARY_DISPATCHER(hvx_add_f32)
HVX_BINARY_DISPATCHER(hvx_sub_f32)
HVX_BINARY_DISPATCHER(hvx_mul_f32)
// Mul-Mul Optimized
static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint8_t * restrict src2, const uint32_t num_elems) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src0 % 128 == 0);
@@ -443,6 +392,68 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t *
}
}
//
// Square
//
#define hvx_sqr_loop_body(dst_type, src_type, vec_store) \
do { \
dst_type * restrict vdst = (dst_type *) dst; \
src_type * restrict vsrc = (src_type *) src; \
\
const uint32_t elem_size = sizeof(float); \
const uint32_t epv = 128 / elem_size; \
const uint32_t nvec = n / epv; \
const uint32_t nloe = n % epv; \
\
uint32_t i = 0; \
\
_Pragma("unroll(4)") \
for (; i < nvec; i++) { \
vdst[i] = HVX_OP_MUL(vsrc[i], vsrc[i]); \
} \
if (nloe) { \
HVX_Vector v = HVX_OP_MUL(vsrc[i], vsrc[i]); \
vec_store((void *) &vdst[i], nloe * elem_size, v); \
} \
} while(0)
static inline void hvx_sqr_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src % 128 == 0);
hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
}
static inline void hvx_sqr_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
}
static inline void hvx_sqr_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) src % 128 == 0);
hvx_sqr_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
}
static inline void hvx_sqr_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
hvx_sqr_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
}
static inline void hvx_sqr_f32(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t num_elems) {
if (hex_is_aligned((void *) dst, 128)) {
if (hex_is_aligned((void *) src, 128)) {
hvx_sqr_f32_aa(dst, src, num_elems);
} else {
hvx_sqr_f32_au(dst, src, num_elems);
}
} else {
if (hex_is_aligned((void *) src, 128)) {
hvx_sqr_f32_ua(dst, src, num_elems);
} else {
hvx_sqr_f32_uu(dst, src, num_elems);
}
}
}
#undef HVX_OP_ADD
#undef HVX_OP_SUB
#undef HVX_OP_MUL
@@ -453,5 +464,7 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t *
#undef hvx_scalar_loop_body
#undef HVX_OP_MIN_SCALAR
#undef HVX_OP_CLAMP_SCALAR
#undef DEFINE_HVX_BINARY_OP_VARIANTS
#undef HVX_BINARY_DISPATCHER
#endif // HVX_ARITH_H
+6
View File
@@ -66,6 +66,12 @@ static inline float hvx_vec_get_f32(HVX_Vector v) {
return x;
}
static inline int32_t hvx_vec_get_i32(HVX_Vector v) {
int32_t __attribute__((aligned(128))) x;
hvx_vec_store_a(&x, 4, v);
return x;
}
static inline HVX_Vector hvx_vec_abs_f16(HVX_Vector v) {
// abs by clearing the fp16 sign bit
HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff);
-2
View File
@@ -136,8 +136,6 @@ static inline void hvx_copy_f32_uu(uint8_t * restrict dst, const uint8_t * restr
dst_type * restrict vdst = (dst_type *) dst; \
src_type * restrict vsrc = (src_type *) src; \
\
const HVX_Vector zero = Q6_V_vsplat_R(0); \
\
const uint32_t elem_size = sizeof(__fp16); \
const uint32_t epv = 128 / elem_size; \
const uint32_t nvec = n / epv; \
+116
View File
@@ -0,0 +1,116 @@
#ifndef HVX_DIV_H
#define HVX_DIV_H
#include <HAP_farf.h>
#include <math.h>
#include <string.h>
#include <assert.h>
#include <stddef.h>
#include <stdint.h>
#include "hvx-base.h"
#include "hex-utils.h"
#include "hvx-inverse.h"
#include "hvx-arith.h"
#if __HVX_ARCH__ < 79
#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
#else
#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
#endif
#define hvx_div_f32_loop_body(dst_type, src0_type, src1_type, vec_store) \
do { \
dst_type * restrict vdst = (dst_type *) dst; \
src0_type * restrict vsrc0 = (src0_type *) src0; \
src1_type * restrict vsrc1 = (src1_type *) src1; \
\
const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \
\
const uint32_t nvec = n / VLEN_FP32; \
const uint32_t nloe = n % VLEN_FP32; \
\
uint32_t i = 0; \
\
_Pragma("unroll(4)") \
for (; i < nvec; i++) { \
HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \
HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1); \
vdst[i] = res; \
} \
if (nloe) { \
HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \
HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1); \
vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, res); \
} \
} while(0)
// 3-letter suffix variants
static inline void hvx_div_f32_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((uintptr_t) dst % 128 == 0);
assert((uintptr_t) src0 % 128 == 0);
assert((uintptr_t) src1 % 128 == 0);
hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a);
}
static inline void hvx_div_f32_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((uintptr_t) dst % 128 == 0);
assert((uintptr_t) src0 % 128 == 0);
hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a);
}
static inline void hvx_div_f32_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((uintptr_t) dst % 128 == 0);
assert((uintptr_t) src1 % 128 == 0);
hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a);
}
static inline void hvx_div_f32_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((uintptr_t) dst % 128 == 0);
hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a);
}
static inline void hvx_div_f32_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((uintptr_t) src0 % 128 == 0);
assert((uintptr_t) src1 % 128 == 0);
hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u);
}
static inline void hvx_div_f32_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((uintptr_t) src0 % 128 == 0);
hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u);
}
static inline void hvx_div_f32_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((uintptr_t) src1 % 128 == 0);
hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u);
}
static inline void hvx_div_f32_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u);
}
static inline void hvx_div_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
if (hex_is_aligned((void *) dst, 128)) {
if (hex_is_aligned((void *) src0, 128)) {
if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aaa(dst, src0, src1, num_elems);
else hvx_div_f32_aau(dst, src0, src1, num_elems);
} else {
if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aua(dst, src0, src1, num_elems);
else hvx_div_f32_auu(dst, src0, src1, num_elems);
}
} else {
if (hex_is_aligned((void *) src0, 128)) {
if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uaa(dst, src0, src1, num_elems);
else hvx_div_f32_uau(dst, src0, src1, num_elems);
} else {
if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uua(dst, src0, src1, num_elems);
else hvx_div_f32_uuu(dst, src0, src1, num_elems);
}
}
}
#undef HVX_OP_MUL
#endif // HVX_DIV_H
+27
View File
@@ -91,6 +91,27 @@ static inline HVX_Vector hvx_vec_tanh_f32(HVX_Vector x) {
} \
} while(0)
#define hvx_tanh_loop_body(dst_type, src_type, vec_store) \
do { \
dst_type * restrict vdst = (dst_type *) dst; \
src_type * restrict vsrc = (src_type *) src; \
\
const uint32_t epv = 128 / sizeof(float); \
const uint32_t nvec = n / epv; \
const uint32_t nloe = n % epv; \
\
uint32_t i = 0; \
\
_Pragma("unroll(4)") \
for (; i < nvec; i++) { \
vdst[i] = hvx_vec_tanh_f32(vsrc[i]); \
} \
if (nloe) { \
HVX_Vector tmp = hvx_vec_tanh_f32(vsrc[i]); \
vec_store((void *) &vdst[i], nloe * sizeof(float), tmp); \
} \
} while(0)
static inline void hvx_sigmoid_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src % 128 == 0);
@@ -111,4 +132,10 @@ static inline void hvx_sigmoid_f32_uu(uint8_t * restrict dst, const uint8_t * re
hvx_sigmoid_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
}
static inline void hvx_tanh_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src % 128 == 0);
hvx_tanh_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
}
#endif /* HVX_SIGMOID_H */
+67 -1
View File
@@ -12,11 +12,17 @@
#define RSQRT_ONE_HALF 0x3f000000 // 0.5
#define RSQRT_THREE_HALVES 0x3fc00000 // 1.5
#if __HVX_ARCH__ < 79
#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
#else
#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
#endif
static inline HVX_Vector hvx_vec_rsqrt_f32(HVX_Vector in_vec) {
//Algorithm :
// x2 = input*0.5
// y = * (long *) &input
// y = 0x5f3759df - (y>>2)
// y = 0x5f3759df - (y>>1)
// y = y*(threehalfs - x2*y*y)
HVX_Vector rsqrtconst = Q6_V_vsplat_R(RSQRT_CONST);
@@ -57,4 +63,64 @@ static inline HVX_Vector hvx_vec_rsqrt_f32(HVX_Vector in_vec) {
return Q6_Vsf_equals_Vqf32(temp);
}
// Compute sqrt(x) as x*inv_sqrt(x)
#define hvx_sqrt_f32_loop_body(dst_type, src_type, vec_store) \
do { \
dst_type * restrict vdst = (dst_type *) dst; \
src_type * restrict vsrc = (src_type *) src; \
\
const uint32_t nvec = n / VLEN_FP32; \
const uint32_t nloe = n % VLEN_FP32; \
\
uint32_t i = 0; \
\
_Pragma("unroll(4)") \
for (; i < nvec; i++) { \
HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]); \
HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]); \
vdst[i] = sqrt_res; \
} \
if (nloe) { \
HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]); \
HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]); \
vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, sqrt_res); \
} \
} while(0)
static inline void hvx_sqrt_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src % 128 == 0);
hvx_sqrt_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
}
static inline void hvx_sqrt_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
hvx_sqrt_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
}
static inline void hvx_sqrt_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) src % 128 == 0);
hvx_sqrt_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
}
static inline void hvx_sqrt_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
hvx_sqrt_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
}
static inline void hvx_sqrt_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int num_elems) {
if ((unsigned long) dst % 128 == 0) {
if ((unsigned long) src % 128 == 0) {
hvx_sqrt_f32_aa(dst, src, num_elems);
} else {
hvx_sqrt_f32_au(dst, src, num_elems);
}
} else {
if ((unsigned long) src % 128 == 0) {
hvx_sqrt_f32_ua(dst, src, num_elems);
} else {
hvx_sqrt_f32_uu(dst, src, num_elems);
}
}
}
#endif /* HVX_SQRT_H */
+1
View File
@@ -12,6 +12,7 @@
#include "hvx-sigmoid.h"
#include "hvx-sqrt.h"
#include "hvx-arith.h"
#include "hvx-div.h"
#include "hvx-base.h"
#endif /* HVX_UTILS_H */
+108 -1
View File
@@ -189,7 +189,7 @@ static int vtcm_release_callback(unsigned int rctx, void * state) {
// otherwise we'll release it once we're done with the current Op.
if (ctx->vtcm_inuse) {
ctx->vtcm_needs_release = false;
ctx->vtcm_needs_release = true;
return 0;
}
@@ -440,6 +440,45 @@ static void proc_matmul_req(struct htp_context * ctx,
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
static void proc_argsort_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
struct dspqueue_buffer rsp_bufs[1];
// We had written to the output buffer, we'd also need to flush it
rsp_bufs[0].fd = bufs[1].fd;
rsp_bufs[0].ptr = bufs[1].ptr;
rsp_bufs[0].offset = bufs[1].offset;
rsp_bufs[0].size = bufs[1].size;
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
// Setup Op context
struct htp_ops_context octx = { 0 };
octx.ctx = ctx;
octx.src0 = req->src0;
octx.dst = req->dst;
octx.flags = req->flags;
octx.op = req->op;
memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
// Update data pointers
octx.src0.data = (uint32_t) bufs[0].ptr;
octx.dst.data = (uint32_t) bufs[1].ptr;
octx.n_threads = ctx->n_threads;
struct profile_data prof;
profile_start(&prof);
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
rsp_status = op_argsort(&octx);
vtcm_release(ctx);
}
profile_stop(&prof);
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
static void proc_cpy_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
struct dspqueue_buffer rsp_bufs[1];
@@ -679,6 +718,45 @@ static void proc_unary_req(struct htp_context * ctx, struct htp_general_req * re
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
static void proc_sum_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
// We had written to the output buffer, we'd also need to flush it
rsp_bufs[0].fd = bufs[1].fd;
rsp_bufs[0].ptr = bufs[1].ptr;
rsp_bufs[0].offset = bufs[1].offset;
rsp_bufs[0].size = bufs[1].size;
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
// Setup Op context
struct htp_ops_context octx = { 0 };
octx.ctx = ctx;
octx.src0 = req->src0;
octx.dst = req->dst;
octx.flags = req->flags;
octx.op = req->op;
memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
// Update data pointers
octx.src0.data = (uint32_t) bufs[0].ptr;
octx.dst.data = (uint32_t) bufs[1].ptr;
octx.n_threads = ctx->n_threads;
struct profile_data prof;
profile_start(&prof);
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
rsp_status = op_sum_rows(&octx);
vtcm_release(ctx);
}
profile_stop(&prof);
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
static void proc_activations_req(struct htp_context * ctx,
struct htp_general_req * req,
struct dspqueue_buffer * bufs,
@@ -951,6 +1029,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
case HTP_OP_MUL:
case HTP_OP_ADD:
case HTP_OP_SUB:
case HTP_OP_DIV:
if (n_bufs != 3) {
FARF(ERROR, "Bad binary-req buffer list");
continue;
@@ -968,6 +1047,25 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
proc_unary_req(ctx, &req, bufs);
break;
case HTP_OP_SQR:
case HTP_OP_SQRT:
if (n_bufs != 2) {
FARF(ERROR, "Bad unary-req buffer list");
continue;
}
proc_unary_req(ctx, &req, bufs);
break;
case HTP_OP_SUM_ROWS:
if (n_bufs != 2) {
FARF(ERROR, "Bad unary-req buffer list");
continue;
}
proc_sum_rows_req(ctx, &req, bufs);
break;
case HTP_OP_UNARY_SILU:
case HTP_OP_UNARY_GELU:
if (n_bufs != 2) {
@@ -980,6 +1078,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
case HTP_OP_GLU_SWIGLU:
case HTP_OP_GLU_SWIGLU_OAI:
case HTP_OP_SOFTMAX:
case HTP_OP_GLU_GEGLU:
if ((n_bufs != 2) && (n_bufs != 3)) {
FARF(ERROR, "Bad act-req buffer list");
continue;
@@ -1035,6 +1134,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
proc_cpy_req(ctx, &req, bufs);
break;
case HTP_OP_ARGSORT:
if (n_bufs != 2) {
FARF(ERROR, "Bad argsort-req buffer list");
continue;
}
proc_argsort_req(ctx, &req, bufs);
break;
default:
FARF(ERROR, "Unknown Op %u", req.op);
break;
File diff suppressed because it is too large Load Diff
+115
View File
@@ -0,0 +1,115 @@
#pragma clang diagnostic ignored "-Wunused-variable"
#pragma clang diagnostic ignored "-Wunused-function"
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
#include <HAP_farf.h>
#include <HAP_perf.h>
#include <string.h>
#include <math.h>
#include "hex-dma.h"
#include "hvx-utils.h"
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "htp-ctx.h"
#include "htp-msg.h"
#include "htp-ops.h"
#define sum_rows_preamble \
struct htp_tensor *src0 = &octx->src0;\
struct htp_tensor *dst = &octx->dst; \
\
const uint32_t ne00 = src0->ne[0]; \
const uint32_t ne01 = src0->ne[1]; \
const uint32_t ne02 = src0->ne[2]; \
const uint32_t ne03 = src0->ne[3]; \
\
const uint32_t nb00 = src0->nb[0]; \
const uint32_t nb01 = src0->nb[1]; \
const uint32_t nb02 = src0->nb[2]; \
const uint32_t nb03 = src0->nb[3]; \
\
const uint32_t ne0 = dst->ne[0]; \
const uint32_t ne1 = dst->ne[1]; \
const uint32_t ne2 = dst->ne[2]; \
const uint32_t ne3 = dst->ne[3]; \
\
const uint32_t nb0 = dst->nb[0]; \
const uint32_t nb1 = dst->nb[1]; \
const uint32_t nb2 = dst->nb[2]; \
const uint32_t nb3 = dst->nb[3]; \
static int sum_rows_thread_f32(struct htp_ops_context * octx, const int nth, const int ith) {
sum_rows_preamble;
const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
const size_t src0_row_size = nb01;
const size_t dst_row_size = nb1;
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
// no work for this thread
if (src0_start_row >= src0_end_row) {
return HTP_STATUS_OK;
}
int opt_path = 0;
if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) {
opt_path = 1;
}
const uint8_t * restrict data_src = (const uint8_t *) src0->data;
uint8_t * restrict data_dst = (uint8_t *) dst->data;
const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size));
float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size));
for (uint32_t ir = 0; ir < src0_nrows_per_thread; ir++) {
const float * restrict src_local = src_th + (ir * ne00);
if (ir + 1 < src0_nrows_per_thread) {
hex_l2fetch(src_local + ne00, src0_row_size, src0_row_size, 1);
}
if (1 == opt_path) {
dst_th[ir] = hvx_reduce_sum_f32_a((const uint8_t *) src_local, ne00);
} else {
dst_th[ir] = hvx_reduce_sum_f32((const uint8_t *) src_local, ne00);
}
}
return HTP_STATUS_OK;
}
static void sum_rows_work_f32(unsigned int n, unsigned int i, void *data) {
sum_rows_thread_f32((struct htp_ops_context *) data, n, i);
}
int op_sum_rows(struct htp_ops_context * octx) {
sum_rows_preamble;
if (octx->src0.type != HTP_TYPE_F32) {
return HTP_STATUS_NO_SUPPORT;
}
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
return HTP_STATUS_OK;
}
const int n_threads = octx->n_threads;
const uint32_t src0_nrows = ne01 * ne02 * ne03;
uint32_t n_jobs = MIN(n_threads, src0_nrows);
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
worker_pool_run_func(octx->ctx->worker_pool, sum_rows_work_f32, octx, n_jobs);
return HTP_STATUS_OK;
}
+64
View File
@@ -132,6 +132,56 @@ static void rms_norm_htp_f32(const float * restrict src,
}
}
static void sqr_htp_f32(const float * restrict src,
float * restrict dst,
uint8_t * restrict spad,
const uint32_t num_rows,
const uint32_t row_elems,
const size_t row_size,
int32_t * op_params,
int opt_path) {
for (uint32_t ir = 0; ir < num_rows; ir++) {
const float * restrict src_local = src + (ir * row_elems);
float * restrict dst_local = dst + (ir * row_elems);
if (ir + 1 < num_rows) {
hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
}
if (1 == opt_path) {
hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
} else {
hvx_sqr_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
}
}
}
static void sqrt_htp_f32(const float * restrict src,
float * restrict dst,
uint8_t * restrict spad,
const uint32_t num_rows,
const uint32_t row_elems,
const size_t row_size,
int32_t * op_params,
int opt_path) {
for (uint32_t ir = 0; ir < num_rows; ir++) {
const float * restrict src_local = src + (ir * row_elems);
float * restrict dst_local = dst + (ir * row_elems);
if (ir + 1 < num_rows) {
hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
}
if (1 == opt_path) {
hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
} else {
hvx_sqrt_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
}
}
}
static void unary_job_f32_per_thread(const struct htp_tensor * src,
struct htp_tensor * dst,
uint8_t * spad,
@@ -181,6 +231,12 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src,
case HTP_OP_SCALE:
scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
break;
case HTP_OP_SQR:
sqr_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
break;
case HTP_OP_SQRT:
sqrt_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
break;
default:
break;
@@ -218,6 +274,14 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
unary_op_func = unary_job_dispatcher_f32;
op_type = "scale-f32";
break;
case HTP_OP_SQR:
unary_op_func = unary_job_dispatcher_f32;
op_type = "sqr-f32";
break;
case HTP_OP_SQRT:
unary_op_func = unary_job_dispatcher_f32;
op_type = "sqrt-f32";
break;
default:
FARF(ERROR, "Unsupported unary Op %u\n", octx->op);
+4
View File
@@ -98,6 +98,10 @@ static bool ggml_op_is_empty(enum ggml_op op) {
}
}
static inline bool ggml_impl_is_view(const struct ggml_tensor * t) {
return t->view_src != NULL;
}
static inline float ggml_compute_softplus_f32(float input) {
return (input > 20.0f) ? input : logf(1 + expf(input));
}
+13 -2
View File
@@ -264,15 +264,26 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_GROUP_NORM:
case GGML_OP_L2_NORM:
case GGML_OP_SUM_ROWS:
case GGML_OP_SSM_CONV:
case GGML_OP_SSM_SCAN:
case GGML_OP_CLAMP:
case GGML_OP_TRI:
case GGML_OP_DIAG:
case GGML_OP_MUL:
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_DIV:
case GGML_OP_GLU:
case GGML_OP_SCALE:
case GGML_OP_UNARY:
case GGML_OP_GET_ROWS:
case GGML_OP_CPY:
case GGML_OP_SET_ROWS:
case GGML_OP_SET:
case GGML_OP_CPY:
case GGML_OP_CONT:
case GGML_OP_REPEAT:
return true;
default:
return ggml_op_is_empty(op);
@@ -312,7 +323,7 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
h_add(mrs1, node0);
// that many nodes forward to search for a concurrent node
constexpr int N_FORWARD = 8;
constexpr int N_FORWARD = 64;
for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) {
if (used[i1]) {
+76 -50
View File
@@ -212,61 +212,69 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_meta
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
char base[256];
char name[256];
const int64_t n = ggml_nelements(op);
int op_num = -1;
const char * op_str = "undefined";
switch (op->op) {
case GGML_OP_SCALE: op_str = "scale"; break;
case GGML_OP_FILL: op_str = "fill"; break;
case GGML_OP_CLAMP: op_str = "clamp"; break;
case GGML_OP_SQR: op_str = "sqr"; break;
case GGML_OP_SQRT: op_str = "sqrt"; break;
case GGML_OP_SIN: op_str = "sin"; break;
case GGML_OP_COS: op_str = "cos"; break;
case GGML_OP_LOG: op_str = "log"; break;
case GGML_OP_LEAKY_RELU: op_str = "leaky_relu"; break;
case GGML_OP_SCALE: op_num = OP_UNARY_NUM_SCALE; break;
case GGML_OP_FILL: op_num = OP_UNARY_NUM_FILL; break;
case GGML_OP_CLAMP: op_num = OP_UNARY_NUM_CLAMP; break;
case GGML_OP_SQR: op_num = OP_UNARY_NUM_SQR; break;
case GGML_OP_SQRT: op_num = OP_UNARY_NUM_SQRT; break;
case GGML_OP_SIN: op_num = OP_UNARY_NUM_SIN; break;
case GGML_OP_COS: op_num = OP_UNARY_NUM_COS; break;
case GGML_OP_LOG: op_num = OP_UNARY_NUM_LOG; break;
case GGML_OP_LEAKY_RELU: op_num = OP_UNARY_NUM_LEAKY_RELU; break;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_TANH: op_str = "tanh"; break;
case GGML_UNARY_OP_RELU: op_str = "relu"; break;
case GGML_UNARY_OP_SIGMOID: op_str = "sigmoid"; break;
case GGML_UNARY_OP_GELU: op_str = "gelu"; break;
case GGML_UNARY_OP_GELU_ERF: op_str = "gelu_erf"; break;
case GGML_UNARY_OP_GELU_QUICK: op_str = "gelu_quick"; break;
case GGML_UNARY_OP_SILU: op_str = "silu"; break;
case GGML_UNARY_OP_ELU: op_str = "elu"; break;
case GGML_UNARY_OP_NEG: op_str = "neg"; break;
case GGML_UNARY_OP_ABS: op_str = "abs"; break;
case GGML_UNARY_OP_SGN: op_str = "sgn"; break;
case GGML_UNARY_OP_STEP: op_str = "step"; break;
case GGML_UNARY_OP_HARDSWISH: op_str = "hardswish"; break;
case GGML_UNARY_OP_HARDSIGMOID: op_str = "hardsigmoid"; break;
case GGML_UNARY_OP_EXP: op_str = "exp"; break;
case GGML_UNARY_OP_SOFTPLUS: op_str = "softplus"; break;
case GGML_UNARY_OP_EXPM1: op_str = "expm1"; break;
case GGML_UNARY_OP_TANH: op_num = OP_UNARY_NUM_TANH; break;
case GGML_UNARY_OP_RELU: op_num = OP_UNARY_NUM_RELU; break;
case GGML_UNARY_OP_SIGMOID: op_num = OP_UNARY_NUM_SIGMOID; break;
case GGML_UNARY_OP_GELU: op_num = OP_UNARY_NUM_GELU; break;
case GGML_UNARY_OP_GELU_ERF: op_num = OP_UNARY_NUM_GELU_ERF; break;
case GGML_UNARY_OP_GELU_QUICK: op_num = OP_UNARY_NUM_GELU_QUICK; break;
case GGML_UNARY_OP_SILU: op_num = OP_UNARY_NUM_SILU; break;
case GGML_UNARY_OP_ELU: op_num = OP_UNARY_NUM_ELU; break;
case GGML_UNARY_OP_NEG: op_num = OP_UNARY_NUM_NEG; break;
case GGML_UNARY_OP_ABS: op_num = OP_UNARY_NUM_ABS; break;
case GGML_UNARY_OP_SGN: op_num = OP_UNARY_NUM_SGN; break;
case GGML_UNARY_OP_STEP: op_num = OP_UNARY_NUM_STEP; break;
case GGML_UNARY_OP_HARDSWISH: op_num = OP_UNARY_NUM_HARDSWISH; break;
case GGML_UNARY_OP_HARDSIGMOID: op_num = OP_UNARY_NUM_HARDSIGMOID; break;
case GGML_UNARY_OP_EXP: op_num = OP_UNARY_NUM_EXP; break;
case GGML_UNARY_OP_SOFTPLUS: op_num = OP_UNARY_NUM_SOFTPLUS; break;
case GGML_UNARY_OP_EXPM1: op_num = OP_UNARY_NUM_EXPM1; break;
default: GGML_ABORT("fatal error");
} break;
default: GGML_ABORT("fatal error");
};
const char * suffix = "";
if (n % 4 == 0) {
suffix = "_4";
}
const char * t0_str = ggml_type_name(op->src[0]->type);
const char * t_str = ggml_type_name(op->type);
snprintf(base, 256, "kernel_%s_%s%s", op_str, ggml_type_name(op->src[0]->type), suffix);
snprintf(name, 256, "%s", base);
const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
const bool is_cnt = ggml_is_contiguous(op->src[0]) && ggml_nelements(op) < 32768;
snprintf(base, 256, "kernel_unary_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
snprintf(name, 256, "%s_op=%d_cnt=%d", base, op_num, is_cnt);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_int16(cv, op_num, FC_UNARY + 0);
ggml_metal_cv_set_bool (cv, is_cnt, FC_UNARY + 1);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_cv_free(cv);
}
res.c4 = is_c4;
res.cnt = is_cnt;
return res;
}
@@ -320,31 +328,46 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum(ggml_metal_l
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
char base[256];
char name[256];
const char * op_str = "undefined";
int op_num = -1;
switch (op->op) {
case GGML_OP_SUM_ROWS:
op_str = "sum_rows"; break;
case GGML_OP_MEAN:
op_str = "mean"; break;
case GGML_OP_SUM_ROWS: op_num = OP_SUM_ROWS_NUM_SUM_ROWS; break;
case GGML_OP_MEAN: op_num = OP_SUM_ROWS_NUM_MEAN; break;
default: GGML_ABORT("fatal error");
};
snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type));
const char * t0_str = ggml_type_name(op->src[0]->type);
const char * t_str = ggml_type_name(op->type);
snprintf(name, 256, "%s", base);
const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
snprintf(base, 256, "kernel_sum_rows_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
snprintf(name, 256, "%s_op=%d", base, op_num);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_int16(cv, op_num, FC_SUM_ROWS + 0);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_cv_free(cv);
}
res.smem = 32*sizeof(float);
if (is_c4) {
res.smem *= 4;
}
res.c4 = is_c4;
return res;
}
@@ -1472,13 +1495,15 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one(ggml_met
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_L2_NORM);
GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
char base[256];
char name[256];
snprintf(base, 256, "kernel_l2_norm_f32");
const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
const char * t0_str = ggml_type_name(op->src[0]->type);
const char * t_str = ggml_type_name(op->type);
snprintf(base, 256, "kernel_l2_norm_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
@@ -1486,6 +1511,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
}
res.c4 = is_c4;
res.smem = 32*sizeof(float);
return res;
+13 -14
View File
@@ -1011,6 +1011,15 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
}
switch (op->op) {
case GGML_OP_SCALE:
case GGML_OP_FILL:
case GGML_OP_CLAMP:
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_LOG:
return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_TANH:
@@ -1030,7 +1039,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_SOFTPLUS:
case GGML_UNARY_OP_EXPM1:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
default:
return false;
}
@@ -1058,11 +1067,9 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_ADD_ID:
return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ACC:
return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_REPEAT:
case GGML_OP_SCALE:
case GGML_OP_FILL:
case GGML_OP_CONV_TRANSPOSE_1D:
return true;
case GGML_OP_CONV_TRANSPOSE_2D:
@@ -1070,14 +1077,6 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) &&
op->src[1]->type == GGML_TYPE_F32 &&
op->type == GGML_TYPE_F32;
case GGML_OP_CLAMP:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_LOG:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SUM:
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
case GGML_OP_TRI:
@@ -1087,9 +1086,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_MEAN:
case GGML_OP_SOFT_MAX:
case GGML_OP_GROUP_NORM:
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_L2_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_COUNT_EQUAL:
return has_simdgroup_reduction &&
op->src[0]->type == GGML_TYPE_I32 &&
@@ -1161,6 +1159,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
return has_simdgroup_reduction;
case GGML_OP_SET:
case GGML_OP_CPY:
case GGML_OP_DUP:
case GGML_OP_CONT:
+73 -20
View File
@@ -80,7 +80,9 @@
#define FC_SSM_CONV 900
#define FC_SOLVE_TRI 1000
#define FC_COUNT_EQUAL 1100
#define FC_BIN 1200
#define FC_UNARY 1200
#define FC_BIN 1300
#define FC_SUM_ROWS 1400
// op-specific constants
#define OP_FLASH_ATTN_EXT_NQPSG 8
@@ -89,6 +91,37 @@
#define OP_FLASH_ATTN_EXT_VEC_NQPSG 1
#define OP_FLASH_ATTN_EXT_VEC_NCPSG 32
#define OP_UNARY_NUM_SCALE 10
#define OP_UNARY_NUM_FILL 11
#define OP_UNARY_NUM_CLAMP 12
#define OP_UNARY_NUM_SQR 13
#define OP_UNARY_NUM_SQRT 14
#define OP_UNARY_NUM_SIN 15
#define OP_UNARY_NUM_COS 16
#define OP_UNARY_NUM_LOG 17
#define OP_UNARY_NUM_LEAKY_RELU 18
#define OP_UNARY_NUM_TANH 100
#define OP_UNARY_NUM_RELU 101
#define OP_UNARY_NUM_SIGMOID 102
#define OP_UNARY_NUM_GELU 103
#define OP_UNARY_NUM_GELU_ERF 104
#define OP_UNARY_NUM_GELU_QUICK 105
#define OP_UNARY_NUM_SILU 106
#define OP_UNARY_NUM_ELU 107
#define OP_UNARY_NUM_NEG 108
#define OP_UNARY_NUM_ABS 109
#define OP_UNARY_NUM_SGN 110
#define OP_UNARY_NUM_STEP 111
#define OP_UNARY_NUM_HARDSWISH 112
#define OP_UNARY_NUM_HARDSIGMOID 113
#define OP_UNARY_NUM_EXP 114
#define OP_UNARY_NUM_SOFTPLUS 115
#define OP_UNARY_NUM_EXPM1 116
#define OP_SUM_ROWS_NUM_SUM_ROWS 10
#define OP_SUM_ROWS_NUM_MEAN 11
// kernel argument structs
//
// - element counters (e.g. ne00) typically use int32_t to reduce register usage
@@ -124,6 +157,31 @@ typedef struct {
int32_t dim;
} ggml_metal_kargs_concat;
typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
float slope;
float scale;
float bias;
float val;
float min;
float max;
} ggml_metal_kargs_unary;
typedef struct {
int32_t ne00;
int32_t ne01;
@@ -181,20 +239,6 @@ typedef struct {
uint64_t nb3;
} ggml_metal_kargs_repeat;
typedef struct {
float scale;
float bias;
} ggml_metal_kargs_scale;
typedef struct {
float val;
} ggml_metal_kargs_fill;
typedef struct {
float min;
float max;
} ggml_metal_kargs_clamp;
typedef struct {
int64_t nk0;
int64_t ne00;
@@ -498,8 +542,21 @@ typedef struct {
typedef struct {
int32_t ne00;
int32_t ne00_4;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
float eps;
} ggml_metal_kargs_l2_norm;
@@ -881,10 +938,6 @@ typedef struct {
int max_period;
} ggml_metal_kargs_timestep_embedding;
typedef struct {
float slope;
} ggml_metal_kargs_leaky_relu;
typedef struct {
int32_t ne00;
int32_t ne01;
+264 -197
View File
@@ -287,17 +287,9 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
n_fuse = ggml_metal_op_acc(ctx, idx);
} break;
case GGML_OP_SCALE:
{
n_fuse = ggml_metal_op_scale(ctx, idx);
} break;
case GGML_OP_FILL:
{
n_fuse = ggml_metal_op_fill(ctx, idx);
} break;
case GGML_OP_CLAMP:
{
n_fuse = ggml_metal_op_clamp(ctx, idx);
} break;
case GGML_OP_LEAKY_RELU:
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_SIN:
@@ -426,10 +418,6 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_top_k(ctx, idx);
} break;
case GGML_OP_LEAKY_RELU:
{
n_fuse = ggml_metal_op_leaky_relu(ctx, idx);
} break;
case GGML_OP_TRI:
{
n_fuse = ggml_metal_op_tri(ctx, idx);
@@ -438,6 +426,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx);
} break;
case GGML_OP_SET:
{
n_fuse = ggml_metal_op_set(ctx, idx);
} break;
case GGML_OP_DUP:
case GGML_OP_CPY:
case GGML_OP_CONT:
@@ -628,8 +620,8 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
GGML_ASSERT(op->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
GGML_ASSERT(ggml_is_contiguous(op->src[1]));
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
const size_t pnb1 = ((const int32_t *) op->op_params)[0];
const size_t pnb2 = ((const int32_t *) op->op_params)[1];
@@ -679,10 +671,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
}
ggml_metal_kargs_bin args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.ne00 =*/ ne10,
/*.ne01 =*/ ne11,
/*.ne02 =*/ ne12,
/*.ne03 =*/ ne13,
/*.nb00 =*/ nb00,
/*.nb01 =*/ pnb1,
/*.nb02 =*/ pnb2,
@@ -695,10 +687,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.ne0 =*/ ne10,
/*.ne1 =*/ ne11,
/*.ne2 =*/ ne12,
/*.ne3 =*/ ne13,
/*.nb0 =*/ nb0,
/*.nb1 =*/ pnb1,
/*.nb2 =*/ pnb2,
@@ -715,126 +707,19 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
int nth = 1;
while (2*nth < args.ne0 && nth < nth_max) {
nth *= 2;
}
ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1);
return 1;
}
int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
float scale;
float bias;
memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(float));
memcpy(&bias, ((const int32_t *) op->op_params) + 1, sizeof(float));
ggml_metal_kargs_scale args = {
/*.scale =*/ scale,
/*.bias =*/ bias,
};
int64_t n = ggml_nelements(op);
if (n % 4 == 0) {
n /= 4;
}
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
return 1;
}
int ggml_metal_op_fill(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
const float val = ggml_get_op_params_f32(op, 0);
ggml_metal_kargs_fill args = {
/*.val =*/ val
};
int64_t n = ggml_nelements(op);
if (n % 4 == 0) {
n /= 4;
}
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
return 1;
}
int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
float min;
float max;
memcpy(&min, ((const int32_t *) op->op_params) + 0, sizeof(float));
memcpy(&max, ((const int32_t *) op->op_params) + 1, sizeof(float));
ggml_metal_kargs_clamp args = {
/*.min =*/ min,
/*.max =*/ max,
};
int64_t n = ggml_nelements(op);
if (n % 4 == 0) {
n /= 4;
}
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
return 1;
}
int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
@@ -846,19 +731,79 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
int64_t n = ggml_nelements(op);
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
if (n % 4 == 0) {
n /= 4;
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
ggml_metal_kargs_unary args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.slope =*/ 0.0,
/*.scale =*/ 0.0,
/*.bias =*/ 0.0,
/*.val =*/ 0.0,
/*.min =*/ 0.0,
/*.max =*/ 0.0,
};
if (op->op == GGML_OP_LEAKY_RELU) {
args.slope = ggml_get_op_params_f32(op, 0);
}
if (op->op == GGML_OP_SCALE) {
args.scale = ggml_get_op_params_f32(op, 0);
args.bias = ggml_get_op_params_f32(op, 1);
}
if (op->op == GGML_OP_FILL) {
args.val = ggml_get_op_params_f32(op, 0);
}
if (op->op == GGML_OP_CLAMP) {
args.min = ggml_get_op_params_f32(op, 0);
args.max = ggml_get_op_params_f32(op, 1);
}
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 1);
if (pipeline.c4) {
args.ne00 = ne00/4;
args.ne0 = ne0/4;
}
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
if (pipeline.cnt) {
const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
} else {
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
const int nth = MIN(args.ne00, nth_max);
const int nk0 = (args.ne00 + nth - 1)/nth;
ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1);
}
return 1;
}
@@ -969,6 +914,11 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
ggml_metal_kargs_sum_rows args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
@@ -990,21 +940,26 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
if (pipeline.c4) {
args.ne00 = ne00/4;
args.ne0 = ne0/4;
}
int nth = 32; // SIMD width
while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
while (nth < args.ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth *= 2;
}
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
nth = std::min(nth, ne00);
nth = std::min(nth, (int) args.ne00);
const size_t smem = pipeline.smem;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
@@ -1664,6 +1619,134 @@ int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
return 1;
}
int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
const size_t pnb1 = ((const int32_t *) op->op_params)[0];
const size_t pnb2 = ((const int32_t *) op->op_params)[1];
const size_t pnb3 = ((const int32_t *) op->op_params)[2];
const size_t offs = ((const int32_t *) op->op_params)[3];
const bool inplace = (bool) ((const int32_t *) op->op_params)[4];
if (!inplace) {
// run a separete kernel to cpy src->dst
// not sure how to avoid this
// TODO: make a simpler cpy_bytes kernel
//const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
ggml_metal_kargs_cpy args = {
/*.nk0 =*/ ne00,
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
};
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
ggml_metal_op_concurrency_reset(ctx);
}
auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[1]->type, op->type);
GGML_ASSERT(ne10 % ggml_blck_size(op->src[1]->type) == 0);
int64_t nk0 = ne10;
if (ggml_is_quantized(op->src[1]->type)) {
nk0 = ne10/16;
} else if (ggml_is_quantized(op->type)) {
nk0 = ne10/ggml_blck_size(op->type);
}
int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
// when rows are small, we can batch them together in a single threadgroup
int nrptg = 1;
// TODO: relax this constraint in the future
if (ggml_blck_size(op->src[1]->type) == 1 && ggml_blck_size(op->type) == 1) {
if (nth > nk0) {
nrptg = (nth + nk0 - 1)/nk0;
nth = nk0;
if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nrptg--;
}
}
}
nth = std::min<int>(nth, nk0);
ggml_metal_kargs_cpy args = {
/*.nk0 =*/ nk0,
/*.ne00 =*/ ne10,
/*.ne01 =*/ ne11,
/*.ne02 =*/ ne12,
/*.ne03 =*/ ne13,
/*.nb00 =*/ nb10,
/*.nb01 =*/ nb11,
/*.nb02 =*/ nb12,
/*.nb03 =*/ nb13,
/*.ne0 =*/ ne10,
/*.ne1 =*/ ne11,
/*.ne2 =*/ ne12,
/*.ne3 =*/ ne13,
/*.nb0 =*/ ggml_element_size(op),
/*.nb1 =*/ pnb1,
/*.nb2 =*/ pnb2,
/*.nb3 =*/ pnb3,
};
const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
bid_dst.offs += offs;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne11 + nrptg - 1)/nrptg, ne12, ne13, nth, nrptg, 1);
return 1;
}
int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
@@ -3044,39 +3127,59 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
float eps;
memcpy(&eps, op->op_params, sizeof(float));
int nth = 32; // SIMD width
ggml_metal_kargs_l2_norm args = {
/*.ne00 =*/ ne00,
/*.ne00_4 =*/ ne00/4,
/*.nb01 =*/ nb01,
/*.eps =*/ eps,
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.eps =*/ eps,
};
auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
if (pipeline.c4) {
args.ne00 = ne00/4;
args.ne0 = ne0/4;
}
int nth = 32; // SIMD width
while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth *= 2;
}
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
nth = std::min(nth, ne00/4);
const size_t smem = pipeline.smem;
const int64_t nrows = ggml_nrows(op->src[0]);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
return 1;
}
@@ -4084,42 +4187,6 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
return 1;
}
int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
float slope;
memcpy(&slope, op->op_params, sizeof(float));
ggml_metal_kargs_leaky_relu args = {
/*.slope =*/ slope
};
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
int64_t n = ggml_nelements(op);
if (n % 4 == 0) {
n /= 4;
}
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
return 1;
}
int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
+1 -4
View File
@@ -46,9 +46,6 @@ size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op);
int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_repeat (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_acc (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_fill (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx);
@@ -62,6 +59,7 @@ int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_set (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_pool_1d (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_pool_2d (ggml_metal_op_t ctx, int idx);
@@ -86,7 +84,6 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_tri (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);
+275 -466
View File
@@ -77,6 +77,14 @@ static inline float dot(float x, float y) {
return x*y;
}
static inline float sum(float x) {
return x;
}
static inline float sum(float4 x) {
return x[0] + x[1] + x[2] + x[3];
}
// NOTE: this is not dequantizing - we are simply fitting the template
template <typename type4x4>
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
@@ -895,6 +903,210 @@ enum ggml_sort_order {
GGML_SORT_ORDER_DESC,
};
constant float GELU_COEF_A = 0.044715f;
constant float GELU_QUICK_COEF = -1.702f;
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
// ref: https://www.johndcook.com/blog/python_erf/
constant float p_erf = 0.3275911f;
constant float a1_erf = 0.254829592f;
constant float a2_erf = -0.284496736f;
constant float a3_erf = 1.421413741f;
constant float a4_erf = -1.453152027f;
constant float a5_erf = 1.061405429f;
template<typename T>
inline T erf_approx(T x) {
T sign_x = sign(x);
x = fabs(x);
T t = 1.0f / (1.0f + p_erf * x);
T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
return sign_x * y;
}
template<typename T> T elu_approx(T x);
template<> inline float elu_approx<float>(float x) {
return (x > 0.f) ? x : (exp(x) - 1);
}
template<> inline float4 elu_approx<float4>(float4 x) {
float4 res;
res[0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
res[1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
res[2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
res[3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
return res;
}
constant short FC_unary_op [[function_constant(FC_UNARY + 0)]];
constant bool FC_unary_cnt[[function_constant(FC_UNARY + 1)]];
template <typename T0, typename T, typename TC>
kernel void kernel_unary_impl(
constant ggml_metal_kargs_unary & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
#define FC_OP FC_unary_op
#define FC_CNT FC_unary_cnt
device const T0 * src0_ptr;
device T * dst_ptr;
int i0;
if (FC_CNT) {
i0 = tgpig.x;
src0_ptr = (device const T0 *) (src0);
dst_ptr = (device T *) (dst);
} else {
const int i03 = tgpig.z;
const int i02 = tgpig.y;
const int k0 = tgpig.x/args.ne01;
const int i01 = tgpig.x - k0*args.ne01;
i0 = k0*ntg.x + tpitg.x;
src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 );
}
{
//threadgroup_barrier(mem_flags::mem_none);
if (!FC_CNT) {
if (i0 >= args.ne0) {
return;
}
}
const TC x = (TC) src0_ptr[i0];
if (FC_OP == OP_UNARY_NUM_SCALE) {
dst_ptr[i0] = (T) (args.scale * x + args.bias);
}
if (FC_OP == OP_UNARY_NUM_FILL) {
dst_ptr[i0] = (T) args.val;
}
if (FC_OP == OP_UNARY_NUM_CLAMP) {
dst_ptr[i0] = (T) clamp(x, args.min, args.max);
}
if (FC_OP == OP_UNARY_NUM_SQR) {
dst_ptr[i0] = (T) (x * x);
}
if (FC_OP == OP_UNARY_NUM_SQRT) {
dst_ptr[i0] = (T) sqrt(x);
}
if (FC_OP == OP_UNARY_NUM_SIN) {
dst_ptr[i0] = (T) sin(x);
}
if (FC_OP == OP_UNARY_NUM_COS) {
dst_ptr[i0] = (T) cos(x);
}
if (FC_OP == OP_UNARY_NUM_LOG) {
dst_ptr[i0] = (T) log(x);
}
if (FC_OP == OP_UNARY_NUM_LEAKY_RELU) {
dst_ptr[i0] = (T) (TC(x > 0)*x + TC(x <= 0)*(x * args.slope));
}
if (FC_OP == OP_UNARY_NUM_TANH) {
dst_ptr[i0] = (T) precise::tanh(x);
}
if (FC_OP == OP_UNARY_NUM_RELU) {
dst_ptr[i0] = (T) fmax(0, x);
}
if (FC_OP == OP_UNARY_NUM_SIGMOID) {
dst_ptr[i0] = (T) (1 / (1 + exp(-x)));
}
if (FC_OP == OP_UNARY_NUM_GELU) {
dst_ptr[i0] = (T) (0.5*x*(1 + precise::tanh(SQRT_2_OVER_PI*x*(1 + GELU_COEF_A*x*x))));
}
if (FC_OP == OP_UNARY_NUM_GELU_ERF) {
dst_ptr[i0] = (T) (0.5*x*(1 + erf_approx(SQRT_2_INV*x)));
}
if (FC_OP == OP_UNARY_NUM_GELU_QUICK) {
dst_ptr[i0] = (T) (x * (1/(1 + exp(GELU_QUICK_COEF*x))));
}
if (FC_OP == OP_UNARY_NUM_SILU) {
dst_ptr[i0] = (T) (x / (1 + exp(-x)));
}
if (FC_OP == OP_UNARY_NUM_ELU) {
dst_ptr[i0] = (T) elu_approx(x);
}
if (FC_OP == OP_UNARY_NUM_NEG) {
dst_ptr[i0] = (T) -x;
}
if (FC_OP == OP_UNARY_NUM_ABS) {
dst_ptr[i0] = (T) fabs(x);
}
if (FC_OP == OP_UNARY_NUM_SGN) {
dst_ptr[i0] = T(x > 0) - T(x < 0);
}
if (FC_OP == OP_UNARY_NUM_STEP) {
dst_ptr[i0] = T(x > 0);
}
if (FC_OP == OP_UNARY_NUM_HARDSWISH) {
dst_ptr[i0] = (T) (x * fmax(0, fmin(1, x/6 + 0.5)));
}
if (FC_OP == OP_UNARY_NUM_HARDSIGMOID) {
dst_ptr[i0] = (T) fmax(0, fmin(1, x/6 + 0.5));
}
if (FC_OP == OP_UNARY_NUM_EXP) {
dst_ptr[i0] = (T) exp(x);
}
if (FC_OP == OP_UNARY_NUM_SOFTPLUS) {
dst_ptr[i0] = (T) select(log(1 + exp(x)), x, x > 20);
}
if (FC_OP == OP_UNARY_NUM_EXPM1) {
// TODO: precise implementation
dst_ptr[i0] = (T) (exp(x) - 1);
}
}
#undef FC_OP
#undef FC_CNT
}
typedef decltype(kernel_unary_impl<float, float, float>) kernel_unary_t;
template [[host_name("kernel_unary_f32_f32")]] kernel kernel_unary_t kernel_unary_impl<float, float, float>;
template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl<float4, float4, float4>;
template [[host_name("kernel_unary_f16_f16")]] kernel kernel_unary_t kernel_unary_impl<half, half, float>;
template [[host_name("kernel_unary_f16_f16_4")]] kernel kernel_unary_t kernel_unary_impl<half4, half4, float4>;
// OP: 0 - add, 1 - sub, 2 - mul, 3 - div
constant short FC_bin_op [[function_constant(FC_BIN + 0)]];
constant short FC_bin_f [[function_constant(FC_BIN + 1)]];
@@ -1114,414 +1326,6 @@ template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat
template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
kernel void kernel_scale_f32(
constant ggml_metal_kargs_scale & args,
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * args.scale + args.bias;
}
kernel void kernel_scale_f32_4(
constant ggml_metal_kargs_scale & args,
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * args.scale + args.bias;
}
kernel void kernel_fill_f32(
constant ggml_metal_kargs_fill & args,
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = args.val;
}
kernel void kernel_fill_f32_4(
constant ggml_metal_kargs_fill & args,
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = args.val;
}
kernel void kernel_clamp_f32(
constant ggml_metal_kargs_clamp & args,
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = clamp(src0[tpig], args.min, args.max);
}
kernel void kernel_clamp_f32_4(
constant ggml_metal_kargs_clamp & args,
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = clamp(src0[tpig], args.min, args.max);
}
kernel void kernel_relu_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = max(0.0f, src0[tpig]);
}
kernel void kernel_relu_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = max(0.0f, src0[tpig]);
}
kernel void kernel_sigmoid_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
}
kernel void kernel_sigmoid_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
}
kernel void kernel_tanh_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = precise::tanh(src0[tpig]);
}
kernel void kernel_tanh_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = precise::tanh(src0[tpig]);
}
constant float GELU_COEF_A = 0.044715f;
constant float GELU_QUICK_COEF = -1.702f;
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
kernel void kernel_gelu_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
device const float & x = src0[tpig];
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
}
kernel void kernel_gelu_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
device const float4 & x = src0[tpig];
// BEWARE !!!
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
// This was observed with Falcon 7B and 40B models
//
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
}
kernel void kernel_gelu_quick_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
device const float & x = src0[tpig];
dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
}
kernel void kernel_gelu_quick_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
device const float4 & x = src0[tpig];
dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
}
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
// ref: https://www.johndcook.com/blog/python_erf/
constant float p_erf = 0.3275911f;
constant float a1_erf = 0.254829592f;
constant float a2_erf = -0.284496736f;
constant float a3_erf = 1.421413741f;
constant float a4_erf = -1.453152027f;
constant float a5_erf = 1.061405429f;
template<typename T>
T erf_approx(T x) {
T sign_x = sign(x);
x = fabs(x);
T t = 1.0f / (1.0f + p_erf * x);
T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
return sign_x * y;
}
kernel void kernel_gelu_erf_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
device const float & x = src0[tpig];
dst[tpig] = 0.5f*x*(1.0f+erf_approx<float>(x*SQRT_2_INV));
}
kernel void kernel_gelu_erf_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
device const float4 & x = src0[tpig];
dst[tpig] = 0.5f*x*(1.0f+erf_approx<float4>(x*SQRT_2_INV));
}
kernel void kernel_silu_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
device const float & x = src0[tpig];
dst[tpig] = x / (1.0f + exp(-x));
}
kernel void kernel_silu_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
device const float4 & x = src0[tpig];
dst[tpig] = x / (1.0f + exp(-x));
}
kernel void kernel_elu_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
const float x = src0[tpig];
dst[tpig] = (x > 0.0f) ? x : (exp(x) - 1.0f);
}
kernel void kernel_elu_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
const float4 x = src0[tpig];
dst[tpig][0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
dst[tpig][1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
dst[tpig][2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
dst[tpig][3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
}
kernel void kernel_sqr_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * src0[tpig];
}
kernel void kernel_sqr_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * src0[tpig];
}
kernel void kernel_sqrt_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = sqrt(src0[tpig]);
}
kernel void kernel_sqrt_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = sqrt(src0[tpig]);
}
kernel void kernel_sin_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = sin(src0[tpig]);
}
kernel void kernel_sin_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = sin(src0[tpig]);
}
kernel void kernel_cos_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = cos(src0[tpig]);
}
kernel void kernel_cos_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = cos(src0[tpig]);
}
kernel void kernel_log_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = log(src0[tpig]);
}
kernel void kernel_log_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = log(src0[tpig]);
}
kernel void kernel_neg_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = -src0[tpig];
}
kernel void kernel_neg_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = -src0[tpig];
}
kernel void kernel_abs_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = fabs(src0[tpig]);
}
kernel void kernel_abs_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = fabs(src0[tpig]);
}
kernel void kernel_sgn_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = sign(src0[tpig]);
}
kernel void kernel_sgn_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = sign(src0[tpig]);
}
kernel void kernel_step_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = step(0.0f, src0[tpig]);
}
kernel void kernel_step_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = step(0.0f, src0[tpig]);
}
kernel void kernel_hardswish_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
const float x = src0[tpig];
dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
}
kernel void kernel_hardswish_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
const float4 x = src0[tpig];
dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
}
kernel void kernel_hardsigmoid_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
const float x = src0[tpig];
dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
}
kernel void kernel_hardsigmoid_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
const float4 x = src0[tpig];
dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
}
kernel void kernel_exp_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = exp(src0[tpig]);
}
kernel void kernel_exp_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = exp(src0[tpig]);
}
kernel void kernel_softplus_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
device const float & x = src0[tpig];
dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
}
kernel void kernel_softplus_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
device const float4 & x = src0[tpig];
dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
}
kernel void kernel_expm1_f32(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = exp(src0[tpig]) - 1.0f;
}
kernel void kernel_expm1_f32_4(
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = exp(src0[tpig]) - 1.0f;
}
kernel void kernel_reglu_f32(
constant ggml_metal_kargs_glu & args,
device const char * src0,
@@ -1705,33 +1509,35 @@ kernel void kernel_op_sum_f32(
}
}
template <bool norm>
kernel void kernel_sum_rows(
constant short FC_sum_rows_op [[function_constant(FC_SUM_ROWS + 0)]];
template <typename T0, typename T>
kernel void kernel_sum_rows_impl(
constant ggml_metal_kargs_sum_rows & args,
device const float * src0,
device float * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
device const char * src0,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
int64_t i3 = tgpig.z;
int64_t i2 = tgpig.y;
int64_t i1 = tgpig.x;
#define FC_OP FC_sum_rows_op
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
return;
}
const int i3 = tgpig.z;
const int i2 = tgpig.y;
const int i1 = tgpig.x;
threadgroup T0 * shmem_t = (threadgroup T0 *) shmem;
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
shmem_t[tiisg] = 0.0f;
}
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
device const T0 * src_row = (device const T0 *) (src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
device T * dst_row = (device T *) (dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
float sumf = 0;
T0 sumf = T0(0.0f);
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
sumf += src_row[i0];
@@ -1742,23 +1548,33 @@ kernel void kernel_sum_rows(
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
shmem_t[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg];
sumf = shmem_t[tiisg];
sumf = simd_sum(sumf);
if (tpitg.x == 0) {
dst_row[0] = norm ? sumf / args.ne00 : sumf;
if (FC_OP == OP_SUM_ROWS_NUM_MEAN) {
if (is_same<float4, T0>::value) {
dst_row[0] = sum(sumf) / (4*args.ne00);
} else {
dst_row[0] = sum(sumf) / args.ne00;
}
} else {
dst_row[0] = sum(sumf);
}
}
#undef FC_OP
}
typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
typedef decltype(kernel_sum_rows_impl<float, float>) kernel_sum_rows_t;
template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
template [[host_name("kernel_sum_rows_f32_f32")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float, float>;
template [[host_name("kernel_sum_rows_f32_f32_4")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float4, float>;
template<typename T>
kernel void kernel_cumsum_blk(
@@ -2639,9 +2455,6 @@ kernel void kernel_solve_tri_f32(
const short K = FC_solve_tri_k;
const short NP = PAD2(N, NW);
const int32_t ne02 = args.ne02;
const int32_t ne03 = args.ne03;
const int32_t i03 = tgpig.z;
const int32_t i02 = tgpig.y;
const int32_t i01 = tgpig.x*NSG + sgitg;
@@ -2928,26 +2741,32 @@ template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_f
template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>;
template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 3>;
kernel void kernel_l2_norm_f32(
template <typename T0, typename T>
kernel void kernel_l2_norm_impl(
constant ggml_metal_kargs_l2_norm & args,
device const char * src0,
device char * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
ushort tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort ntg[[threads_per_threadgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig.z;
const int i02 = tgpig.y;
const int i01 = tgpig.x;
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
device const T0 * x = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
float sumf = 0.0f;
// parallel sum
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
sumf += dot(x[i00], x[i00]);
}
sumf = simd_sum(sumf);
@@ -2965,12 +2784,16 @@ kernel void kernel_l2_norm_f32(
const float scale = 1.0f/sqrt(max(sumf, args.eps));
device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
y[i00] = x[i00] * scale;
}
}
typedef decltype(kernel_l2_norm_impl<float, float>) kernel_l2_norm_t;
template [[host_name("kernel_l2_norm_f32_f32")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float, float>;
template [[host_name("kernel_l2_norm_f32_f32_4")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float4, float4>;
kernel void kernel_group_norm_f32(
constant ggml_metal_kargs_group_norm & args,
device const float * src0,
@@ -5072,24 +4895,6 @@ kernel void kernel_argsort_merge_f32_i32(
template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;
template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;
kernel void kernel_leaky_relu_f32(
constant ggml_metal_kargs_leaky_relu & args,
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
const float x = src0[tpig];
dst[tpig] = x > 0.0f ? x : x * args.slope;
}
kernel void kernel_leaky_relu_f32_4(
constant ggml_metal_kargs_leaky_relu & args,
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
const float4 x = src0[tpig];
dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope);
}
constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]];
constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]];
@@ -6161,7 +5966,7 @@ kernel void kernel_flash_attn_ext_vec(
static_assert(DK4 % NL == 0, "DK4 must be divisible by NL");
static_assert(DV4 % NL == 0, "DV4 must be divisible by NL");
const short T = PK + NSG*SH; // shared memory size per query in (half)
//const short T = PK + NSG*SH; // shared memory size per query in (half)
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t
@@ -8749,7 +8554,9 @@ kernel void kernel_mul_mm(
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
#ifdef GGML_METAL_HAS_TENSOR
threadgroup float * sc = (threadgroup float *)(shmem);
#endif
constexpr int NR0 = 64;
constexpr int NR1 = 32;
@@ -8872,8 +8679,8 @@ kernel void kernel_mul_mm(
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
const short dx = sx;
const short dy = sy;
//const short dx = sx;
//const short dy = sy;
const short ly = (tiitg/NL1)%8;
@@ -9122,7 +8929,9 @@ kernel void kernel_mul_mm_id(
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
#ifdef GGML_METAL_HAS_TENSOR
threadgroup float * sc = (threadgroup float *)(shmem);
#endif
constexpr int NR0 = 64;
constexpr int NR1 = 32;
@@ -9257,8 +9066,8 @@ kernel void kernel_mul_mm_id(
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
const short dx = sx;
const short dy = sy;
//const short dx = sx;
//const short dy = sy;
const short ly = (tiitg/NL1)%8;
@@ -9939,7 +9748,7 @@ kernel void kernel_opt_step_sgd_f32(
template<typename T>
kernel void kernel_memset(
constant ggml_metal_kargs_fill & args,
constant ggml_metal_kargs_memset & args,
device T * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = args.val;
+6
View File
@@ -85,6 +85,9 @@ set(GGML_OPENCL_KERNELS
mul_mv_q4_0_f32_8x_flat
mul_mv_q4_0_f32_1d_8x_flat
mul_mv_q4_0_f32_1d_16x_flat
mul_mv_q4_1_f32
mul_mv_q4_1_f32_flat
mul_mv_q4_k_f32
mul_mv_q6_k_f32
mul_mv_q6_k_f32_flat
mul_mv_q8_0_f32
@@ -100,7 +103,10 @@ set(GGML_OPENCL_KERNELS
gemv_moe_mxfp4_f32
mul_mm_f32_f32_l4_lm
mul_mm_f16_f32_l4_lm
mul_mm_q4_0_f32_l4_lm
mul_mm_q4_1_f32_l4_lm
mul_mm_q8_0_f32_l4_lm
mul_mm_q6_k_f32_l4_lm
mul_mm_q8_0_f32_8x4
gemv_noshuffle_general_q8_0_f32
mul
File diff suppressed because it is too large Load Diff
+51
View File
@@ -46,6 +46,15 @@ struct block_q4_0
uint8_t qs[QK4_0 / 2];
};
//------------------------------------------------------------------------------
// block_q4_1
//------------------------------------------------------------------------------
struct block_q4_1 {
half d; // delta
half m; // min
uchar qs[QK4_1 / 2]; // nibbles / quants
};
//------------------------------------------------------------------------------
// block_q6_K
//------------------------------------------------------------------------------
@@ -148,6 +157,48 @@ kernel void kernel_restore_block_q4_0_noshuffle(
}
}
//------------------------------------------------------------------------------
// kernel_convert_block_q4_1
// Convert the block_q4_1 format to 2 separate arrays (AOS -> SOA).
// This kernel does not deshuffle the bits.
//------------------------------------------------------------------------------
kernel void kernel_convert_block_q4_1(
global struct block_q4_1 * src0,
global uchar * dst_q,
global half * dst_d,
global half * dst_m
) {
global struct block_q4_1 * b = (global struct block_q4_1 *) src0 + get_global_id(0);
global uchar * q = (global uchar *) dst_q + QK4_1/2*get_global_id(0);
global half * d = (global half *) dst_d + get_global_id(0);
global half * m = (global half *) dst_m + get_global_id(0);
*d = b->d;
*m = b->m;
for (int i = 0; i < QK4_1/2; ++i) {
q[i] = b->qs[i];
}
}
kernel void kernel_restore_block_q4_1(
global uchar * src_q,
global half * src_d,
global half * src_m,
global struct block_q4_1 * dst
) {
global struct block_q4_1 * b = (global struct block_q4_1 *) dst + get_global_id(0);
global uchar * q = (global uchar *) src_q + QK4_1/2*get_global_id(0);
global half * d = (global half *) src_d + get_global_id(0);
global half * m = (global half *) src_m + get_global_id(0);
b->d = *d;
b->m = *m;
for (int i = 0; i < QK4_1/2; ++i) {
b->qs[i] = q[i];
}
}
//------------------------------------------------------------------------------
// block_mxfp4
//------------------------------------------------------------------------------
+87 -56
View File
@@ -3,80 +3,111 @@
//------------------------------------------------------------------------------
// expm1
//------------------------------------------------------------------------------
kernel void kernel_expm1_f32_nd(
global void * p_src0_base,
ulong off_src0_abs,
global void * p_dst_base,
ulong off_dst_abs,
int ne00,
int ne01,
int ne02,
int ne03,
kernel void kernel_expm1_f32(
global const float * src0,
ulong offset0,
global float * dst,
ulong offsetd
) {
src0 = (global float*)((global char*)src0 + offset0);
dst = (global float*)((global char*)dst + offsetd);
dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0f;
}
kernel void kernel_expm1_f32_4(
global const float4 * src0,
ulong offset0,
global float4 * dst,
ulong offsetd
) {
src0 = (global float4*)((global char*)src0 + offset0);
dst = (global float4*)((global char*)dst + offsetd);
dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0f;
}
kernel void kernel_expm1_f16(
global const half * src0,
ulong offset0,
global half * dst,
ulong offsetd
) {
src0 = (global half*)((global char*)src0 + offset0);
dst = (global half*)((global char*)dst + offsetd);
dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0h;
}
kernel void kernel_expm1_f16_4(
global const half4 * src0,
ulong offset0,
global half4 * dst,
ulong offsetd
) {
src0 = (global half4*)((global char*)src0 + offset0);
dst = (global half4*)((global char*)dst + offsetd);
dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0h;
}
kernel void kernel_expm1_f32_nc(
global const char * src0,
ulong offset0,
global char * dst,
ulong offsetd,
int ne00,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
int ne11,
int ne12,
int ne13,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb13
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3
) {
int i0 = get_global_id(0);
int i1 = get_global_id(1);
int i2 = get_global_id(2);
src0 = src0 + offset0;
dst = dst + offsetd;
if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
for (int i3 = 0; i3 < ne13; ++i3) {
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
const int i3 = get_group_id(2);
const int i2 = get_group_id(1);
const int i1 = get_group_id(0);
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
*dst_val_ptr = exp(*src_val_ptr) - 1;
}
*y = exp(*x) - 1.0f;
}
}
kernel void kernel_expm1_f16_nd(
global void * p_src0_base,
ulong off_src0_abs,
global void * p_dst_base,
ulong off_dst_abs,
int ne00,
int ne01,
int ne02,
int ne03,
kernel void kernel_expm1_f16_nc(
global const char * src0,
ulong offset0,
global char * dst,
ulong offsetd,
int ne00,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
int ne11,
int ne12,
int ne13,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb13
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3
) {
int i0 = get_global_id(0);
int i1 = get_global_id(1);
int i2 = get_global_id(2);
src0 = src0 + offset0;
dst = dst + offsetd;
if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
for (int i3 = 0; i3 < ne13; ++i3) {
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
const int i3 = get_group_id(2);
const int i2 = get_group_id(1);
const int i1 = get_group_id(0);
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
global const half * x = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
global half * y = (global half *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
*dst_val_ptr = exp(*src_val_ptr) - 1;
}
*y = exp(*x) - 1.0f;
}
}
+116 -15
View File
@@ -1,8 +1,13 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
// Most devices have max workgroup size of 1024, so this is enough for subgroup
// sizes of 16, 32, 64 and 128. Increase this value for smaller subgroups sizes
#define MAX_SUBGROUPS 64
kernel void kernel_mean_f32(
global float * src0,
global char * src0,
ulong offset0,
global float * dst,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
@@ -15,25 +20,121 @@ kernel void kernel_mean_f32(
ulong nb2,
ulong nb3
) {
src0 = (global float *)((global char *)src0 + offset0);
dst = (global float *)((global char *)dst + offsetd);
src0 = src0 + offset0;
dst = dst + offsetd;
int i3 = get_global_id(2);
int i2 = get_global_id(1);
int i1 = get_global_id(0);
const int i3 = get_group_id(2);
const int i2 = get_group_id(1);
const int i1 = get_group_id(0);
const int lid = get_local_id(0);
const int lsize = get_local_size(0);
const uint sg_size = get_sub_group_size();
const uint sg_id = get_sub_group_id();
const uint sg_lid = get_sub_group_local_id();
__local float lmem[MAX_SUBGROUPS];
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
return;
}
global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
global float * dst_row = (global float *) ((global char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
float row_sum = 0;
for (int i0 = 0; i0 < ne00; i0++) {
row_sum += src_row[i0];
if(sg_id == 0){
lmem[sg_lid] = 0.0f;
}
dst_row[0] = row_sum / ne00;
global float * src_row = (global float *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);
global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3);
float sumf = 0.0f;
for (int i0 = lid; i0 < ne00; i0 += lsize) {
sumf += src_row[i0];
}
sumf = sub_group_reduce_add(sumf);
barrier(CLK_LOCAL_MEM_FENCE);
if(sg_lid == 0){
lmem[sg_id] = sumf;
}
barrier(CLK_LOCAL_MEM_FENCE);
sumf = lmem[sg_lid];
sumf = sub_group_reduce_add(sumf);
if (lid == 0) {
dst_row[0] = sumf / ne00;
}
}
kernel void kernel_mean_f32_4(
global char * src0,
ulong offset0,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb01,
ulong nb02,
ulong nb03,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = src0 + offset0;
dst = dst + offsetd;
const int i3 = get_group_id(2);
const int i2 = get_group_id(1);
const int i1 = get_group_id(0);
const int lid = get_local_id(0);
const int lsize = get_local_size(0);
const uint sg_size = get_sub_group_size();
const uint sg_id = get_sub_group_id();
const uint sg_lid = get_sub_group_local_id();
__local float lmem[MAX_SUBGROUPS];
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
return;
}
if(sg_id == 0){
lmem[sg_lid] = 0.0f;
}
global float4 * src_row = (global float4 *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);
global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3);
float4 sum_vec = (float4)0.0f;
for (int i0 = lid; i0 < ne00 / 4; i0 += lsize) {
sum_vec += src_row[i0];
}
float sumf = dot(sum_vec, (float4)(1.0f));
sumf = sub_group_reduce_add(sumf);
barrier(CLK_LOCAL_MEM_FENCE);
if(sg_lid == 0){
lmem[sg_id] = sumf;
}
barrier(CLK_LOCAL_MEM_FENCE);
sumf = lmem[sg_lid];
sumf = sub_group_reduce_add(sumf);
if (lid == 0) {
dst_row[0] = sumf / ne00;
}
}
@@ -0,0 +1,163 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#define LOAD_VEC_A 8
#define LOAD_VEC_B 4
#define BM 64
#define BN 64
#define BK 32
#define TM 4
#define TN 8
kernel void kernel_mul_mm_q4_0_f32_l4_lm(
global uchar4 * src0_q,
global half * src0_d,
global float4 * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne11,
int ne12,
int stride_a,
int stride_b,
int stride_d,
int batch_stride_a,
int batch_stride_b,
int batch_stride_d,
int r2,
int r3
) {
src1 = (global float4*)((global char*)src1 + offset1);
dst = (global float *)((global char*)dst + offsetd);
local float buf_a[BM * BK];
local float buf_b[BN * BK];
const int batch_idx = get_global_id(2);
const int i13 = batch_idx / ne12;
const int i12 = batch_idx % ne12;
const int i03 = i13 / r3;
const int i02 = i12 / r2;
const int batch_idx_a = i03 * ne02 + i02;
const int ir = get_group_id(0);
const int ic = get_group_id(1);
const int tid = get_local_id(0);
const int th_r = tid % (BM / TM);
const int th_c = tid / (BM / TM);
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
float sums[TM * TN];
float cache_a[TM];
float cache_b[TN];
for (int i = 0; i < TM * TN; i++) {
sums[i] = 0.0f;
}
for (int block = 0; block < ne00; block += BK) {
for (int l = 0; l < BM; l += loadstride_a) {
if (ir*BM + loadc_a + l < ne01) {
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
int ib = idx / 4;
int iqs = idx % 4;
float d = (float)src0_d[ib];
global uchar4 * qs = src0_q + ib*4 + iqs;
uchar4 q = *qs;
float4 v1 = (convert_float4((uchar4)((q.s0 )&0x0F, (q.s1 )&0x0F, (q.s2 )&0x0F, (q.s3 )&0x0F)) - 8.0f)*d;
float4 v2 = (convert_float4((uchar4)((q.s0>>4)&0x0F, (q.s1>>4)&0x0F, (q.s2>>4)&0x0F, (q.s3>>4)&0x0F)) - 8.0f)*d;
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0;
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1;
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2;
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3;
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0;
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1;
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2;
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3;
} else {
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f;
}
}
for (int l = 0; l < BN; l += loadstride_b) {
if (ic*BN + loadc_b + l < ne11) {
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
} else {
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
}
}
barrier(CLK_LOCAL_MEM_FENCE);
pos_a += BK / LOAD_VEC_A;
pos_b += BK / LOAD_VEC_B;
for (int i = 0; i < BK; i++) {
for (int j = 0; j < TM; j++) {
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
}
for (int j = 0; j < TN; j++) {
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
}
for (int cc = 0; cc < TN; cc++) {
for (int cr = 0; cr < TM; cr++) {
const int sums_idx = cc*TM + cr;
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
}
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
const int dr = ir * BM + th_r * TM;
const int dc = ic * BN + th_c * TN;
const int offsets = batch_idx * batch_stride_d;
for (int cc = 0; cc < TN; cc++) {
for (int cr = 0; cr < TM; cr++) {
if (dr + cr < ne01 && dc + cc < ne11) {
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
}
}
}
}
@@ -0,0 +1,165 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#define LOAD_VEC_A 8
#define LOAD_VEC_B 4
#define BM 64
#define BN 64
#define BK 32
#define TM 4
#define TN 8
kernel void kernel_mul_mm_q4_1_f32_l4_lm(
global uchar4 * src0_q,
global half * src0_d,
global half * src0_m,
global float4 * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne11,
int ne12,
int stride_a,
int stride_b,
int stride_d,
int batch_stride_a,
int batch_stride_b,
int batch_stride_d,
int r2,
int r3
) {
src1 = (global float4*)((global char*)src1 + offset1);
dst = (global float *)((global char*)dst + offsetd);
local float buf_a[BM * BK];
local float buf_b[BN * BK];
const int batch_idx = get_global_id(2);
const int i13 = batch_idx / ne12;
const int i12 = batch_idx % ne12;
const int i03 = i13 / r3;
const int i02 = i12 / r2;
const int batch_idx_a = i03 * ne02 + i02;
const int ir = get_group_id(0);
const int ic = get_group_id(1);
const int tid = get_local_id(0);
const int th_r = tid % (BM / TM);
const int th_c = tid / (BM / TM);
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
float sums[TM * TN];
float cache_a[TM];
float cache_b[TN];
for (int i = 0; i < TM * TN; i++) {
sums[i] = 0.0f;
}
for (int block = 0; block < ne00; block += BK) {
for (int l = 0; l < BM; l += loadstride_a) {
if (ir*BM + loadc_a + l < ne01) {
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
int ib = idx / 4;
int iqs = idx % 4;
float d = (float)src0_d[ib];
float m = (float)src0_m[ib];
global uchar4 * qs = src0_q + ib*4 + iqs;
uchar4 q = *qs;
float4 v1 = (convert_float4((uchar4)((q.s0 )&0x0F, (q.s1 )&0x0F, (q.s2 )&0x0F, (q.s3 )&0x0F)))*d + m;
float4 v2 = (convert_float4((uchar4)((q.s0>>4)&0x0F, (q.s1>>4)&0x0F, (q.s2>>4)&0x0F, (q.s3>>4)&0x0F)))*d + m;
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0;
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1;
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2;
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3;
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0;
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1;
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2;
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3;
} else {
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f;
}
}
for (int l = 0; l < BN; l += loadstride_b) {
if (ic*BN + loadc_b + l < ne11) {
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
} else {
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
}
}
barrier(CLK_LOCAL_MEM_FENCE);
pos_a += BK / LOAD_VEC_A;
pos_b += BK / LOAD_VEC_B;
for (int i = 0; i < BK; i++) {
for (int j = 0; j < TM; j++) {
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
}
for (int j = 0; j < TN; j++) {
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
}
for (int cc = 0; cc < TN; cc++) {
for (int cr = 0; cr < TM; cr++) {
const int sums_idx = cc*TM + cr;
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
}
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
const int dr = ir * BM + th_r * TM;
const int dc = ic * BN + th_c * TN;
const int offsets = batch_idx * batch_stride_d;
for (int cc = 0; cc < TN; cc++) {
for (int cr = 0; cr < TM; cr++) {
if (dr + cr < ne01 && dc + cc < ne11) {
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
}
}
}
}
@@ -0,0 +1,158 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#define LOAD_VEC_A 2
#define LOAD_VEC_B 4
#define BM 64
#define BN 64
#define BK 32
#define TM 4
#define TN 8
kernel void kernel_mul_mm_q6_k_f32_l4_lm(
global uchar * src0_ql,
global uchar * src0_qh,
global char * src0_s,
global half * src0_d,
global float4 * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne11,
int ne12,
int stride_a,
int stride_b,
int stride_d,
int batch_stride_a,
int batch_stride_b,
int batch_stride_d,
int r2,
int r3
) {
src1 = (global float4*)((global char*)src1 + offset1);
dst = (global float *)((global char*)dst + offsetd);
local float buf_a[BM * BK];
local float buf_b[BN * BK];
const int batch_idx = get_global_id(2);
const int i13 = batch_idx / ne12;
const int i12 = batch_idx % ne12;
const int i03 = i13 / r3;
const int i02 = i12 / r2;
const int batch_idx_a = i03 * ne02 + i02;
const int ir = get_group_id(0);
const int ic = get_group_id(1);
const int tid = get_local_id(0);
const int th_r = tid % (BM / TM);
const int th_c = tid / (BM / TM);
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
float sums[TM * TN];
float cache_a[TM];
float cache_b[TN];
for (int i = 0; i < TM * TN; i++) {
sums[i] = 0.0f;
}
for (int block = 0; block < ne00; block += BK) {
for (int l = 0; l < BM; l += loadstride_a) {
if (ir*BM + loadc_a + l < ne01) {
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
int ib = idx / 128; // 2 values per idx
int iqs = idx % 128; // 0..127
int n = iqs / 64; // 0,1
int b = (iqs % 64) / 32; // 0,1
int is_b = (iqs % 16) / 8; // 0,1
int qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
int is = 8 * n + qhshift + is_b; // 0..15
int qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126
int qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
float dscale = (float)src0_d[ib] * (float)src0_s[ib*16 + is];
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 0] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 0] >> qhshift) & 3) << 4)) - 32);
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 1] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 1] >> qhshift) & 3) << 4)) - 32);
} else {
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f;
}
}
for (int l = 0; l < BN; l += loadstride_b) {
if (ic*BN + loadc_b + l < ne11) {
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
} else {
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
}
}
barrier(CLK_LOCAL_MEM_FENCE);
pos_a += BK / LOAD_VEC_A;
pos_b += BK / LOAD_VEC_B;
for (int i = 0; i < BK; i++) {
for (int j = 0; j < TM; j++) {
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
}
for (int j = 0; j < TN; j++) {
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
}
for (int cc = 0; cc < TN; cc++) {
for (int cr = 0; cr < TM; cr++) {
const int sums_idx = cc*TM + cr;
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
}
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
const int dr = ir * BM + th_r * TM;
const int dc = ic * BN + th_c * TN;
const int offsets = batch_idx * batch_stride_d;
for (int cc = 0; cc < TN; cc++) {
for (int cr = 0; cr < TM; cr++) {
if (dr + cr < ne01 && dc + cc < ne11) {
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
}
}
}
}
@@ -0,0 +1,219 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_subgroups
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#endif
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#define QK4_1 32
struct block_q4_1 {
half d; // delta
half m; // min
uchar qs[QK4_1 / 2]; // nibbles / quants
};
inline float block_q4_1_dot_y(
global const struct block_q4_1 * qb_curr,
float sumy,
float16 yl,
int il
) {
float d = qb_curr->d;
float m = qb_curr->m;
float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f);
global const ushort * qs = ((global const ushort *) qb_curr + 2 + il/2);
acc.s0 += yl.s0 * (qs[0] & 0x000F);
acc.s0 += yl.s1 * (qs[0] & 0x0F00);
acc.s0 += yl.s8 * (qs[0] & 0x00F0);
acc.s3 += yl.s9 * (qs[0] & 0xF000);
acc.s0 += yl.s2 * (qs[1] & 0x000F);
acc.s1 += yl.s3 * (qs[1] & 0x0F00);
acc.s2 += yl.sa * (qs[1] & 0x00F0);
acc.s3 += yl.sb * (qs[1] & 0xF000);
acc.s0 += yl.s4 * (qs[2] & 0x000F);
acc.s1 += yl.s5 * (qs[2] & 0x0F00);
acc.s2 += yl.sc * (qs[2] & 0x00F0);
acc.s3 += yl.sd * (qs[2] & 0xF000);
acc.s0 += yl.s6 * (qs[3] & 0x000F);
acc.s1 += yl.s7 * (qs[3] & 0x0F00);
acc.s2 += yl.se * (qs[3] & 0x00F0);
acc.s3 += yl.sf * (qs[3] & 0xF000);
return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3) + sumy * m;
}
#undef N_DST
#undef N_SIMDGROUP
#undef N_SIMDWIDTH
#ifdef INTEL_GPU
#define N_DST 4 // each subgroup works on 4 rows
#define N_SIMDGROUP 1 // number of subgroups in a thread group
#define N_SIMDWIDTH 16 // assuming subgroup size is 16
#elif defined (ADRENO_GPU)
#define N_DST 4
#define N_SIMDGROUP 1
#define N_SIMDWIDTH 64
#endif
inline void mul_vec_q_n_f32(
global void * src0,
global float * src1,
global float * dst,
int ne00,
int ne01,
int ne02,
int ne10,
int ne12,
int ne0,
int ne1,
int r2,
int r3
) {
const ulong nb = ne00/QK4_1;
int r0 = get_group_id(0);
int r1 = get_group_id(1);
int im = get_group_id(2);
int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
int i12 = im%ne12;
int i13 = im/ne12;
ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
global struct block_q4_1 * x = (global struct block_q4_1 *) src0 + offset0;
global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1;
float16 yl;
float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f);
int ix = get_sub_group_local_id()/2;
int il = 8*(get_sub_group_local_id()%2);
global float * yb = y + ix * QK4_1 + il;
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
float sumy = 0;
sumy += yb[0];
sumy += yb[1];
sumy += yb[2];
sumy += yb[3];
sumy += yb[4];
sumy += yb[5];
sumy += yb[6];
sumy += yb[7];
sumy += yb[16];
sumy += yb[17];
sumy += yb[18];
sumy += yb[19];
sumy += yb[20];
sumy += yb[21];
sumy += yb[22];
sumy += yb[23];
yl.s0 = yb[0];
yl.s1 = yb[1]/256.f;
yl.s2 = yb[2];
yl.s3 = yb[3]/256.f;
yl.s4 = yb[4];
yl.s5 = yb[5]/256.f;
yl.s6 = yb[6];
yl.s7 = yb[7]/256.f;
yl.s8 = yb[16]/16.f;
yl.s9 = yb[17]/4096.f;
yl.sa = yb[18]/16.f;
yl.sb = yb[19]/4096.f;
yl.sc = yb[20]/16.f;
yl.sd = yb[21]/4096.f;
yl.se = yb[22]/16.f;
yl.sf = yb[23]/4096.f;
sumf.s0 += block_q4_1_dot_y(x+ib+0*nb, sumy, yl, il);
sumf.s1 += block_q4_1_dot_y(x+ib+1*nb, sumy, yl, il);
sumf.s2 += block_q4_1_dot_y(x+ib+2*nb, sumy, yl, il);
sumf.s3 += block_q4_1_dot_y(x+ib+3*nb, sumy, yl, il);
yb += QK4_1 * (N_SIMDWIDTH/2);
}
float4 tot = (float4)(
sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3)
);
if (get_sub_group_local_id() == 0) {
if (first_row + 0 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
}
if (first_row + 1 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
}
if (first_row + 2 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
}
if (first_row + 3 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
}
}
}
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_16
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_mul_mv_q4_1_f32(
global void * src0,
ulong offset0,
global float * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne10,
int ne12,
int ne0,
int ne1,
int r2,
int r3
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global float*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);
mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
}
@@ -0,0 +1,229 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_subgroups
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#endif
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#define QK4_1 32
struct block_q4_1 {
half d; // delta
half m; // min
uchar qs[QK4_1 / 2]; // nibbles / quants
};
inline float block_q4_1_dot_y_flat(
global const uchar * x,
global const half * dh,
global const half * mh,
float sumy,
float16 yl,
int il
) {
float d = *dh;
float m = *mh;
global const ushort * qs = ((global const ushort *) x + il/2);
float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f);
acc.s0 += yl.s0 * (qs[0] & 0x000F);
acc.s0 += yl.s1 * (qs[0] & 0x0F00);
acc.s0 += yl.s8 * (qs[0] & 0x00F0);
acc.s3 += yl.s9 * (qs[0] & 0xF000);
acc.s0 += yl.s2 * (qs[1] & 0x000F);
acc.s1 += yl.s3 * (qs[1] & 0x0F00);
acc.s2 += yl.sa * (qs[1] & 0x00F0);
acc.s3 += yl.sb * (qs[1] & 0xF000);
acc.s0 += yl.s4 * (qs[2] & 0x000F);
acc.s1 += yl.s5 * (qs[2] & 0x0F00);
acc.s2 += yl.sc * (qs[2] & 0x00F0);
acc.s3 += yl.sd * (qs[2] & 0xF000);
acc.s0 += yl.s6 * (qs[3] & 0x000F);
acc.s1 += yl.s7 * (qs[3] & 0x0F00);
acc.s2 += yl.se * (qs[3] & 0x00F0);
acc.s3 += yl.sf * (qs[3] & 0xF000);
return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3) + sumy * m;
}
#undef N_DST
#undef N_SIMDGROUP
#undef N_SIMDWIDTH
#ifdef INTEL_GPU
#define N_DST 4 // each subgroup works on 4 rows
#define N_SIMDGROUP 1 // number of subgroups in a thread group
#define N_SIMDWIDTH 16 // assuming subgroup size is 16
#elif defined (ADRENO_GPU)
#define N_DST 4
#define N_SIMDGROUP 1
#define N_SIMDWIDTH 64
#endif
inline void mul_vec_q_n_f32_flat(
global void * src0_q,
global void * src0_d,
global void * src0_m,
global float * src1,
global float * dst,
int ne00,
int ne01,
int ne02,
int ne10,
int ne12,
int ne0,
int ne1,
int r2,
int r3
) {
const ulong nb = ne00/QK4_1;
int r0 = get_group_id(0);
int r1 = get_group_id(1);
int im = get_group_id(2);
int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
int i12 = im%ne12;
int i13 = im/ne12;
ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
// The number of scales/mins is the same as the number of blocks.
ulong offset0_dm = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02));
// Each block contains QK4_1/2 uchars, hence offset for qs is as follows.
ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_1/2;
global uchar * x = (global uchar *) src0_q + offset0_q;
global half * d = (global half *) src0_d + offset0_dm;
global half * m = (global half *) src0_m + offset0_dm;
global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1;
float16 yl;
float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f);
int ix = get_sub_group_local_id()/2;
int il = 8*(get_sub_group_local_id()%2);
global float * yb = y + ix * QK4_1 + il;
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
float sumy = 0;
sumy += yb[0];
sumy += yb[1];
sumy += yb[2];
sumy += yb[3];
sumy += yb[4];
sumy += yb[5];
sumy += yb[6];
sumy += yb[7];
sumy += yb[16];
sumy += yb[17];
sumy += yb[18];
sumy += yb[19];
sumy += yb[20];
sumy += yb[21];
sumy += yb[22];
sumy += yb[23];
yl.s0 = yb[0];
yl.s1 = yb[1]/256.f;
yl.s2 = yb[2];
yl.s3 = yb[3]/256.f;
yl.s4 = yb[4];
yl.s5 = yb[5]/256.f;
yl.s6 = yb[6];
yl.s7 = yb[7]/256.f;
yl.s8 = yb[16]/16.f;
yl.s9 = yb[17]/4096.f;
yl.sa = yb[18]/16.f;
yl.sb = yb[19]/4096.f;
yl.sc = yb[20]/16.f;
yl.sd = yb[21]/4096.f;
yl.se = yb[22]/16.f;
yl.sf = yb[23]/4096.f;
sumf.s0 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 0*nb*QK4_1/2, d + ib + 0*nb, m + ib + 0*nb, sumy, yl, il);
sumf.s1 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 1*nb*QK4_1/2, d + ib + 1*nb, m + ib + 1*nb, sumy, yl, il);
sumf.s2 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 2*nb*QK4_1/2, d + ib + 2*nb, m + ib + 2*nb, sumy, yl, il);
sumf.s3 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 3*nb*QK4_1/2, d + ib + 3*nb, m + ib + 3*nb, sumy, yl, il);
yb += QK4_1 * (N_SIMDWIDTH/2);
}
float4 tot = (float4)(
sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3)
);
if (get_sub_group_local_id() == 0) {
if (first_row + 0 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
}
if (first_row + 1 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
}
if (first_row + 2 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
}
if (first_row + 3 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
}
}
}
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_16
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_mul_mv_q4_1_f32_flat(
global void * src0_q,
global void * src0_d,
global void * src0_m,
global float * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne10,
int ne12,
int ne0,
int ne1,
int r2,
int r3
) {
src1 = (global float*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);
mul_vec_q_n_f32_flat(src0_q, src0_d, src0_m, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
}
@@ -0,0 +1,180 @@
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
//------------------------------------------------------------------------------
// block_q4_K
//------------------------------------------------------------------------------
#define QK_K 256
#define K_SCALE_SIZE 12
// 8 blocks of 32 elements each
// weight is represented as x = a * q + b
typedef struct {
half d; // super-block scale for quantized scales
half dmin; // super-block scale for quantized mins
uchar scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
uchar qs[QK_K/2]; // 4-bit quants
} block_q4_K;
#undef N_DST
#undef N_SIMDGROUP
#undef N_SIMDWIDTH
#ifdef INTEL_GPU
#define N_DST 4 // number of rows each SIMD group works on
#define N_SIMDGROUP 1 // number of SIMD groups in a thread group
#define N_SIMDWIDTH 16 // SIMD group size
#elif defined (ADRENO_GPU)
#define N_DST 4
#define N_SIMDGROUP 1
#define N_SIMDWIDTH 64
#endif
#undef BLOCK_STRIDE
// number of (super) blocks each subgroup processes
// each thread in a subgroup processes a block (32 weights)
#define BLOCK_STRIDE (N_SIMDWIDTH/8)
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_16
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_mul_mv_q4_K_f32(
global char * src0,
int offset0,
global char * src1,
int offset1,
global char * dst,
int offsetd,
int ne00,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
int ne12,
ulong nb11,
ulong nb12,
ulong nb13,
int ne0,
int ne1,
int r2,
int r3
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
ushort kmask1 = 0x3f3f;
ushort kmask2 = 0x0f0f;
ushort kmask3 = 0xc0c0;
int ix = get_sub_group_local_id()/8; // super block index
int it = get_sub_group_local_id()%8; // block index (inside super block)
int iq = it/4; // 0 or 1 - first or second half of the super block
int ir = it%4; // 0...3 - block index in the half super block
int nb = ne00/QK_K;
int r0 = get_group_id(0);
int r1 = get_group_id(1);
int im = get_group_id(2);
int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
int i12 = im%ne12;
int i13 = im/ne12;
int offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
int offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
global block_q4_K * x = (global block_q4_K *) (src0 + offset_src0);
global float * y = (global float *) (src1 + offset_src1);
float yl[16];
float yh[16];
float sumf[N_DST] = {0.f};
float all_sum;
global float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
ushort sc16[4];
uchar * sc8 = (uchar *)sc16;
for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) {
float4 sumy = {0.f, 0.f, 0.f, 0.f};
for (int i = 0; i < 8; ++i) {
yl[i+0] = y4[i+0];
sumy.s0 += yl[i+0];
yl[i+8] = y4[i+32];
sumy.s1 += yl[i+8];
yh[i+0] = y4[i+128];
sumy.s2 += yh[i+0];
yh[i+8] = y4[i+160];
sumy.s3 += yh[i+8];
}
global ushort * sc = (global ushort *)x[ib].scales + iq;
global ushort * q1 = (global ushort *)x[ib].qs + 16 * iq + 4 * ir;
global half * dh = &x[ib].d;
for (int row = 0; row < N_DST; row++) {
sc16[0] = sc[0] & kmask1;
sc16[1] = sc[2] & kmask1;
sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
global ushort * q2 = q1 + 32;
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
for (int i = 0; i < 8; i += 2) {
acc1.s0 += yl[i+0] * (q1[i/2] & 0x000F);
acc1.s1 += yl[i+1] * (q1[i/2] & 0x0F00);
acc1.s2 += yl[i+8] * (q1[i/2] & 0x00F0);
acc1.s3 += yl[i+9] * (q1[i/2] & 0xF000);
acc2.s0 += yh[i+0] * (q2[i/2] & 0x000F);
acc2.s1 += yh[i+1] * (q2[i/2] & 0x0F00);
acc2.s2 += yh[i+8] * (q2[i/2] & 0x00F0);
acc2.s3 += yh[i+9] * (q2[i/2] & 0xF000);
}
float dall = dh[0];
float dmin = dh[1];
sumf[row] += dall * ((acc1.s0 + 1.f/256.f * acc1.s1) * sc8[0] +
(acc1.s2 + 1.f/256.f * acc1.s3) * sc8[1] * 1.f/16.f +
(acc2.s0 + 1.f/256.f * acc2.s1) * sc8[4] +
(acc2.s2 + 1.f/256.f * acc2.s3) * sc8[5] * 1.f/16.f) -
dmin * (sumy.s0 * sc8[2] + sumy.s1 * sc8[3] + sumy.s2 * sc8[6] + sumy.s3 * sc8[7]);
q1 += nb01/2;
sc += nb01/2;
dh += nb01/2;
}
y4 += BLOCK_STRIDE * QK_K;
}
global float * dst_f32 = (global float *) dst + im*ne0*ne1 + r1*ne0;
for (int row = 0; row < N_DST; ++row) {
all_sum = sub_group_reduce_add(sumf[row]);
if (first_row + row < ne01) {
if (get_sub_group_local_id() == 0) {
dst_f32[first_row + row] = all_sum;
}
}
}
}
+88 -60
View File
@@ -3,86 +3,114 @@
//------------------------------------------------------------------------------
// softplus
//------------------------------------------------------------------------------
inline float softplus_f32(float x){
float ax = fabs(x);
float m = fmax(x, 0.0f);
return log1p(exp(-ax)) + m;
kernel void kernel_softplus_f32(
global const float * src0,
ulong offset0,
global float * dst,
ulong offsetd
) {
src0 = (global float*)((global char*)src0 + offset0);
dst = (global float*)((global char*)dst + offsetd);
dst[get_global_id(0)] = (src0[get_global_id(0)] > 20.0f) ? src0[get_global_id(0)] : log(1.0f + exp(src0[get_global_id(0)]));
}
kernel void kernel_softplus_f32_nd(
global void * p_src0_base,
ulong off_src0_abs,
global void * p_dst_base,
ulong off_dst_abs,
int ne00,
int ne01,
int ne02,
int ne03,
kernel void kernel_softplus_f32_4(
global const float4 * src0,
ulong offset0,
global float4 * dst,
ulong offsetd
) {
src0 = (global float4*)((global char*)src0 + offset0);
dst = (global float4*)((global char*)dst + offsetd);
dst[get_global_id(0)] = (src0[get_global_id(0)] > 20.0f) ? src0[get_global_id(0)] : log(1.0f + exp(src0[get_global_id(0)]));
}
kernel void kernel_softplus_f16(
global const half * src0,
ulong offset0,
global half * dst,
ulong offsetd
) {
src0 = (global half*)((global char*)src0 + offset0);
dst = (global half*)((global char*)dst + offsetd);
const float x = convert_float(src0[get_global_id(0)]);
dst[get_global_id(0)] = convert_half_rte((x > 20.0f) ? x : log(1.0f + exp(x)));
}
kernel void kernel_softplus_f16_4(
global const half4 * src0,
ulong offset0,
global half4 * dst,
ulong offsetd
) {
src0 = (global half4*)((global char*)src0 + offset0);
dst = (global half4*)((global char*)dst + offsetd);
const float4 x = convert_float4(src0[get_global_id(0)]);
dst[get_global_id(0)] = convert_half4_rte((x > 20.0f) ? x : log(1.0f + exp(x)));
}
kernel void kernel_softplus_f32_nc(
global const char * src0,
ulong offset0,
global char * dst,
ulong offsetd,
int ne00,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
int ne11,
int ne12,
int ne13,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb13
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3
) {
int i0 = get_global_id(0);
int i1 = get_global_id(1);
int i2 = get_global_id(2);
src0 = src0 + offset0;
dst = dst + offsetd;
if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
for (int i3 = 0; i3 < ne13; ++i3) {
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
const int i3 = get_group_id(2);
const int i2 = get_group_id(1);
const int i1 = get_group_id(0);
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
*dst_val_ptr = softplus_f32(*src_val_ptr);
}
*y = (*x > 20.0f) ? *x : log(1.0f + exp(*x));
}
}
kernel void kernel_softplus_f16_nd(
global void * p_src0_base,
ulong off_src0_abs,
global void * p_dst_base,
ulong off_dst_abs,
int ne00,
int ne01,
int ne02,
int ne03,
kernel void kernel_softplus_f16_nc(
global const char * src0,
ulong offset0,
global char * dst,
ulong offsetd,
int ne00,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
int ne11,
int ne12,
int ne13,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb13
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3
) {
int i0 = get_global_id(0);
int i1 = get_global_id(1);
int i2 = get_global_id(2);
src0 = src0 + offset0;
dst = dst + offsetd;
if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
for (int i3 = 0; i3 < ne13; ++i3) {
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
const int i3 = get_group_id(2);
const int i2 = get_group_id(1);
const int i1 = get_group_id(0);
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
global const half * hx = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
global half * hy = (global half *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
*dst_val_ptr = (half)(softplus_f32((float)(*src_val_ptr)));
}
const float x = convert_float(*hx);
*hy = convert_half_rte((x > 20.0f) ? x : log(1.0f + exp(x)));
}
}
+116 -15
View File
@@ -1,8 +1,13 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
// Most devices have max workgroup size of 1024, so this is enough for subgroup
// sizes of 16, 32, 64 and 128. Increase this value for smaller subgroups sizes
#define MAX_SUBGROUPS 64
kernel void kernel_sum_rows_f32(
global float * src0,
global char * src0,
ulong offset0,
global float * dst,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
@@ -15,25 +20,121 @@ kernel void kernel_sum_rows_f32(
ulong nb2,
ulong nb3
) {
src0 = (global float *)((global char *)src0 + offset0);
dst = (global float *)((global char *)dst + offsetd);
src0 = src0 + offset0;
dst = dst + offsetd;
int i3 = get_global_id(2);
int i2 = get_global_id(1);
int i1 = get_global_id(0);
const int i3 = get_group_id(2);
const int i2 = get_group_id(1);
const int i1 = get_group_id(0);
const int lid = get_local_id(0);
const int lsize = get_local_size(0);
const uint sg_size = get_sub_group_size();
const uint sg_id = get_sub_group_id();
const uint sg_lid = get_sub_group_local_id();
__local float lmem[MAX_SUBGROUPS];
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
return;
}
global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
global float * dst_row = (global float *) ((global char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
float row_sum = 0;
for (int i0 = 0; i0 < ne00; i0++) {
row_sum += src_row[i0];
if(sg_id == 0){
lmem[sg_lid] = 0.0f;
}
dst_row[0] = row_sum;
global float * src_row = (global float *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);
global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3);
float sumf = 0.0f;
for (int i0 = lid; i0 < ne00; i0 += lsize) {
sumf += src_row[i0];
}
sumf = sub_group_reduce_add(sumf);
barrier(CLK_LOCAL_MEM_FENCE);
if(sg_lid == 0){
lmem[sg_id] = sumf;
}
barrier(CLK_LOCAL_MEM_FENCE);
sumf = lmem[sg_lid];
sumf = sub_group_reduce_add(sumf);
if (lid == 0) {
dst_row[0] = sumf;
}
}
kernel void kernel_sum_rows_f32_4(
global char * src0,
ulong offset0,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb01,
ulong nb02,
ulong nb03,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = src0 + offset0;
dst = dst + offsetd;
const int i3 = get_group_id(2);
const int i2 = get_group_id(1);
const int i1 = get_group_id(0);
const int lid = get_local_id(0);
const int lsize = get_local_size(0);
const uint sg_size = get_sub_group_size();
const uint sg_id = get_sub_group_id();
const uint sg_lid = get_sub_group_local_id();
__local float lmem[MAX_SUBGROUPS];
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
return;
}
if(sg_id == 0){
lmem[sg_lid] = 0.0f;
}
global float4 * src_row = (global float4 *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);
global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3);
float4 sum_vec = (float4)0.0f;
for (int i0 = lid; i0 < ne00 / 4; i0 += lsize) {
sum_vec += src_row[i0];
}
float sumf = dot(sum_vec, (float4)(1.0f));
sumf = sub_group_reduce_add(sumf);
barrier(CLK_LOCAL_MEM_FENCE);
if(sg_lid == 0){
lmem[sg_id] = sumf;
}
barrier(CLK_LOCAL_MEM_FENCE);
sumf = lmem[sg_lid];
sumf = sub_group_reduce_add(sumf);
if (lid == 0) {
dst_row[0] = sumf;
}
}
+92 -41
View File
@@ -92,6 +92,7 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
#define VK_VENDOR_ID_APPLE 0x106b
#define VK_VENDOR_ID_INTEL 0x8086
#define VK_VENDOR_ID_NVIDIA 0x10de
#define VK_VENDOR_ID_QUALCOMM 0x5143
#define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256
@@ -687,6 +688,7 @@ struct vk_device_struct {
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
vk_pipeline pipeline_acc_f32;
vk_pipeline pipeline_set_f32;
// [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16]
vk_pipeline pipeline_add[2][2][2];
@@ -942,6 +944,7 @@ struct vk_mat_mat_push_constants {
uint32_t M; uint32_t N; uint32_t K;
uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
uint32_t base_work_group_z; uint32_t num_batches;
uint32_t k_split;
uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
uint32_t padded_N;
@@ -961,6 +964,7 @@ struct vk_mat_vec_push_constants {
uint32_t batch_stride_b;
uint32_t batch_stride_d;
uint32_t fusion_flags;
uint32_t base_work_group_y;
uint32_t ne02;
uint32_t ne12;
uint32_t broadcast2;
@@ -4080,7 +4084,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
}
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@@ -4181,7 +4185,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_set_f32, "set_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 0}, 1);
ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
@@ -5641,6 +5646,10 @@ static void ggml_vk_instance_init() {
driver_priorities[vk::DriverId::eMesaNvk] = 2;
#endif
break;
case VK_VENDOR_ID_QUALCOMM:
driver_priorities[vk::DriverId::eQualcommProprietary] = 1;
driver_priorities[vk::DriverId::eMesaTurnip] = 2;
break;
}
driver_priorities[vk::DriverId::eMesaDozen] = 100;
@@ -6766,8 +6775,16 @@ static void ggml_vk_matmul(
uint32_t padded_n) {
VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")");
if (split_k == 1) {
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, batch });
ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2]));
uint32_t base_work_group_z = 0;
while (base_work_group_z < batch) {
uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k, ne02, ne12, broadcast2, broadcast3, padded_n };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, groups_z });
base_work_group_z += groups_z;
}
return;
}
@@ -6781,9 +6798,17 @@ static void ggml_vk_matmul(
uint32_t k_split = CEIL_DIV(k, split_k);
k_split = ROUNDUP_POW2(k_split, 256);
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
// Make sure enough workgroups get assigned for split k to work
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2]));
uint32_t base_work_group_z = 0;
while (base_work_group_z < batch) {
uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
// Make sure enough workgroups get assigned for split k to work
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, groups_z });
base_work_group_z += groups_z;
}
ggml_vk_sync_buffers(ctx, subctx);
const std::array<uint32_t, 2> pc2 = { (uint32_t)(m * n * batch), split_k };
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 });
@@ -7179,7 +7204,6 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
}
// Request descriptor sets
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
if (qx_needs_dequant) {
ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);
}
@@ -7477,7 +7501,6 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (quantize_y) {
ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
}
ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
}
vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
@@ -7572,22 +7595,29 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;
}
// compute
const vk_mat_vec_push_constants pc = {
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
stride_batch_x, stride_batch_y, stride_batch_d,
fusion_flags,
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
};
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
{
d_X,
d_Y,
d_D,
d_F0,
d_F1,
},
pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
ggml_pipeline_request_descriptor_sets(ctx, dmmv, CEIL_DIV(ne12 * ne13, ctx->device->properties.limits.maxComputeWorkGroupCount[1]));
uint32_t base_work_group_y = 0;
while (base_work_group_y < ne12 * ne13) {
uint32_t groups_y = std::min((uint32_t)(ne12 * ne13) - base_work_group_y, ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
const vk_mat_vec_push_constants pc = {
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
stride_batch_x, stride_batch_y, stride_batch_d,
fusion_flags, base_work_group_y,
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
};
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
{
d_X,
d_Y,
d_D,
d_F0,
d_F1,
},
pc, { groups_x, groups_y, groups_z });
base_work_group_y += groups_y;
}
if (x_non_contig) {
ctx->prealloc_x_need_sync = true;
@@ -7825,10 +7855,15 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
src1->nb[2] <= src1->nb[1] &&
src1->nb[1] <= src1->nb[3] &&
src0->ne[3] == 1 &&
src1->ne[3] == 1) {
src1->ne[3] == 1 &&
src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) {
ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, cgraph, node_idx);
} else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 &&
!ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
!ggml_is_permuted(src0) && !ggml_is_permuted(src1) &&
src0->ne[3] <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&
src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) {
ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, cgraph, node_idx);
// mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
// when ne12 and ne13 are one.
@@ -8422,6 +8457,8 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
const uint32_t acctype = f32acc ? 4 : 2;
const uint32_t f16vec4 = 8;
const uint32_t tmpsh = (Bc / MatBc) * sizeof(float);
const uint32_t qstride = hsk_pad / 4 + 2;
const uint32_t Qf = Br * qstride * f16vec4;
@@ -8438,7 +8475,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
const uint32_t slope = Br * acctype;
const uint32_t total_size = Qf + Psh + sfsh + ksh + slope;
const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + slope;
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", kv_type=" << kv_type << ", total_size=" << total_size << ", supported=" << supported);
@@ -8815,6 +8852,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_acc_f32;
}
return nullptr;
case GGML_OP_SET:
if (src0->type == src1->type && src0->type == dst->type &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32)) {
return ctx->device->pipeline_set_f32;
}
return nullptr;
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_MUL:
@@ -9801,16 +9844,16 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
// int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
int offset = dst->op_params[3] / 4; // offset in bytes
int nb1 = dst->op_params[0] / src0_type_size; // 4 bytes of float32
int nb2 = dst->op_params[1] / src0_type_size; // 4 bytes of float32
int nb3 = dst->op_params[2] / src0_type_size; // 4 bytes of float32
int offset = dst->op_params[3] / src0_type_size; // offset in bytes
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ACC, {
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] / dst_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3,
0,
0.0f, 0.0f, offset,
});
@@ -10624,8 +10667,10 @@ static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& sub
}
static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
float * op_params = (float *)dst->op_params;
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
const float * op_params = (const float *)dst->op_params;
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
p.param1 = op_params[0];
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, std::move(p));
}
static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
@@ -11543,7 +11588,6 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
}
}
ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
if (split_k > 1) {
ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
@@ -12052,7 +12096,6 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
// y[i] = i % k;
}
ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
if (split_k > 1) {
ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
@@ -12500,6 +12543,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break;
case GGML_OP_ACC:
case GGML_OP_SET:
ggml_vk_acc(ctx, compute_ctx, src0, src1, node);
break;
@@ -14896,8 +14940,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return true;
case GGML_OP_NORM:
case GGML_OP_GROUP_NORM:
case GGML_OP_L2_NORM:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_L2_NORM:
return ggml_is_contiguous_rows(op->src[0]) &&
op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_MUL:
@@ -14960,7 +15006,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
}
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ACC:
return op->src[0]->type == GGML_TYPE_F32;
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
case GGML_OP_SET:
return op->src[0]->type == op->src[1]->type && op->src[0]->type == op->type &&
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_I32);
case GGML_OP_CONCAT:
return ggml_type_size(op->src[0]->type) == ggml_type_size(GGML_TYPE_F32);
case GGML_OP_ADD1:
@@ -15611,6 +15660,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_ACC) {
tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
} else if (tensor->op == GGML_OP_SET) {
tensor_clone = ggml_set(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
} else if (tensor->op == GGML_OP_NORM) {
tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
} else if (tensor->op == GGML_OP_GROUP_NORM) {
+16 -8
View File
@@ -3,6 +3,9 @@
#include "types.glsl"
#include "generic_binary_head.glsl"
// false for SET, true for ACC
layout(constant_id = 1) const bool ACC = true;
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
@@ -13,17 +16,22 @@ void main() {
const uint offset = p.param3;
const uint src1_i = idx - offset;
const uint oz = src1_i / p.nb02;
const uint oy = (src1_i - (oz * p.nb02)) / p.nb01;
const uint ox = src1_i % p.nb01;
const uint i3 = src1_i / p.nb03;
const uint rem2 = src1_i - i3 * p.nb03;
const uint i2 = rem2 / p.nb02;
const uint rem1 = rem2 - i2 * p.nb02;
const uint i1 = rem1 / p.nb01;
const uint i0 = rem1 % p.nb01;
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03);
if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) {
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + ox + oy * p.ne10 + oz * p.ne10 * p.ne11]));
if (i0 < p.ne10 && i1 < p.ne11 && i2 < p.ne12 && i3 < p.ne13) {
if (ACC) {
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)]));
} else {
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)]));
}
} else {
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]));
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]));
}
}
@@ -130,6 +130,7 @@ void main() {
if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
float max_mask = NEG_FLT_MAX_OVER_2;
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
@@ -137,12 +138,25 @@ void main() {
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
masksh[c][r] = m;
max_mask = max(max_mask, m);
} else {
masksh[c][r] = float(0);
}
}
}
// skip the block if the mask is entirely -inf
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
barrier();
if (gl_SubgroupInvocationID == 0) {
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
}
barrier();
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
max_mask = max(max_mask, tmpsh[s]);
}
if (max_mask <= NEG_FLT_MAX_OVER_2) {
continue;
}
}
float Sf[Br][cols_per_thread];
@@ -260,6 +274,9 @@ void main() {
barrier();
}
// prevent race on tmpsh
barrier();
// reduce across threads
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
@@ -42,6 +42,8 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
return elem;
}
shared float tmpsh[row_split];
const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
shared f16vec4 Qf[Br * qstride];
@@ -213,6 +215,19 @@ void main() {
}
}
}
// skip the block if the mask is entirely -inf
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
barrier();
if (gl_SubgroupInvocationID == 0) {
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
}
barrier();
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
max_mask = max(max_mask, tmpsh[s]);
}
if (max_mask <= NEG_FLT_MAX_OVER_2) {
continue;
}
}
}
@@ -176,7 +176,14 @@ void main() {
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
// skip the block if the mask is entirely -inf
coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
continue;
}
} else {
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
// Don't clamp against nem1 when GQA is enabled
@@ -184,7 +191,14 @@ void main() {
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
// skip the block if the mask is entirely -inf
coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
continue;
}
}
}
}
@@ -1,6 +1,6 @@
#version 450
#include "generic_head.glsl"
#include "generic_unary_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
@@ -8,19 +8,22 @@
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
shared FLOAT_TYPE sum[BLOCK_SIZE];
void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
const uint i3 = row / (p.ne11 * p.ne12);
const uint i3_offset = i3 * p.ne12 * p.ne11;
const uint i2 = (row - i3_offset) / p.ne11;
const uint i2_offset = i2 * p.ne11;
const uint i1 = row - i3_offset - i2_offset;
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]);
sum[tid] += xi * xi;
}
@@ -35,7 +38,7 @@ void main() {
const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1)));
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
data_d[i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0] = D_TYPE(scale * FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]));
}
}
@@ -32,6 +32,7 @@ layout (push_constant) uniform parameter
uint expert_i1;
uint nbi1;
#else
uint base_work_group_y;
uint ne02;
uint ne12;
uint broadcast2;
@@ -45,9 +46,9 @@ uint expert_id;
void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
#ifdef MUL_MAT_ID
const uint expert_i0 = gl_GlobalInvocationID.y;
const uint expert_i0 = gl_WorkGroupID.y;
#else
const uint batch_idx = gl_GlobalInvocationID.y;
const uint batch_idx = gl_WorkGroupID.y + p.base_work_group_y;
#endif
#ifndef MUL_MAT_ID
@@ -90,6 +90,8 @@ layout (push_constant) uniform parameter
uint nbi1;
uint ne11;
#else
uint base_work_group_z;
uint num_batches;
uint k_split;
uint ne02;
uint ne12;
@@ -139,7 +141,7 @@ void main() {
const uint ic = gl_WorkGroupID.y;
#ifdef MUL_MAT_ID
const uint expert_idx = gl_GlobalInvocationID.z;
const uint expert_idx = gl_WorkGroupID.z;
if (ic * BN >= data_expert_count[expert_idx]) {
return;
}
@@ -149,7 +151,7 @@ void main() {
#endif
#ifndef MUL_MAT_ID
const uint batch_idx = gl_GlobalInvocationID.z;
const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z;
const uint i13 = batch_idx / p.ne12;
const uint i12 = batch_idx % p.ne12;
@@ -366,7 +368,7 @@ void main() {
const uint dc = ic * BN + warp_c * WN;
#ifndef MUL_MAT_ID
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches;
#endif
#ifdef COOPMAT
@@ -53,6 +53,8 @@ layout (push_constant) uniform parameter
uint nbi1;
uint ne11;
#else
uint base_work_group_z;
uint num_batches;
uint k_split;
uint ne02;
uint ne12;
@@ -197,7 +199,7 @@ void main() {
const uint ic = gl_WorkGroupID.y;
#ifdef MUL_MAT_ID
const uint expert_idx = gl_GlobalInvocationID.z;
const uint expert_idx = gl_WorkGroupID.z;
if (ic * BN >= data_expert_count[expert_idx]) {
return;
}
@@ -215,7 +217,7 @@ void main() {
#endif
#ifndef MUL_MAT_ID
const uint batch_idx = gl_GlobalInvocationID.z;
const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z;
const uint i13 = batch_idx / p.ne12;
const uint i12 = batch_idx % p.ne12;
@@ -255,7 +257,7 @@ void main() {
#else
uint pos_a = batch_idx_a * (p.batch_stride_a / QUANT_K);
uint pos_b = batch_idx * p.batch_stride_b;
uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches;
#endif
uint stride_a = p.stride_a / QUANT_K;
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -1,5 +1,4 @@
#decl(BYTE_HELPERS)
#ifdef BYTE_HELPERS
fn get_byte(value: u32, index: u32) -> u32 {
return (value >> (index * 8)) & 0xFF;
}
@@ -7,76 +6,74 @@ fn get_byte(value: u32, index: u32) -> u32 {
fn get_byte_i32(value: u32, index: u32) -> i32 {
return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;
}
#endif
#enddecl(BYTE_HELPERS)
#decl(Q4_0_T)
#ifdef Q4_0_T
struct q4_0 {
d: f16,
qs: array<f16, 8>
};
#enddecl(Q4_0_T)
#endif
#decl(Q4_1_T)
#ifdef Q4_1_T
struct q4_1 {
d: f16,
m: f16,
qs: array<u32, 4>
};
#enddecl(Q4_1_T)
#endif
#decl(Q5_0_T)
#ifdef Q5_0_T
struct q5_0 {
d: f16,
qh: array<f16, 2>,
qs: array<f16, 8>
};
#enddecl(Q5_0_T)
#endif
#decl(Q5_1_T)
#ifdef Q5_1_T
struct q5_1 {
d: f16,
m: f16,
qh: u32,
qs: array<u32, 4>
};
#enddecl(Q5_1_T)
#endif
#decl(Q8_0_T)
#ifdef Q8_0_T
struct q8_0 {
d: f16,
qs: array<f16, 16>
};
#enddecl(Q8_0_T)
#endif
#decl(Q8_1_T)
#ifdef Q8_1_T
struct q8_1 {
d: f16,
m: f16,
qs: array<u32, 8>
};
#enddecl(Q8_1_T)
#endif
#decl(Q2_K_T)
struct q2_k {
#ifdef Q2_K_T
struct q2_K {
scales: array<u32, 4>,
qs: array<u32, 16>,
d: f16,
dmin: f16
};
#enddecl(Q2_K_T)
#endif
#decl(Q3_K_T)
struct q3_k {
#ifdef Q3_K_T
struct q3_K {
hmask: array<f16, 16>,
qs: array<f16, 32>,
scales: array<f16, 6>,
d: f16
};
#enddecl(Q3_K_T)
#decl(Q45_K_SCALE_MIN)
#endif
#if defined(Q4_K_SCALE_MIN) || defined(Q5_K_SCALE_MIN)
fn get_scale_min(is: u32, scales: array<u32, 3>) -> vec2<f32> {
if (is < 4) {
let sc_byte = get_byte(scales[is / 4], is % 4);
@@ -91,69 +88,67 @@ fn get_scale_min(is: u32, scales: array<u32, 3>) -> vec2<f32> {
return vec2(f32(sc), f32(m));
}
}
#enddecl(Q45_K_SCALE_MIN)
#decl(Q4_K_T)
struct q4_k {
#endif
#ifdef Q4_K_T
struct q4_K {
d: f16,
dmin: f16,
scales: array<u32, 3>,
qs: array<u32, 32>
};
#enddecl(Q4_K_T)
#endif
#decl(Q5_K_T)
struct q5_k {
#ifdef Q5_K_T
struct q5_K {
d: f16,
dmin: f16,
scales: array<u32, 3>,
qh: array<u32, 8>,
qs: array<u32, 32>
};
#enddecl(Q5_K_T)
#endif
#decl(Q6_K_T)
struct q6_k {
#ifdef Q6_K_T
struct q6_K {
ql: array<f16, 64>,
qh: array<f16, 32>,
scales: array<f16, 8>,
d: f16
};
#enddecl(Q6_K_T)
#endif
#decl(IQ2_XXS_T)
#ifdef IQ2_XXS_T
struct iq2_xxs {
d: f16,
qs: array<f16, 32>
};
#enddecl(IQ2_XXS_T)
#endif
#decl(IQ2_XS_T)
#ifdef IQ2_XS_T
struct iq2_xs {
d: f16,
qs: array<f16, 32>,
scales: array<f16, 4>
};
#enddecl(IQ2_XS_T)
#endif
#decl(IQ2_S_T)
#ifdef IQ2_S_T
struct iq2_s {
d: f16,
qs: array<f16, 32>,
qh: array<f16, 4>,
scales: array<f16, 4>
};
#enddecl(IQ2_S_T)
#endif
#decl(IQ3_XSS_T)
#ifdef IQ3_XXS_T
struct iq3_xxs {
d: f16,
qs: array<f16, 48>
};
#enddecl(IQ3_XSS_T)
#endif
#decl(IQ3_S_T)
#ifdef IQ3_S_T
struct iq3_s {
d: f16,
qs: array<f16, 32>,
@@ -161,41 +156,41 @@ struct iq3_s {
signs: array<f16, 16>,
scales: array<f16, 2>
};
#enddecl(IQ3_S_T)
#endif
#decl(IQ1_S_T)
#ifdef IQ1_S_T
struct iq1_s {
d: f16,
qs: array<f16, 16>,
qh: array<f16, 8>
};
#enddecl(IQ1_S_T)
#endif
#decl(IQ1_M_T)
#ifdef IQ1_M_T
struct iq1_m {
qs: array<u32, 8>,
qh: array<u32, 4>,
scales: array<u32, 2>
};
#enddecl(IQ1_M_T)
#endif
#decl(IQ4_NL_T)
#ifdef IQ4_NL_T
struct iq4_nl {
d: f16,
qs: array<f16, 8>,
};
#enddecl(IQ4_NL_T)
#endif
#decl(IQ4_XS_T)
#ifdef IQ4_XS_T
struct iq4_xs {
d: f16,
scales_h: f16,
scales_l: u32,
qs: array<u32, 32>
};
#enddecl(IQ4_XS_T)
#endif
#decl(IQ23_TABLES)
#if defined(IQ2_XXS_TABLES) || defined(IQ2_XS_TABLES) || defined(IQ2_S_TABLES) || defined(IQ3_XXS_TABLES) || defined(IQ3_S_TABLES)
const kmask_iq2xs : array<u32, 2> = array<u32, 2>(
0x08040201u, // 1, 2, 4, 8
0x80402010u // 16, 32, 64, 128
@@ -211,9 +206,9 @@ const ksigns_iq2xs: array<u32, 32> = array<u32, 32>(
0x63e2e160,0xe76665e4,0xeb6a69e8,0x6feeed6c,
0xf37271f0,0x77f6f574,0x7bfaf978,0xff7e7dfc
);
#enddecl(IQ23_TABLES)
#endif
#decl(IQ2_XXS_GRID)
#ifdef IQ2_XXS_GRID
const iq2xxs_grid = array<u32, 512>(
0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,
0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x082b0808, 0x08080808,
@@ -280,9 +275,9 @@ const iq2xxs_grid = array<u32, 512>(
0x0808082b, 0x2b2b0808, 0x19190808, 0x2b2b0808, 0x2b081919, 0x2b2b0808, 0x08082b19, 0x2b2b0819,
0x08080808, 0x2b2b082b, 0x08192b08, 0x2b2b1908, 0x19190808, 0x2b2b2b08, 0x08081908, 0x2b2b2b19
);
#enddecl(IQ2_XXS_GRID)
#endif
#decl(IQ2_XS_GRID)
#ifdef IQ2_XS_GRID
const iq2xs_grid = array<u32, 1024>(
0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,
0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808,
@@ -413,9 +408,9 @@ const iq2xs_grid = array<u32, 1024>(
0x2b2b2b08, 0x2b2b2b08, 0x08081908, 0x2b2b2b19, 0x2b081908, 0x2b2b2b19, 0x2b08192b, 0x2b2b2b19,
0x082b2b08, 0x2b2b2b2b, 0x082b2b2b, 0x2b2b2b2b, 0x2b190819, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b
);
#enddecl(IQ2_XS_GRID)
#endif
#decl(IQ2_S_GRID)
#ifdef IQ2_S_GRID
const iq2s_grid = array<u32, 2048>(
0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,
0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808,
@@ -674,10 +669,9 @@ const iq2s_grid = array<u32, 2048>(
0x2b08192b, 0x2b2b2b19, 0x08082b08, 0x2b2b2b2b, 0x08082b2b, 0x2b2b2b2b, 0x082b0808, 0x2b2b2b2b,
0x082b082b, 0x2b2b2b2b, 0x082b2b08, 0x2b2b2b2b, 0x2b082b08, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b
);
#enddecl(IQ2_S_GRID)
#decl(IQ3_XSS_GRID)
#endif
#ifdef IQ3_XXS_GRID
const iq3xxs_grid = array<u32, 256>(
0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
@@ -712,10 +706,9 @@ const iq3xxs_grid = array<u32, 256>(
0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04
);
#enddecl(IQ3_XSS_GRID)
#decl(IQ3_S_GRID)
#endif
#ifdef IQ3_S_GRID
const iq3s_grid = array<u32, 512>(
0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305,
0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905,
@@ -782,9 +775,9 @@ const iq3s_grid = array<u32, 512>(
0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b,
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101
);
#enddecl(IQ3_S_GRID)
#endif
#decl(IQ1_GRID)
#if defined(IQ1_S_GRID) || defined(IQ1_M_GRID)
const IQ1_DELTA: f32 = 0.125;
@@ -919,12 +912,12 @@ const iq1_grid = array<u32, 1024>(
0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557
);
#enddecl(IQ1_GRID)
#endif
#decl(IQ4_GRID)
#if defined(IQ4_NL_GRID) || defined(IQ4_XS_GRID)
const kvalues_iq4nl = array<i32, 16>(
-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113
);
#enddecl(IQ4_GRID)
#endif
@@ -56,12 +56,46 @@ def expand_includes(shader, input_dir):
return include_pattern.sub(replacer, shader)
def write_shader(shader_name, shader_code, output_dir, outfile):
def chunk_shader(shader_code, max_chunk_len=60000):
"""Split shader_code into safe raw-string sized chunks."""
return [shader_code[i : i + max_chunk_len] for i in range(0, len(shader_code), max_chunk_len)]
def raw_delim(shader_code):
"""Pick a raw-string delimiter that does not appear in the shader."""
delim = "wgsl"
while f"){delim}\"" in shader_code:
delim += "_x"
return delim
def write_shader(shader_name, shader_code, output_dir, outfile, input_dir):
shader_code = expand_includes(shader_code, input_dir)
if output_dir:
wgsl_filename = os.path.join(output_dir, f"{shader_name}.wgsl")
with open(wgsl_filename, "w", encoding="utf-8") as f_out:
f_out.write(shader_code)
outfile.write(f'const char* wgsl_{shader_name} = R"({shader_code})";\n\n')
delim = raw_delim(shader_code)
chunks = chunk_shader(shader_code)
if len(chunks) == 1:
outfile.write(f'const char* wgsl_{shader_name} = R"{delim}({shader_code}){delim}";\n\n')
else:
for idx, chunk in enumerate(chunks):
outfile.write(f'static const char wgsl_{shader_name}_part{idx}[] = R"{delim}({chunk}){delim}";\n\n')
outfile.write(f'static const std::string& wgsl_{shader_name}_str() {{\n')
outfile.write(' static const std::string s = []{\n')
outfile.write(' std::string tmp;\n')
outfile.write(f' tmp.reserve({len(shader_code)});\n')
for idx in range(len(chunks)):
outfile.write(f' tmp.append(wgsl_{shader_name}_part{idx});\n')
outfile.write(' return tmp;\n')
outfile.write(' }();\n')
outfile.write(' return s;\n')
outfile.write('}\n')
outfile.write(f'const char* wgsl_{shader_name} = wgsl_{shader_name}_str().c_str();\n\n')
def generate_variants(fname, input_dir, output_dir, outfile):
@@ -74,7 +108,7 @@ def generate_variants(fname, input_dir, output_dir, outfile):
try:
variants = ast.literal_eval(extract_block(text, "VARIANTS"))
except ValueError:
write_shader(shader_base_name, text, output_dir, outfile)
write_shader(shader_base_name, text, output_dir, outfile, input_dir)
else:
try:
decls_map = parse_decls(extract_block(text, "DECLS"))
@@ -123,7 +157,7 @@ def generate_variants(fname, input_dir, output_dir, outfile):
output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"]
else:
output_name = shader_base_name
write_shader(output_name, final_shader, output_dir, outfile)
write_shader(output_name, final_shader, output_dir, outfile, input_dir)
def main():
@@ -137,7 +171,8 @@ def main():
os.makedirs(args.output_dir, exist_ok=True)
with open(args.output_file, "w", encoding="utf-8") as out:
out.write("// Auto-generated shader embedding\n\n")
out.write("// Auto-generated shader embedding\n")
out.write("#include <string>\n\n")
for fname in sorted(os.listdir(args.input_dir)):
if fname.endswith(".wgsl"):
generate_variants(fname, args.input_dir, args.output_dir, out)
@@ -1,222 +1,31 @@
#define(VARIANTS)
enable f16;
#include "common_decls.tmpl"
[
{
"SHADER_SUFFIX": "f32_vec",
"REPLS": {
"TYPE" : "vec4<f32>",
"DST_TYPE": "vec4<f32>",
"BLOCK_SIZE": 4
},
"DECLS": ["F32_VEC"]
},
{
"REPLS": {
"TYPE" : "f32",
"DST_TYPE": "f32",
"BLOCK_SIZE": 1
},
"DECLS": ["F32"]
},
{
"REPLS": {
"TYPE" : "f16",
"DST_TYPE": "f32",
"BLOCK_SIZE": 1
},
"DECLS": ["F16"]
},
{
"REPLS": {
"TYPE" : "i32",
"DST_TYPE": "i32",
"BLOCK_SIZE": 1
},
"DECLS": ["I32"]
},
{
"REPLS": {
"TYPE" : "q4_0",
"DST_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"]
},
{
"REPLS": {
"TYPE" : "q4_1",
"DST_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"]
},
{
"REPLS": {
"TYPE" : "q5_0",
"DST_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"]
},
{
"REPLS": {
"TYPE" : "q5_1",
"DST_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"]
},
{
"REPLS": {
"TYPE" : "q8_0",
"DST_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"]
},
{
"REPLS": {
"TYPE" : "q2_k",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"]
},
{
"REPLS": {
"TYPE" : "q3_k",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"]
},
{
"REPLS": {
"TYPE" : "q4_k",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"]
},
{
"REPLS": {
"TYPE" : "q5_k",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"]
},
{
"REPLS": {
"TYPE" : "q6_k",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"]
},
{
"REPLS": {
"TYPE" : "iq2_xxs",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"]
},
{
"REPLS": {
"TYPE" : "iq2_xs",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"]
},
{
"REPLS": {
"TYPE": "iq2_s",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"]
},
{
"REPLS": {
"TYPE": "iq3_xxs",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"]
},
{
"REPLS": {
"TYPE": "iq3_s",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"]
},
{
"REPLS": {
"TYPE": "iq1_s",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"]
},
{
"REPLS": {
"TYPE": "iq1_m",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"]
},
{
"REPLS": {
"TYPE": "iq4_nl",
"DST_TYPE": "f32",
"BLOCK_SIZE": 32,
},
"DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"]
},
{
"REPLS": {
"TYPE": "iq4_xs",
"DST_TYPE": "f32",
"BLOCK_SIZE": 256,
},
"DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"]
}
]
#end(VARIANTS)
#define(DECLS)
#decl(F32_VEC)
#ifdef F32_VEC
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
dst[(dst_base / 4) + offset] = src[(src_base / 4) + offset];
}
#enddecl(F32_VEC)
#endif
#decl(F32)
#ifdef F32
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
dst[dst_base + offset] = src[src_base + offset];
}
#enddecl(F32)
#endif
#decl(F16)
#ifdef F16
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
dst[dst_base + offset] = f32(src[src_base + offset]);
}
#enddecl(F16)
#endif
#decl(I32)
#ifdef I32
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
dst[dst_base + offset] = src[src_base + offset];
}
#enddecl(I32)
#endif
#decl(Q4_0)
#ifdef Q4_0
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_q4_0 = src[src_base + offset];
let d = f32(block_q4_0.d);
@@ -232,9 +41,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(Q4_0)
#endif
#decl(Q4_1)
#ifdef Q4_1
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_q4_1 = src[src_base + offset];
let d = f32(block_q4_1.d);
@@ -251,9 +60,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(Q4_1)
#endif
#decl(Q5_0)
#ifdef Q5_0
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_q5_0 = src[src_base + offset];
let d = f32(block_q5_0.d);
@@ -272,10 +81,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#endif
#enddecl(Q5_0)
#decl(Q5_1)
#ifdef Q5_1
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_q5_1 = src[src_base + offset];
let d = f32(block_q5_1.d);
@@ -294,9 +102,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(Q5_1)
#endif
#decl(Q8_0)
#ifdef Q8_0
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_q8_0 = src[src_base + offset];
let d = f32(block_q8_0.d);
@@ -310,9 +118,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(Q8_0)
#endif
#decl(Q2_K)
#ifdef Q2_K
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@@ -340,9 +148,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(Q2_K)
#endif
#decl(Q3_K)
#ifdef Q3_K
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@@ -398,9 +206,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(Q3_K)
#endif
#decl(Q4_K)
#ifdef Q4_K
// 8 blocks of 32 elements each
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
@@ -425,9 +233,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(Q4_K)
#endif
#decl(Q5_K)
#ifdef Q5_K
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@@ -455,9 +263,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(Q5_K)
#endif
#decl(Q6_K)
#ifdef Q6_K
// 16 blocks of 16 elements each
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
@@ -511,10 +319,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
sc_b_idx += 8;
}
}
#endif
#enddecl(Q6_K)
#decl(IQ2_XXS)
#ifdef IQ2_XXS
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@@ -536,9 +343,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(IQ2_XXS)
#endif
#decl(IQ2_XS)
#ifdef IQ2_XS
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@@ -568,9 +375,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(IQ2_XS)
#endif
#decl(IQ2_S)
#ifdef IQ2_S
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@@ -608,10 +415,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#endif
#enddecl(IQ2_S)
#decl(IQ3_XSS)
#ifdef IQ3_XXS
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@@ -638,9 +444,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(IQ3_XSS)
#endif
#decl(IQ3_S)
#ifdef IQ3_S
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@@ -683,9 +489,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#enddecl(IQ3_S)
#endif
#decl(IQ1_S)
#ifdef IQ1_S
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@@ -707,10 +513,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#endif
#enddecl(IQ1_S)
#decl(IQ1_M)
#ifdef IQ1_M
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
@@ -751,10 +556,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
}
}
#endif
#enddecl(IQ1_M)
#decl(IQ4_NL)
#ifdef IQ4_NL
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@@ -770,9 +574,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
dst_i++;
}
}
#enddecl(IQ4_NL)
#endif
#decl(IQ4_XS)
#ifdef IQ4_XS
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block = src[src_base + offset];
let d = f32(block.d);
@@ -791,24 +595,16 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
dst_i += 16;
}
}
#enddecl(IQ4_XS)
#end(DECLS)
#define(SHADER)
enable f16;
DECLS
#endif
@group(0) @binding(0)
var<storage, read_write> src: array<{{TYPE}}>;
var<storage, read_write> src: array<SRC_TYPE>;
@group(0) @binding(1)
var<storage, read_write> idx: array<i32>;
@group(0) @binding(2)
var<storage, read_write> dst: array<{{DST_TYPE}}>;
var<storage, read_write> dst: array<DST_TYPE>;
struct Params {
offset_src: u32, // in elements
@@ -842,8 +638,7 @@ struct Params {
@group(0) @binding(3)
var<uniform> params: Params;
override wg_size: u32;
@compute @workgroup_size(wg_size)
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.n_rows * params.ne2 * params.ne3) {
return;
@@ -866,9 +661,8 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let i_src_row = params.offset_src + idx_val * params.stride_src1 + i_dst2 * params.stride_src2 + i_dst3 * params.stride_src3;
let i_dst_row = params.offset_dst + i_dst1 * params.stride_dst1 + i_dst2 * params.stride_dst2 + i_dst3 * params.stride_dst3;
for (var i: u32 = 0; i < params.ne0/{{BLOCK_SIZE}}; i++) {
for (var i: u32 = 0; i < params.ne0/BLOCK_SIZE; i++) {
copy_elements(i_src_row, i_dst_row, i);
}
}
#end(SHADER)
@@ -1,195 +1,24 @@
#define(VARIANTS)
enable f16;
[
{
"REPLS": {
"SRC0_TYPE" : "f32",
"SRC1_TYPE" : "f32",
"BLOCK_SIZE" : 1
},
"DECLS" : ["FLOAT"]
},
{
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f16",
"BLOCK_SIZE" : 1
},
"DECLS" : ["FLOAT"]
},
{
"REPLS": {
"SRC0_TYPE" : "f16",
"SRC1_TYPE" : "f32",
"BLOCK_SIZE" : 1
},
"DECLS" : ["FLOAT"]
},
{
"REPLS": {
"SRC0_TYPE": "q4_0",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"]
},
{
"REPLS": {
"SRC0_TYPE": "q4_1",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"]
},
{
"REPLS": {
"SRC0_TYPE": "q5_0",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"]
},
{
"REPLS": {
"SRC0_TYPE": "q5_1",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"]
},
{
"REPLS": {
"SRC0_TYPE": "q8_0",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 32
},
"DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"]
},
{
"REPLS": {
"SRC0_TYPE": "q2_k",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"]
},
{
"REPLS": {
"SRC0_TYPE": "q3_k",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"]
},
{
"REPLS": {
"SRC0_TYPE": "q4_k",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"]
},
{
"REPLS": {
"SRC0_TYPE": "q5_k",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"]
},
{
"REPLS": {
"SRC0_TYPE": "q6_k",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"]
},
{
"REPLS": {
"SRC0_TYPE": "iq2_xxs",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"]
},
{
"REPLS": {
"SRC0_TYPE": "iq2_xs",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"]
},
{
"REPLS": {
"SRC0_TYPE": "iq2_s",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"]
},
{
"REPLS": {
"SRC0_TYPE": "iq3_xxs",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"]
},
{
"REPLS": {
"SRC0_TYPE": "iq3_s",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"]
},
{
"REPLS": {
"SRC0_TYPE": "iq1_s",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"]
},
{
"REPLS": {
"SRC0_TYPE": "iq1_m",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256
},
"DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"]
},
{
"REPLS": {
"SRC0_TYPE": "iq4_nl",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 32,
},
"DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"]
},
{
"REPLS": {
"SRC0_TYPE": "iq4_xs",
"SRC1_TYPE": "f32",
"BLOCK_SIZE": 256,
},
"DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"]
}
]
#include "common_decls.tmpl"
#end(VARIANTS)
#ifdef FLOAT
const BLOCK_SIZE = 1u;
#define(DECLS)
#elif defined(Q4_0) || defined(Q4_1) || defined(Q5_0) || defined(Q5_1) || defined(Q8_0) || defined(Q8_1) || defined(IQ4_NL)
const BLOCK_SIZE = 32u;
#decl(FLOAT)
#elif defined(Q2_K) || defined(Q3_K) || defined(Q4_K) || defined(Q5_K) || defined(Q6_K) || defined(IQ2_XXS) || defined(IQ2_XS) || defined(IQ2_S) || defined(IQ3_XXS) || defined(IQ3_S) || defined(IQ1_S) || defined(IQ1_M) || defined(IQ4_XS)
const BLOCK_SIZE = 256u;
#endif
#ifdef FLOAT
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
return f32(src0[src0_idx_base + offset]) * f32(src1[src1_idx_base + offset]);
}
#enddecl(FLOAT)
#endif
#decl(Q4_0)
#ifdef Q4_0
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_q4_0 = src0[src0_idx_base + offset];
let d = f32(block_q4_0.d);
@@ -207,9 +36,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#enddecl(Q4_0)
#endif
#decl(Q4_1)
#ifdef Q4_1
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_q4_1 = src0[src0_idx_base + offset];
let d = f32(block_q4_1.d);
@@ -228,9 +57,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#enddecl(Q4_1)
#endif
#decl(Q5_0)
#ifdef Q5_0
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_q5_0 = src0[src0_idx_base + offset];
let d = f32(block_q5_0.d);
@@ -251,9 +80,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#enddecl(Q5_0)
#endif
#decl(Q5_1)
#ifdef Q5_1
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_q5_1 = src0[src0_idx_base + offset];
let d = f32(block_q5_1.d);
@@ -274,9 +103,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#enddecl(Q5_1)
#endif
#decl(Q8_0)
#ifdef Q8_0
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_q8_0 = src0[src0_idx_base + offset];
let d = f32(block_q8_0.d);
@@ -292,9 +121,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#enddecl(Q8_0)
#endif
#decl(Q8_1)
#ifdef Q8_1
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_q8_1 = src0[src0_idx_base + offset];
let d = f32(block_q8_1.d);
@@ -311,9 +140,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#enddecl(Q8_1)
#endif
#decl(Q2_K)
#ifdef Q2_K
// 16 blocks of 16 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
@@ -344,10 +173,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(Q2_K)
#decl(Q3_K)
#ifdef Q3_K
// 16 blocks of 16 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
@@ -406,10 +234,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(Q3_K)
#decl(Q4_K)
#ifdef Q4_K
// 8 blocks of 32 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
@@ -436,10 +263,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(Q4_K)
#decl(Q5_K)
#ifdef Q5_K
// 8 blocks of 32 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
@@ -470,10 +296,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(Q5_K)
#decl(Q6_K)
#ifdef Q6_K
// 16 blocks of 16 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
@@ -529,10 +354,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(Q6_K)
#decl(IQ2_XXS)
#ifdef IQ2_XXS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
@@ -556,10 +380,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(IQ2_XXS)
#decl(IQ2_XS)
#ifdef IQ2_XS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
@@ -591,10 +414,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(IQ2_XS)
#decl(IQ2_S)
#ifdef IQ2_S
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
@@ -634,11 +456,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(IQ2_S)
#decl(IQ3_XSS)
#ifdef IQ3_XXS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
@@ -667,10 +487,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(IQ3_XSS)
#decl(IQ3_S)
#ifdef IQ3_S
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
@@ -715,9 +534,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#enddecl(IQ3_S)
#endif
#decl(IQ1_S)
#ifdef IQ1_S
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
@@ -741,10 +560,10 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(IQ1_S)
#decl(IQ1_M)
#ifdef IQ1_M
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
@@ -787,10 +606,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(IQ1_M)
#decl(IQ4_NL)
#ifdef IQ4_NL
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
@@ -808,10 +626,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#endif
#enddecl(IQ4_NL)
#decl(IQ4_XS)
#ifdef IQ4_XS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block = src0[src0_idx_base + offset];
let d = f32(block.d);
@@ -832,16 +649,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
}
return sum;
}
#enddecl(IQ4_XS)
#end(DECLS)
#define(SHADER)
enable f16;
DECLS
#endif
struct MulMatParams {
offset_src0: u32, // in elements/blocks
@@ -864,8 +672,8 @@ struct MulMatParams {
broadcast3: u32
};
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns
@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)
@group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns
@group(0) @binding(3) var<uniform> params: MulMatParams;
@@ -898,10 +706,8 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11;
var sum = 0.0;
for (var i: u32 = 0u; i < params.k/{{BLOCK_SIZE}}; i = i + 1u) {
for (var i: u32 = 0u; i < params.k/BLOCK_SIZE; i = i + 1u) {
sum += multiply_add(src0_idx_base, src1_idx_base, i);
}
dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum;
}
#end(SHADER)

Some files were not shown because too many files have changed in this diff Show More