Compare commits

..

81 Commits

Author SHA1 Message Date
Johannes Gäßler 7049736b2d CUDA: fix numerical issues in tile FA kernel (#16540) 2025-10-13 17:29:45 +03:00
Jie Fu (傅杰) 01d2bdc2bc ggml : fix build broken with -march=armv9-a on MacOS (#16520)
* ggml : fix build broken with -march=armv9-a on MacOS

Signed-off-by: Jie Fu <jiefu@tencent.com>

* Add #pragma message

Signed-off-by: Jie Fu <jiefu@tencent.com>

* Address review comment.

Signed-off-by: Jie Fu <jiefu@tencent.com>

* Update ggml/src/ggml-cpu/ggml-cpu.c

---------

Signed-off-by: Jie Fu <jiefu@tencent.com>
Co-authored-by: Diego Devesa <slarengh@gmail.com>
2025-10-13 15:48:47 +03:00
Chenguang Li 56fc38b965 CANN: fix CPU memory leak in CANN backend (#16549)
This commit fixes a CPU-side memory leak issue in the CANN backend,
which occurred when intermediate aclTensorList objects were not properly
released after operator execution. The leak happened during repeated
invocations of CANN ops (e.g., FlashAttention), leading to increasing
host memory usage over time.

Proper resource cleanup (aclDestroyTensorList and related release logic)
has been added to ensure that all temporary tensors are correctly freed.
2025-10-13 17:01:24 +08:00
Pascal 1fb9504eb7 fix: add remark plugin to render raw HTML as literal text (#16505)
* fix: add remark plugin to render raw HTML as literal text

Implemented a missing MDAST stage to neutralize raw HTML like major LLM WebUIs
do ensuring consistent and safe Markdown rendering

Introduced 'remarkLiteralHtml', a plugin that converts raw HTML nodes in the
Markdown AST into plain-text equivalents while preserving indentation and
line breaks. This ensures consistent rendering and prevents unintended HTML
execution, without altering valid Markdown structure

Kept 'remarkRehype' in the pipeline since it performs the required conversion
from MDAST to HAST for KaTeX, syntax highlighting, and HTML serialization

Refined the link-enhancement logic to skip unnecessary DOM rewrites,
fixing a subtle bug where extra paragraphs were injected after the first
line due to full innerHTML reconstruction, and ensuring links open in new
tabs only when required

Final pipeline: remarkGfm -> remarkMath -> remarkBreaks -> remarkLiteralHtml
-> remarkRehype -> rehypeKatex -> rehypeHighlight -> rehypeStringify

* fix: address review feedback from allozaur

* chore: update webui build output
2025-10-13 10:55:32 +02:00
Sam/Samuel 3f750f8d76 metal: add support for opt_step_sgd (#16539)
* metal: add support for opt_step_sgd

* add newline to pass EditorConfig check
2025-10-13 11:25:02 +03:00
Georgi Gerganov c515fc5771 ggml : fix scalar path for computing norm (#16558) 2025-10-13 11:22:27 +03:00
hipudding f9bc66c3eb CANN: Update several operators to support FP16 data format (#16251)
Many Ascend operators internally use FP16 precision for computation.
If input data is in FP32, it must first be cast to FP16 before
computation, and then cast back to FP32 after computation, which
introduces unnecessary cast operations. Moreover, FP16 computation
requires significantly less workload compared to FP32, leading to
noticeable efficiency improvements.

In this change, `get_rows`, `rms_norm`, and `flash_attn_ext` are extended
to support multiple data types. Validation on the Qwen2 0.5b model shows
correct accuracy and about 10% performance gain in concurrent scenarios.

Co-authored-by: noemotiovon <757486878@qq.com>
2025-10-13 08:52:22 +08:00
Sam/Samuel a31cf36ad9 metal : add opt_step_adamw and op_sum (#16529)
* scaffold to support opt step adamw on metal (not written so far)

* add opt-step-adamw kernel for metal

* pass op->src[4] as a separate buffer to the pipeline

* add bounds check to opt-step-adamw kernel

* complete scaffold for GGML_OP_SUM

* naive GGML_OP_SUM kernel

* remove unwanted comment

* change OP_SUM capability gate

* Add has_simdgroup_reduction to both ops to pass CI
2025-10-12 21:43:14 +03:00
Pascal 81d54bbfd5 webui: remove client-side context pre-check and rely on backend for limits (#16506)
* fix: make SSE client robust to premature [DONE] in agentic proxy chains

* webui: remove client-side context pre-check and rely on backend for limits

Removed the client-side context window pre-check and now simply sends messages
while keeping the dialog imports limited to core components, eliminating the
maximum context alert path

Simplified streaming and non-streaming chat error handling to surface a generic
'No response received from server' error whenever the backend returns no content

Removed the obsolete maxContextError plumbing from the chat store so state
management now focuses on the core message flow without special context-limit cases

* webui: cosmetic rename of error messages

* Update tools/server/webui/src/lib/stores/chat.svelte.ts

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>

* Update tools/server/webui/src/lib/stores/chat.svelte.ts

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>

* Update tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>

* Update tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>

* chore: update webui build output

---------

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>
2025-10-12 18:06:41 +02:00
Neo Zhang Jianyu c7be9febcb [SYCL] fix UT fault cases: count-equal, argsort, pad OPs (#16521)
* fix/refactor OP argsort, pad

* fix count-equal op

* update SYCL OP list

* fix format issue

---------

Co-authored-by: Zhang Jianyu <zhang.jianyu@outlook.com>
2025-10-12 21:53:35 +08:00
Mathieu Baudier 8415f61e23 ci : add Vulkan on Ubuntu with default packages build (#16532)
* ci: build Vulkan on Ubuntu with default packages

* ci: disable tests in Vulkan build with default Ubuntu packages
2025-10-12 15:48:03 +02:00
Aldehir Rojas 2c301e91ab common : handle unicode during partial json parsing (#16526)
* common : handle unicode during partial json parsing

* common : set missing `ensure_ascii = true` during json dump
2025-10-12 16:18:47 +03:00
Georgi Gerganov 4b2dae383d common : update presets (#16504)
* presets : add --embd-gemma-default and remove old embedding presets

* presets : add gpt-oss presets

* presets : add vision presets

* cont : remove reasoning overrides [no ci]

* cont : fix batch size for embedding gemma [no ci]
2025-10-12 09:29:13 +03:00
sirus20x6 41aac5c69b ggml : Fix FP16 ELU positive branch (#16519)
Co-authored-by: Aaron <shelhamer.aaron@gmail.com>
2025-10-12 08:25:37 +03:00
Daniel Bevenius a2fba89a42 hparams : add check for layer index in is_recurrent (#16511)
* hparams : add check for layer index in is_recurrent

This commit adds a check in the is_recurrent method to ensure that the
provided layer index is within the valid range.

The motivation for this change is to prevent potential out-of-bounds
and also be consistent with other methods in the class that perform
similar checks, like is_swa.
2025-10-12 07:19:06 +02:00
sirus20x6 20cc625edc ggml: Correct SVE implementation in ggml_vec_dot_f16_unroll (#16518)
The previous SVE implementation for `ggml_vec_dot_f16_unroll` contained a bug due to a copy-paste error. The wrong variable was used in an FMA instruction, leading to incorrect results. This commit corrects the variable usage and improves the clarity of the code by renaming variables to avoid confusion.

Co-authored-by: Aaron <shelhamer.aaron@gmail.com>
2025-10-12 08:15:00 +03:00
Johannes Gäßler 11f0af5504 CUDA: faster tile FA, add oob checks, more HSs (#16492) 2025-10-11 20:54:32 +02:00
Georgi Gerganov a3cb04744f metal : fix mul-mm condition + fix mul-mv permuted kernels (#16494) 2025-10-11 16:54:10 +03:00
Pascal 4a8fbe0a5e feat: render user content as markdown option (#16358)
* feat: render user content as markdown option
- Add a persisted 'renderUserContentAsMarkdown' preference to the settings defaults and info metadata so the choice survives reloads like other options
- Surface the new 'Render user content as Markdown' checkbox in the General section of the chat settings dialog, beneath the PDF toggle
- Render user chat messages with 'MarkdownContent' when the new setting is enabled, matching assistant formatting while preserving the existing card styling otherwise
- chore: update webui build output

* chore: update webui build output
2025-10-11 15:50:49 +02:00
Yann Follet 31d0ff1869 server / ranking : add sorting and management of top_n (#16403)
* server / ranking : add sorting and management of top_n

* Make the retro compatible if no top_n will return
all results

here is a script to make some test

```script

URL=${1:-http://127.0.0.1:8181}

curl "$URL/v1/rerank" -H "Content-Type: application/json" \
 -d '{ "model": "M", "query": "What is the recipe to make bread ?",
 "return_text" : true,
 "texts" : true,
 "top_n": 6,
 "documents": [
 "voici la recette pour faire du pain, il faut de la farine de l eau et du levain et du sel",
 "it is a bear",
 "bread recipe : floor, water, yest, salt",
 "The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.",
 "here is the ingedients to bake bread : 500g floor, 350g water, 120g fresh refresh yest, 15g salt",
 "recipe to make cookies : floor, eggs, water, chocolat",
 "here is the recipe to make bread : 500g floor, 350g water, 120g fresh refresh yest, 15g salt",
 "il fait tres beau aujourd hui",
 "je n ai pas faim, je ne veux pas manger",
 "je suis a paris"
 ] }' | jq
```

* use resize() instead for(...)

* simplify top_n init since no need to return error

result to test :

./tests.sh unit/test_rerank.py -v -x
==================================================== test session starts =====================================================
platform linux -- Python 3.12.3, pytest-8.3.5, pluggy-1.6.0 -- /home/yann/dev/yann/llama.cpp/tools/server/tests/test/bin/python3
cachedir: .pytest_cache
rootdir: /home/yann/dev/yann/llama.cpp/tools/server/tests
configfile: pytest.ini
plugins: anyio-4.11.0
collected 8 items

unit/test_rerank.py::test_rerank PASSED                                                                                [ 12%]
unit/test_rerank.py::test_rerank_tei_format PASSED                                                                     [ 25%]
unit/test_rerank.py::test_invalid_rerank_req[documents0] PASSED                                                        [ 37%]
unit/test_rerank.py::test_invalid_rerank_req[None] PASSED                                                              [ 50%]
unit/test_rerank.py::test_invalid_rerank_req[123] PASSED                                                               [ 62%]
unit/test_rerank.py::test_invalid_rerank_req[documents3] PASSED                                                        [ 75%]
unit/test_rerank.py::test_rerank_usage[Machine learning is-A machine-Learning is-19] PASSED                            [ 87%]
unit/test_rerank.py::test_rerank_usage[Which city?-Machine learning is -Paris, capitale de la-26] PASSED               [100%]

===================================================== 8 passed in 4.31s ======================================================

* add rerank top_n unit test

here is the result :

./tests.sh unit/test_rerank.py -v -x
=================================================================== test session starts ===================================================================
platform linux -- Python 3.12.3, pytest-8.3.5, pluggy-1.6.0 -- /home/yann/dev/yann/llama.cpp/tools/server/tests/test/bin/python3
cachedir: .pytest_cache
rootdir: /home/yann/dev/yann/llama.cpp/tools/server/tests
configfile: pytest.ini
plugins: anyio-4.11.0
collected 16 items

unit/test_rerank.py::test_rerank PASSED                                                                                                             [  6%]
unit/test_rerank.py::test_rerank_tei_format PASSED                                                                                                  [ 12%]
unit/test_rerank.py::test_invalid_rerank_req[documents0] PASSED                                                                                     [ 18%]
unit/test_rerank.py::test_invalid_rerank_req[None] PASSED                                                                                           [ 25%]
unit/test_rerank.py::test_invalid_rerank_req[123] PASSED                                                                                            [ 31%]
unit/test_rerank.py::test_invalid_rerank_req[documents3] PASSED                                                                                     [ 37%]
unit/test_rerank.py::test_rerank_usage[Machine learning is-A machine-Learning is-19] PASSED                                                         [ 43%]
unit/test_rerank.py::test_rerank_usage[Which city?-Machine learning is -Paris, capitale de la-26] PASSED                                            [ 50%]
unit/test_rerank.py::test_rerank_top_n[None-4] PASSED                                                                                               [ 56%]
unit/test_rerank.py::test_rerank_top_n[2-2] PASSED                                                                                                  [ 62%]
unit/test_rerank.py::test_rerank_top_n[4-4] PASSED                                                                                                  [ 68%]
unit/test_rerank.py::test_rerank_top_n[99-4] PASSED                                                                                                 [ 75%]
unit/test_rerank.py::test_rerank_tei_top_n[None-4] PASSED                                                                                           [ 81%]
unit/test_rerank.py::test_rerank_tei_top_n[2-2] PASSED                                                                                              [ 87%]
unit/test_rerank.py::test_rerank_tei_top_n[4-4] PASSED                                                                                              [ 93%]
unit/test_rerank.py::test_rerank_tei_top_n[99-4] PASSED                                                                                             [100%]

=================================================================== 16 passed in 8.84s ===================================================================

* editor config check fix
2025-10-11 16:39:04 +03:00
Diego Devesa 97870e6497 cuda : avoid initializing unused devices (#16510) 2025-10-11 13:02:26 +02:00
amirai21 477a66b035 convert : correctly handle LLaMA tokenizer for Jamba (#16470)
* fix: convert_hf_to_gguf - change Jamba non-sentencepiece mode (tokenizer.json) vocab construction

* fix: convert_hf_to_gguf - jamba non-sentencepiece tokenizer to use _set_vocab_llama_hf func

* fix: convert_hf_to_gguf - removed get_vocab_base_pre from jamba
2025-10-11 10:33:41 +02:00
Georgi Gerganov e60f01d941 server : fix division by zero when reporting stats (#16501) 2025-10-10 22:15:05 +03:00
Georgi Gerganov 81086cd6a3 vocab : mark EOT token for Granite models (#16499)
* vocab : mark EOT token for Granite models

* sampling : fallback to EOS when EOT is not found
2025-10-10 17:17:31 +03:00
Radoslav Gerganov 68ee98ae18 server : return HTTP 400 if prompt exceeds context length (#16486)
In streaming mode when prompt exceeds context length, the server returns
HTTP 200 status code with a JSON error in the body.  This is very
confusing and inconsistent with all other inference engines which return
HTTP 4xx error in this case.

This patch fixes this problem and makes the server return HTTP 400 in
such cases.
2025-10-10 16:11:07 +02:00
Radoslav Gerganov cdb6da468c server : log requests to /v1/completions (#16495) 2025-10-10 13:22:27 +03:00
Prajwal B Mehendarkar 6d69ab3f26 cmake : Dont define XOPENSOURCE on AIX (#16481) 2025-10-10 11:15:46 +03:00
Pascal 1faa13a118 webui: updated the chat service to only include max_tokens in the req… (#16489)
* webui: updated the chat service to only include max_tokens in the request payload when the setting is explicitly provided, while still mapping explicit zero or null values to the infinite-token sentinel

* chore: update webui build output
2025-10-09 22:54:57 +02:00
duduta 1deee0f8d4 cpu : optimize the ggml NORM operation (#15953)
* ggml-cpu: optimize norm operation to use intrinsics or Accelerate

          rename function

          add endif macro comment

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Aaron Teo <taronaeo@gmail.com>

* implement s390x SIMD suggested by @taronaeo

* add TODO comment

* tidy up spaces

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Aaron Teo <taronaeo@gmail.com>
2025-10-09 21:11:15 +02:00
Georgi Gerganov d00cbea63c server : host-memory prompt caching (#16391)
* minor : code style

* server : fix prompt similarity calculation

* server : initial host-memory prompt caching

* cont

* server : refactor

* cont

* cont : make the server task of the slot const

* cont : minor [no ci]

* server : cache prompts and checkpoints only for completion tasks

* server : improve prompt caching logic

* cont : fix check for number of cached prompts [no ci]

* server : improve caching logic, add -cram CLI arg

* server : print prompt mismatch info

* cont : better naming [no ci]

* server : improve prompt cache loading logic

* server : add option to debug the slot contents (#16482)

* server : add option to debug the slot contents

* Update tools/server/server.cpp

---------

Co-authored-by: Xuan-Son Nguyen <son@huggingface.co>

* server : add option to disable prompt cache

---------

Co-authored-by: Xuan-Son Nguyen <son@huggingface.co>
2025-10-09 18:54:51 +03:00
Pascal 8328fd4bae No markdown in cot (#16483)
* fix: let the model think in plaintext

* chore: npm run format + npm run build
2025-10-09 17:36:29 +02:00
Daniel Bevenius 56b4795842 model-conversion : add support for SentenceTransformers (#16387)
* model-conversion : add support for SentenceTransformers

This commit adds support for models that use SentenceTransformer layers.

The motivation for this is that if converted model includes any of the
numbered layers specified in the original models repository then these
changes enable these models to be used and verified. Currently the
model-conversion only support the base model output without any of
the additional transformation layers.

Usage:
Convert the model that also includes the SentenceTransformer layers:
```console
(venv) $ export EMBEDDING_MODEL_PATH="~/google/embeddinggemma-300M"
(venv) make embedding-convert-model
```

Verify the produced embeddings from the converted model against the
original model embeddings:
```console
(venv) make embedding-verify-logits-st
```

The original model can be run using SentenceTransformer:
```console
(venv) make embedding-run-original-model-st
```

Run the converted model using "SentenceTransformer" layers whic
enables pooling and normalization:
```console
(venv) make embedding-run-converted-model-st
```

* add model-conversion example requirements

* add support for -st flag in embedding model conversion

This commit add support for the -st flag in the embedding model
conversion script. This will enable models to be converted using
sentence transformers dense layers.
2025-10-09 14:35:22 +02:00
sudhiarm 2c0d875ae6 ci: add ARM64 Kleidiai build and test support (#16462) 2025-10-09 10:13:18 +02:00
Chenguang Li aa4711d369 CANN: Improve ACL graph matching (#16166)
* CANN: improve ACL graph matching

Record `ne` and `nb` information for src tensors and include them in the
graph matching check. This enhances the robustness of ACL graph matching
by preventing incorrect matches when src tensors share the same data
address but differ in shape or stride.

* CANN: add op_params match
2025-10-09 15:50:25 +08:00
Charles Xu d80d6d2400 kleidiai: kernel interface refactoring (#16460) 2025-10-09 10:29:17 +03:00
Neo Zhang Jianyu b260213755 [SYCL] refactor soft_max, add soft_max_back (#16472)
* refactor to support soft_max_ext

* fix error and support soft_max_back

* rm unused functions

* fix format issue

---------

Co-authored-by: Zhang Jianyu <zhang.jianyu@outlook.com>
2025-10-09 10:25:11 +03:00
Saba Fallah e08db42595 model: EmbeddingGemma Adding Support for SentenceTransformers Dense Modules (#16367)
* model: EmbeddingGemma sentence-transformers dense linear projections support

* model: add support for EmbeddingGemma SentenceTransformers dense linear projections

Adding support for the Dense modules used in EmbeddingGemma models.
EmbeddingGemma is a SentenceTransformers model with additional modules beyond the base Transformer backbone.

See: https://developers.googleblog.com/en/gemma-explained-embeddinggemma-architecture-and-recipe/

* model: add support for EmbeddingGemma SentenceTransformers dense linear projections

- converting model with dense-layers is optional
- introduced dense config params

* Update convert_hf_to_gguf.py

Co-authored-by: Daniel Bevenius <daniel.bevenius@gmail.com>

* fixed formatting issues

* Update src/llama-graph.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* - removed pooling_type_opt, always allow overriding pooling_type
- asserts checking dense features dims

* fix python lint

* fix ubuntu gcc build warning

* - fixed thread-safety test
- moved asserts to load_hparams

* - tidying up code
- simplifying graph-context expecting both dense weights

* minor : add TODO

---------

Co-authored-by: Daniel Bevenius <daniel.bevenius@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-10-09 09:39:18 +03:00
Pascal 12bbc3fa50 refactor: centralize CoT parsing in backend for streaming mode (#16394)
* refactor: unify reasoning handling via backend reasoning_content, drop frontend tag parsing

- Updated the chat message component to surface backend-supplied reasoning via message.thinking while showing the raw assistant content without inline tag scrubbing
- Simplified chat streaming to append content chunks directly, stream reasoning into the message model, and persist any partial reasoning when generation stops
- Refactored the chat service SSE handler to rely on server-provided reasoning_content, removing legacy <think> parsing logic
- Refreshed Storybook data and streaming flows to populate the thinking field explicitly for static and streaming assistant messages

* refactor: implement streaming-aware universal reasoning parser

Remove the streaming mode limitation from --reasoning-format by refactoring
try_parse_reasoning() to handle incremental parsing of <think> tags across
all formats.

- Rework try_parse_reasoning() to track whitespace, partial tags, and
  multiple reasoning segments, allowing proper separation of reasoning_content
  and content in streaming mode
- Parse reasoning tags before tool call handling in content-only and Llama 3.x
  formats to ensure inline <think> blocks are captured correctly
- Change default reasoning_format from 'auto' to 'deepseek' for consistent
  behavior
- Add 'deepseek-legacy' option to preserve old inline behavior when needed
- Update CLI help and documentation to reflect streaming support
- Add parser tests for inline <think>...</think> segments

The parser now continues processing content after </think> closes instead of
stopping, enabling proper message.reasoning_content and message.content
separation in both streaming and non-streaming modes.

Fixes the issue where streaming responses would dump everything (including
post-thinking content) into reasoning_content while leaving content empty.

* refactor: address review feedback from allozaur

- Passed the assistant message content directly to ChatMessageAssistant to drop the redundant derived state in the chat message component
- Simplified chat streaming updates by removing unused partial-thinking handling and persisting partial responses straight from currentResponse
- Refreshed the ChatMessage stories to cover standard and reasoning scenarios without the old THINK-tag parsing examples

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>

* refactor: restore forced reasoning prefix to pass test-chat ([chat] All tests passed)

- store the exact sequence seen on input when 'thinking_forced_open' enforces a reasoning block
- inject this prefix before the first accumulated segment in 'reasoning_content', then clear it to avoid duplication
- repeat the capture on every new 'start_think' detection to properly handle partial/streaming flows

* refactor: address review feedback from ngxson

* debug: say goodbye to curl -N, hello one-click raw stream

- adds a new checkbox in the WebUI to display raw LLM output without backend parsing or frontend Markdown rendering

* Update tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessage.svelte

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>

* webui: add Storybook example for raw LLM output and scope reasoning format toggle per story

- Added a Storybook example that showcases the chat message component in raw LLM output mode with the provided trace sample
- Updated every ChatMessage story to toggle the disableReasoningFormat setting so the raw-output rendering remains scoped to its own example

* npm run format

* chat-parser: address review feedback from ngxson

Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>

---------

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>
Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
2025-10-08 23:18:41 +03:00
ai-fonsi 9d0882840e Disable CUDA host buffers on integrated GPUs (#16308) 2025-10-08 20:21:46 +02:00
issixx d2ee056e1d server : fix cancel pending task (#16467)
Co-authored-by: DevAI <DevAI@gmail.com>
2025-10-08 11:20:18 +03:00
Georgi Gerganov b2c08c9ec4 metal : mark FA blocks (#16372)
* metal : better unroll in the FA kernels

* metal : index FA blocks

* tests : restore [no ci]

* metal : prevent division by zero in FA kernels

* metal : fix -INF detection logic
2025-10-08 10:57:53 +03:00
Georgi Gerganov 7fdd16b432 server : improve context checkpoint logic (#16440) 2025-10-08 10:57:29 +03:00
Reese Levine 74b8fc17f9 ggml webgpu: profiling, CI updates, reworking of command submission (#16452)
* Add profiling

* More detailed profiling

* Rework command submission to avoid global locks

* Update wait handling

* try new method of waiting on futures

* Add serializing of command submission in some cases

* Add new pool for timestamp queries and clean up logging

* Serialize command submission in CI and leave a TODO note

* Update webgpu CI

* Add myself as WebGPU codeowner

* Deadlock avoidance

* Leave WebGPU/Vulkan CI serialized

* Fix divide by 0

* Fix logic in division by inflight_threads

* Update CODEOWNERS and remove serialize submit option
2025-10-07 13:48:56 -07:00
Tarek Dakhran aeaf8a36f0 llama : support LiquidAI LFM2-MoE hybrid model (#16464)
* llama : support LiquidAI LFM2-MoE hybrid model

Add support for [LiquidAI/LFM2-8B-A1B](https://huggingface.co/LiquidAI/LFM2-8B-A1B) model.
For more information about models, please read [the blog post](https://www.liquid.ai/company/news).

[HF PR](https://github.com/huggingface/transformers/pull/41401)
[GGUFs](https://huggingface.co/LiquidAI/LFM2-8B-A1B-GGUF)

* Do not use defaultdict

* Address PR feedback
2025-10-07 20:03:35 +02:00
Georgi Gerganov df1b612e29 server : add /v1/health endpoint (#16461)
* server : add /v1/health endpoint

* cont : update readme
2025-10-07 15:57:14 +03:00
Sascha Rogmann 4e0388aa8a webui : added download action (#13552) (#16282)
* webui : added download action (#13552)

* webui : import and export (for all conversations)

* webui : fixed download-format, import of one conversation

* webui : add ExportedConversations type for chat import/export

* feat: Update naming & order

* chore: Linting

* webui : Updated static build output

---------

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>
2025-10-07 11:11:08 +02:00
Georgi Gerganov ef4c5b87ea presets : fix pooling param for embedding models (#16455) 2025-10-07 10:32:32 +03:00
Radoslav Gerganov c61ae20d05 rpc : update documentation (#16441)
Update the README file to match the newly added functionality of
exposing multiple devices from a single server.

Co-authored-by: Diego Devesa <slarengh@gmail.com>
2025-10-07 06:59:13 +00:00
Georgi Gerganov 0123ff38f5 memory : use sequential equal splits for recurrent modules (#16442) 2025-10-07 08:24:17 +03:00
Georgi Gerganov 0a319bb75e metal : add support for non-padded FA KV (#16148)
* metal : pad K, V and Mask when needed

* cont : simplify

* cuda : add TODO about KV padding requirement

* metal : add comments

* metal : remove mask padding requirement
2025-10-07 08:23:30 +03:00
Georgi Gerganov 1d6092fc72 tests : add -INF blocks to the KQ mask in the FA tests (#16380)
* tests : add -INF blocks to the KQ mask in the FA tests

* cont : bump -INF block size to 64

Co-authored-by: Jeff Bolz <jbolz@nvidia.com>

* ggml : prevent division by zero in FA CPU op

---------

Co-authored-by: Jeff Bolz <jbolz@nvidia.com>
2025-10-07 08:22:35 +03:00
Georgi Gerganov 8ae32dc9ec metal : various optimizations + refactoring (#16446)
* metal : ssm_scan minor opts

* metal : get_rows optimize

* metal : cpy optimize

* metal : ssm_conv opt

* metal : ssm_scan simplify

* metal : ssm_Scan opt
2025-10-07 08:21:40 +03:00
Gadflyii 3df2244df4 llama : add --no-host to disable host buffers (#16310)
* implement --no-host to disable host buffer

* fix equal_mparams

* move no-host enumeration order together with other model params

---------

Co-authored-by: slaren <slarengh@gmail.com>
2025-10-06 19:55:53 +02:00
Gabe Goodhart c08002a198 chat : Granite Docling stopping (#16438)
* fix: Fix duplicate fake image before token on first slice

Branch: GraniteDoclingStopping

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Use double-newline before overview image

Branch: GraniteDoclingStopping

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Remove incorrect newline at the end of granite chat template gen prompt

There should not be one, even for the language models.

Branch: GraniteDoclingStopping

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* tests: Remove bad newline from granite chat template test (legacy)

Branch: GraniteDoclingStopping

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

---------

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2025-10-06 18:59:40 +02:00
Sigbjørn Skjæret 3a002afafa ci : refactor sdk caching to minimize storage (#16414)
* refactor sdk caching to minimize storage

* use correct action

* add myself as owner to /.github/actions/ [no ci]
2025-10-06 17:40:21 +02:00
Georgi Gerganov a23b9bdbd3 ggml : fix unaligned access in AMX code (#16315) 2025-10-06 16:05:27 +03:00
Daniel Bevenius 04e632a4aa ci : remove missing reranker model files (#16444)
This commit removes jina-reranker-v1-tiny-en model files that are no
longer present on Hugging Face.

The motivation for this that it clears up the CI logs from 404 errors
which can be a little confusing when looking at the logs the first time.

Refs: https://github.com/ggml-org/llama.cpp/actions/runs/18070620247/job/51419855630#step:5:2649
2025-10-06 14:56:59 +02:00
Daniel Bevenius a80ff183ab ggml-cpu : fix leftover handling in ggml_vec_scale_f32 for SVE (#16443)
This commit updates the leftover handling in ggml_vec_scale_f32.

The motivation for this is that the code currently incorrectly assumes
there would be fewer than ggml_f32_epr leftover elements. However,
since the main loop processes 2*ggml_f32_epr elements per iteration
, there can be up to (2*ggml_f32_epr - 1) leftover elements.

The original single-pass leftover code could only process ggml_f32_epr
elements, leaving some elements unscaled.

Example scenario with 256-bit SVE:
```
ggml_f32_epr  = 8 (elements per register)
ggml_f32_step = 16 (two registers per iteration)
n             = 25
np            = 16
leftovers     = 9 elements (16-24)

Original    : processes only elements 16-23, misses element 24
This commit : loop processes elements 16-23, then element 24
```

Refs: https://github.com/ggml-org/llama.cpp/actions/runs/18070620247/job/51419855630
2025-10-06 14:17:12 +02:00
Yuannan 1d49ca3759 nix : removed metal for nix (#16118) 2025-10-06 12:29:56 +03:00
Oleksandr Kuvshynov c5fef0fcea server: update readme to mention n_past_max metric (#16436)
https://github.com/ggml-org/llama.cpp/pull/15361 added new metric
exported, but I've missed this doc.
2025-10-06 10:53:31 +03:00
Gabe Goodhart ca71fb9b36 model : Granite docling + Idefics3 preprocessing (SmolVLM) (#16206)
* feat: Add granite-docling conversion using trillion pretokenizer

Branch: gabe-l-hart/GraniteDocling

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Add granite-docling vocab pre enum

Branch: gabe-l-hart/GraniteDocling

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Use granite-docling pre

Branch: gabe-l-hart/GraniteDocling

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Add clip_is_idefics3

Branch: gabe-l-hart/GraniteDocling

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Allow multi-token boundary sequences for image templating

Branch: gabe-l-hart/GraniteDocling

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Add tiling support for idefices3 in clip.cpp

This should likely be moved into llava_uhd::get_slice_instructions, but for
now this avoids disrupting the logic there.

Branch: gabe-l-hart/GraniteDocling

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Partial support for full templating for idefics3 in mtmd

There are still errors encoding some of the image chunks, but the token
sequence now matches transformers _almost_ perfectly, except for the double
newline before the global image which shows up as two consecutive newline
tokens instead of a single double-newline token. I think this is happening
because the blocks are tokenized separately then concatenated.

Branch: gabe-l-hart/GraniteDocling

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Fully working image preprocessing for idefics3 w/ resize and slicing

Branch: gabe-l-hart/GraniteDocling

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Parse the preprocessor config's longest side and add it to the mmproj hparams

Branch: GraniteDocling

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Use the longest side instead of size * scale_factor

For Granite Docling, these come out to the same value, but that was just a
conicidence.

Branch: GraniteDocling

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Allow batch encoding and remove clip_is_idefics3

Branch: GraniteDocling

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* refactor: Remove unnecessary conditionals for empty token vectors

Branch: GraniteDocling

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* refactor: Use image_manipulation util

Branch: GraniteDocling

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* add test model

---------

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
2025-10-05 14:57:47 +02:00
Reese Levine 35266573b9 ggml webgpu: actually add softmax, fix rms_norm offset (#16400)
* implement soft_max

* Fix soft_max data race

* Temporary fix, wait on each submit
2025-10-04 20:59:31 -07:00
Eve 86df2c9ae4 vulkan: use a more appropriate amount of threads when generating shaders (#16418)
* use a more flexible amount of threads

* fix windows compile and 0 thread case

* nominmax
2025-10-04 22:04:27 +02:00
Radoslav Gerganov f39283960b rpc : check src buffer when copying tensor (#16421)
Only dst buffer is guaranteed to be an RPC buffer. Add check for the src
one.
2025-10-04 16:22:45 +03:00
Radoslav Gerganov 898acba681 rpc : add support for multiple devices (#16276)
* rpc : add support for multiple devices

Allow rpc-server to expose multiple devices from a single endpoint.
Change RPC protocol to include device identifier where needed.

closes: #15210

* fixes

* use ggml_backend_reg_t

* address review comments

* fix llama-bench backend report

* address review comments, change device naming

* fix cmd order
2025-10-04 12:49:16 +03:00
Acly e29acf74fe vulkan : incremental shader builds (#16341)
* vulkan (DRAFT): split shader generation by GLSL source file, to improve incremental build times

* support dep-files so shaders are recompiled if their included files change

* rename shader files which are used as "headers" to use .glsl extension
* move glslc extension detection shaders to separate folders
* the above is to prevent them from getting glob'd with the actual compute shaders that need to be compiled

* vulkan : only write embedded shader .hpp/.cpp when they change

* avoid recompiling ggml-vulkan.cpp when editing shaders
* pass single --source argument instead of --input-dir & --filter to shader gen
* check for source file match earlier

* fix hang in vulkan-shaders-gen when there are compilation errors

* early out did not decrement compile_count

* clean up

* fix glslc integer dot product test

* unconditionally write the embedded shader cpp output

* replace output filepath in generated dep-files to match output in CMakeLists

---------

Co-authored-by: Jeff Bolz <jbolz@nvidia.com>
2025-10-04 11:42:56 +02:00
Pascal 128d522c04 chat : support Magistral thinking (#16413)
* feat: added a dedicated Magistral chat format that preserves [THINK] spans, parses reasoning before tool calls

* feat: new flow in the chat template test suite for Magistral
2025-10-03 21:51:48 +03:00
ddh0 f6dcda3900 server : context checkpointing for hybrid and recurrent models (#16382)
* initial commit for branch 3

* generalize `swa_checkpoint` to `ctx_checkpoint`

this extends `llama-server`'s SWA checkpointing logic to include
hybrid/recurrent models such as Jamba, Granite

* oops

* disable debug prints

* keep backwards compat with `--swa-checkpoints`

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* update prompt re-processing message

* fix off-by-one error per GG

* keep `seq_rm` log per GG

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* server : fix checkpoint logic to support recurrent caches

* server : cleanup and fixes

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-10-03 21:34:51 +03:00
Georgi Gerganov 606a73f531 metal : fix loop bound in ggml_mem_ranges (#16412) 2025-10-03 19:18:56 +03:00
Sigbjørn Skjæret 946f71ed9a llama : fix shapes for bert/mpt q/k norm (#16409) 2025-10-03 14:40:25 +02:00
Acly 638d330246 ggml : fix graph reallocation with multiple chunks (#16396)
reallocation is needed if a single chunk grows in size,
even if total allocation size stays the same or is lower
2025-10-03 13:49:08 +02:00
Aleksander Grygier 84c8e305e8 Fix missing messages on sibling navigation (#16408)
* fix: resolve message disappearing issue when navigating between regenerated siblings by using current leaf nodes instead of cached sibling IDs

* chore: update webui build output

* chore: update webui build output
2025-10-03 12:51:40 +02:00
Jeff Bolz 2aaf0a2a20 vulkan: Replace uses of maxMemoryAllocationSize and VK_WHOLE_SIZE (#16354)
* vulkan: Replace uses of maxMemoryAllocationSize and VK_WHOLE_SIZE

Replace maxMemoryAllocationSize check with maxBufferSize when creating buffers.
The maxMemoryAllocationSize limit is a "soft" limit and allocations can succeed
beyond that limit. This allows > 4GB buffers to be allocated on some
implementations (e.g. NVIDIA) and tensors this large can be used for im2col
and mul_mat.

For temporary buffers (prealloc_x/y/etc) check against maxStorageBufferRange.
I'm not sure this check is ideal, but we always use these buffers as a single
full size binding and the limit may be smaller than maxMemoryAllocationSize
or maxBufferSize, so I think this is reasonable.

Replace descriptor range uses of VK_WHOLE_SIZE with a manually computed range.
The maxStorageBufferRange may be smaller than the maxBufferSize or
maxMemoryAllocationSize (and the Vulkan spec warns about this in a note) and
it's invalid usage if VK_WHOLE_SIZE computes a range larger than
maxStorageBufferRange.

With this change, it should be possible to generate videos using wan networks
in stable-diffusion.cpp.

* vulkan: Add env var GGML_VK_FORCE_MAX_BUFFER_SIZE and use stoull
2025-10-03 12:50:46 +02:00
Jeff Bolz 0e1f838556 vulkan: Fix FA coopmat1 invalid array indexing (#16365)
When computing sinks, the cm1 shader was looping r from 0 to Br rather than
to rows_per_thread. I must have copied this from the scalar path (where it is
correct), and somehow it wasn't causing failures on current drivers.
2025-10-03 11:52:46 +02:00
Daniel Bevenius ad126479c2 ci : change macos-13 to macos-15-intel (#16401)
This commit updates the macos-13 runners to macos-15-intel.

The motivation for this changes is the macos-13 runners are scheduled
to be retired on 2025-12-04.

Refs: https://github.blog/changelog/2025-09-19-github-actions-macos-13-runner-image-is-closing-down/
2025-10-03 11:45:16 +02:00
Aleksander Grygier 77233277c9 Capture model name only after first token (streaming) or completed request (#16405)
* feat: Capture model name only after first token (streaming) or completed request (non-streaming)

* chore: update webui build output

* chore: update webui build output
2025-10-03 11:30:39 +02:00
Jeff Bolz e308efda8e vulkan: in flash attention, bounds check against nem1 (don't rely on GGML_KQ_MASK_PAD) (#16316) 2025-10-03 10:33:08 +02:00
Aleksander Grygier 136bda78c5 webui : Fix messages payload sent to chat completions (#16402)
* fix: Include just the currently active message branches instead of all in chat completions request

* chore: Build webui static output

* chore: Formatting

* chore: update webui build output
2025-10-03 10:11:34 +03:00
Pascal 5113efd34c fix: track viewportHeight via window.innerHeight to avoid unwanted scrolling (#16356)
Use <svelte:window bind:innerHeight> instead of manual resize listener

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>
2025-10-03 08:01:31 +02:00
Sigbjørn Skjæret d64c8104f0 test-barrier : do not use more threads than physically available (#16389)
* do not use more threads than physically available

* ensure n_threads > 0

Co-authored-by: Jeff Bolz <jbolz@nvidia.com>

---------

Co-authored-by: Jeff Bolz <jbolz@nvidia.com>
2025-10-02 20:10:12 +02:00
Reese Levine ef07a40906 ggml webgpu: add support for soft_max, optimize rms_norm (#16357)
* Add inplace softmax

* Move rms_norm to split row approach

* Update debug for supports_op

* clean up debug statements

* Update tests/test-backend-ops.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-10-02 11:00:31 -07:00
306 changed files with 21866 additions and 9218 deletions
-4
View File
@@ -128,10 +128,6 @@ effectiveStdenv.mkDerivation (finalAttrs: {
};
postPatch = ''
substituteInPlace ./ggml/src/ggml-metal/ggml-metal.m \
--replace '[bundle pathForResource:@"ggml-metal" ofType:@"metal"];' "@\"$out/bin/ggml-metal.metal\";"
substituteInPlace ./ggml/src/ggml-metal/ggml-metal.m \
--replace '[bundle pathForResource:@"default" ofType:@"metallib"];' "@\"$out/bin/default.metallib\";"
'';
# With PR#6015 https://github.com/ggml-org/llama.cpp/pull/6015,
+36
View File
@@ -0,0 +1,36 @@
name: "Install exe"
description: "Download and install exe"
inputs:
url:
description: "URL of the exe installer"
required: true
args:
description: "Installer arguments"
required: true
timeout:
description: "Timeout (in ms)"
required: false
default: "600000"
runs:
using: "composite"
steps:
- name: Install EXE
shell: pwsh
run: |
$ErrorActionPreference = "Stop"
write-host "Downloading Installer EXE"
Invoke-WebRequest -Uri "${{ inputs.url }}" -OutFile "${env:RUNNER_TEMP}\temp-install.exe"
write-host "Installing"
$proc = Start-Process "${env:RUNNER_TEMP}\temp-install.exe" -ArgumentList '${{ inputs.args }}' -NoNewWindow -PassThru
$completed = $proc.WaitForExit(${{ inputs.timeout }})
if (-not $completed) {
Write-Error "Installer timed out. Killing the process"
$proc.Kill()
exit 1
}
if ($proc.ExitCode -ne 0) {
Write-Error "Installer failed with exit code $($proc.ExitCode)"
exit 1
}
write-host "Completed installation"
@@ -0,0 +1,20 @@
name: "Linux - Setup SpacemiT Toolchain"
description: "Setup SpacemiT Toolchain for Linux"
inputs:
path:
description: "Installation path"
required: true
version:
description: "SpacemiT toolchain version"
required: true
runs:
using: "composite"
steps:
- name: Setup SpacemiT Toolchain
id: setup
uses: ./.github/actions/unarchive-tar
with:
url: https://archive.spacemit.com/toolchain/spacemit-toolchain-linux-glibc-x86_64-v${{ inputs.version }}.tar.xz
path: ${{ inputs.path }}
strip: 1
@@ -0,0 +1,20 @@
name: "Linux - Setup Vulkan SDK"
description: "Setup Vulkan SDK for Linux"
inputs:
path:
description: "Installation path"
required: true
version:
description: "Vulkan SDK version"
required: true
runs:
using: "composite"
steps:
- name: Setup Vulkan SDK
id: setup
uses: ./.github/actions/unarchive-tar
with:
url: https://sdk.lunarg.com/sdk/download/${{ inputs.version }}/linux/vulkan_sdk.tar.xz
path: ${{ inputs.path }}
strip: 1
+27
View File
@@ -0,0 +1,27 @@
name: "Unarchive tar"
description: "Download and unarchive tar into directory"
inputs:
url:
description: "URL of the tar archive"
required: true
path:
description: "Directory to unarchive into"
required: true
type:
description: "Compression type (tar option)"
required: false
default: "J"
strip:
description: "Strip components"
required: false
default: "0"
runs:
using: "composite"
steps:
- name: Unarchive into directory
shell: bash
run: |
mkdir -p ${{ inputs.path }}
cd ${{ inputs.path }}
curl --no-progress-meter ${{ inputs.url }} | tar -${{ inputs.type }}x --strip-components=${{ inputs.strip }}
@@ -0,0 +1,15 @@
name: "Windows - Setup ROCm"
description: "Setup ROCm for Windows"
inputs:
version:
description: "ROCm version"
required: true
runs:
using: "composite"
steps:
- name: Setup ROCm
uses: ./.github/actions/install-exe
with:
url: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-${{ inputs.version }}-WinSvr2022-For-HIP.exe
args: -install
+89
View File
@@ -0,0 +1,89 @@
name: Build Actions Cache
on:
workflow_dispatch: # allows manual triggering
schedule:
- cron: '0 * * * *'
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
cancel-in-progress: true
jobs:
ubuntu-24-vulkan-cache:
runs-on: ubuntu-24.04
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
- name: Get latest Vulkan SDK version
id: vulkan_sdk_version
run: |
echo "VULKAN_SDK_VERSION=$(curl https://vulkan.lunarg.com/sdk/latest/linux.txt)" >> "$GITHUB_ENV"
- name: Setup Cache
uses: actions/cache@v4
id: cache-sdk
with:
path: ./vulkan_sdk
key: vulkan-sdk-${{ env.VULKAN_SDK_VERSION }}-${{ runner.os }}
- name: Setup Vulkan SDK
if: steps.cache-sdk.outputs.cache-hit != 'true'
uses: ./.github/actions/linux-setup-vulkan
with:
path: ./vulkan_sdk
version: ${{ env.VULKAN_SDK_VERSION }}
ubuntu-24-spacemit-cache:
runs-on: ubuntu-24.04
env:
# Make sure this is in sync with build-linux-cross.yml
SPACEMIT_IME_TOOLCHAIN_VERSION: "1.1.2"
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
- name: Setup Cache
uses: actions/cache@v4
id: cache-toolchain
with:
path: ./spacemit_toolchain
key: spacemit-ime-toolchain-v${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}-${{ runner.os }}
- name: Setup SpacemiT Toolchain
if: steps.cache-toolchain.outputs.cache-hit != 'true'
uses: ./.github/actions/linux-setup-spacemit
with:
path: ./spacemit_toolchain
version: ${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}
windows-2022-rocm-cache:
runs-on: windows-2022
env:
# Make sure this is in sync with build.yml
HIPSDK_INSTALLER_VERSION: "25.Q3"
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
- name: Setup Cache
uses: actions/cache@v4
id: cache-rocm
with:
path: C:\Program Files\AMD\ROCm
key: rocm-${{ env.HIPSDK_INSTALLER_VERSION }}-${{ runner.os }}
- name: Setup ROCm
if: steps.cache-rocm.outputs.cache-hit != 'true'
uses: ./.github/actions/windows-setup-rocm
with:
version: ${{ env.HIPSDK_INSTALLER_VERSION }}
+12 -14
View File
@@ -258,31 +258,29 @@ jobs:
runs-on: ubuntu-24.04
env:
# Make sure this is in sync with build-cache.yml
SPACEMIT_IME_TOOLCHAIN_VERSION: "1.1.2"
SPACEMIT_IME_TOOLCHAIN_PATH: "spacemit-toolchain-linux-glibc-x86_64"
steps:
- uses: actions/checkout@v4
- name: Cache Toolchain
- name: Use SpacemiT Toolchain Cache
uses: actions/cache@v4
id: cache-spacemit-ime-cross-toolchain
id: cache-toolchain
with:
path: ./${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}
key: ${{ runner.os }}-spacemit-ime-toolchain-v${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}
path: ./spacemit_toolchain
key: spacemit-ime-toolchain-v${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}-${{ runner.os }}
- name: Setup Toolchain
if: steps.cache-spacemit-ime-cross-toolchain.outputs.cache-hit != 'true'
run: |
wget --quiet --no-check-certificate https://archive.spacemit.com/toolchain/spacemit-toolchain-linux-glibc-x86_64-v${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}.tar.xz -O ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}.tar.xz
rm -rf ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}
mkdir -p ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}
tar xf ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}.tar.xz -C ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }} --strip-components=1
rm -rf ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}.tar.xz
- name: Setup SpacemiT Toolchain
if: steps.cache-toolchain.outputs.cache-hit != 'true'
uses: ./.github/actions/linux-setup-spacemit
with:
path: ./spacemit_toolchain
version: ${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}
- name: Build
run: |
export RISCV_ROOT_PATH=${PWD}/${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}
export RISCV_ROOT_PATH=${PWD}/spacemit_toolchain
cmake -B build -DLLAMA_CURL=OFF \
-DCMAKE_BUILD_TYPE=Release \
-DGGML_OPENMP=OFF \
+101 -39
View File
@@ -97,7 +97,7 @@ jobs:
ctest -L 'main|curl' --verbose --timeout 900
macOS-latest-cmake-x64:
runs-on: macos-13
runs-on: macos-15-intel
steps:
- name: Clone
@@ -387,6 +387,39 @@ jobs:
cd build
ctest -L main --verbose
ubuntu-24-cmake-vulkan-deb:
runs-on: ubuntu-24.04
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
with:
key: ubuntu-24-cmake-vulkan-deb
evict-old-files: 1d
- name: Dependencies
id: depends
run: |
sudo apt-get install -y glslc libvulkan-dev libcurl4-openssl-dev
- name: Configure
id: cmake_configure
run: |
cmake -B build \
-DCMAKE_BUILD_TYPE=RelWithDebInfo \
-DGGML_BACKEND_DL=ON \
-DGGML_CPU_ALL_VARIANTS=ON \
-DGGML_VULKAN=ON
- name: Build
id: cmake_build
run: |
cmake --build build -j $(nproc)
ubuntu-24-cmake-vulkan:
runs-on: ubuntu-24.04
@@ -413,20 +446,19 @@ jobs:
run: |
echo "VULKAN_SDK_VERSION=$(curl https://vulkan.lunarg.com/sdk/latest/linux.txt)" >> "$GITHUB_ENV"
- name: Cache Vulkan SDK
id: cache_vulkan_sdk
- name: Use Vulkan SDK Cache
uses: actions/cache@v4
id: cache-sdk
with:
path: ./vulkan_sdk
key: vulkan-sdk-${{ env.VULKAN_SDK_VERSION }}-${{ runner.os }}
- name: Install Vulkan SDK
if: steps.cache_vulkan_sdk.outputs.cache-hit != 'true'
id: vulkan_sdk_install
run: |
mkdir -p vulkan_sdk
cd vulkan_sdk
curl --no-progress-meter https://sdk.lunarg.com/sdk/download/latest/linux/vulkan_sdk.tar.xz | tar -Jx --strip-components=1
- name: Setup Vulkan SDK
if: steps.cache-sdk.outputs.cache-hit != 'true'
uses: ./.github/actions/linux-setup-vulkan
with:
path: ./vulkan_sdk
version: ${{ env.VULKAN_SDK_VERSION }}
- name: Build
id: cmake_build
@@ -445,8 +477,8 @@ jobs:
# This is using llvmpipe and runs slower than other backends
ctest -L main --verbose --timeout 4200
ubuntu-22-cmake-webgpu:
runs-on: ubuntu-22.04
ubuntu-24-cmake-webgpu:
runs-on: ubuntu-24.04
steps:
- name: Clone
@@ -456,16 +488,34 @@ jobs:
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
with:
key: ubuntu-22-cmake-webgpu
key: ubuntu-24-cmake-webgpu
evict-old-files: 1d
- name: Vulkan SDK Dependencies
id: vulkan-depends
- name: Dependencies
id: depends
run: |
wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | sudo apt-key add -
sudo wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list https://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list
sudo add-apt-repository -y ppa:kisak/kisak-mesa
sudo apt-get update -y
sudo apt-get install -y build-essential mesa-vulkan-drivers vulkan-sdk libcurl4-openssl-dev
sudo apt-get install -y build-essential mesa-vulkan-drivers libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev libcurl4-openssl-dev
- name: Get latest Vulkan SDK version
id: vulkan_sdk_version
run: |
echo "VULKAN_SDK_VERSION=$(curl https://vulkan.lunarg.com/sdk/latest/linux.txt)" >> "$GITHUB_ENV"
- name: Use Vulkan SDK Cache
uses: actions/cache@v4
id: cache-sdk
with:
path: ./vulkan_sdk
key: vulkan-sdk-${{ env.VULKAN_SDK_VERSION }}-${{ runner.os }}
- name: Setup Vulkan SDK
if: steps.cache-sdk.outputs.cache-hit != 'true'
uses: ./.github/actions/linux-setup-vulkan
with:
path: ./vulkan_sdk
version: ${{ env.VULKAN_SDK_VERSION }}
- name: Dawn Dependency
id: dawn-depends
@@ -1111,6 +1161,7 @@ jobs:
env:
# The ROCm version must correspond to the version used in the HIP SDK.
ROCM_VERSION: "6.4.2"
# Make sure this is in sync with build-cache.yml
HIPSDK_INSTALLER_VERSION: "25.Q3"
steps:
@@ -1125,33 +1176,18 @@ jobs:
7z x rocwmma.deb
7z x data.tar
- name: Cache ROCm Installation
id: cache-rocm
- name: Use ROCm Installation Cache
uses: actions/cache@v4
id: cache-rocm
with:
path: C:\Program Files\AMD\ROCm
key: rocm-${{ env.HIPSDK_INSTALLER_VERSION }}-${{ runner.os }}
- name: Install ROCm
- name: Setup ROCm
if: steps.cache-rocm.outputs.cache-hit != 'true'
id: depends
run: |
$ErrorActionPreference = "Stop"
write-host "Downloading AMD HIP SDK Installer"
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-${{ env.HIPSDK_INSTALLER_VERSION }}-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
write-host "Installing AMD HIP SDK"
$proc = Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -PassThru
$completed = $proc.WaitForExit(600000)
if (-not $completed) {
Write-Error "ROCm installation timed out after 10 minutes. Killing the process"
$proc.Kill()
exit 1
}
if ($proc.ExitCode -ne 0) {
Write-Error "ROCm installation failed with exit code $($proc.ExitCode)"
exit 1
}
write-host "Completed AMD HIP SDK installation"
uses: ./.github/actions/windows-setup-rocm
with:
version: ${{ env.HIPSDK_INSTALLER_VERSION }}
- name: Verify ROCm
id: verify
@@ -1512,3 +1548,29 @@ jobs:
run: |
vulkaninfo --summary
GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp
ggml-ci-arm64-cpu-kleidiai:
runs-on: ubuntu-22.04-arm
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
with:
key: ggml-ci-arm64-cpu-kleidiai
evict-old-files: 1d
- name: Dependencies
id: depends
run: |
sudo apt-get update
sudo apt-get install -y build-essential libcurl4-openssl-dev
- name: Test
id: ggml-ci
run: |
GG_BUILD_KLEIDIAI=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
+1 -1
View File
@@ -75,7 +75,7 @@ jobs:
name: llama-bin-macos-arm64.zip
macOS-x64:
runs-on: macos-13
runs-on: macos-15-intel
steps:
- name: Clone
+2 -1
View File
@@ -2,7 +2,7 @@
# multiplie collaborators per item can be specified
/.devops/*.Dockerfile @ngxson
/.github/actions/ @slaren
/.github/actions/ @slaren @CISC
/.github/workflows/ @CISC
/.github/workflows/release.yml @slaren
/.github/workflows/winget.yml @slaren
@@ -70,6 +70,7 @@
/ggml/src/ggml-rpc/ @rgerganov
/ggml/src/ggml-threading.* @ggerganov @slaren
/ggml/src/ggml-vulkan/ @0cc4m
/ggml/src/ggml-webgpu/ @reeselevine
/ggml/src/ggml-zdnn/ @taronaeo @Andreas-Krebbel @AlekseiNikiforovIBM
/ggml/src/ggml.c @ggerganov @slaren
/ggml/src/ggml.cpp @ggerganov @slaren
+32 -6
View File
@@ -22,6 +22,9 @@
# # with MUSA support
# GG_BUILD_MUSA=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
#
# # with KLEIDIAI support
# GG_BUILD_KLEIDIAI=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
#
if [ -z "$2" ]; then
echo "usage: $0 <output-dir> <mnt-dir>"
@@ -115,6 +118,34 @@ if [ ! -z ${GG_BUILD_NO_SVE} ]; then
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8.5-a+fp16+i8mm"
fi
if [ -n "${GG_BUILD_KLEIDIAI}" ]; then
echo ">>===== Enabling KleidiAI support"
CANDIDATES=("armv9-a+dotprod+i8mm" "armv8.6-a+dotprod+i8mm" "armv8.2-a+dotprod")
CPU=""
for cpu in "${CANDIDATES[@]}"; do
if echo 'int main(){}' | ${CXX:-c++} -march="$cpu" -x c++ - -c -o /dev/null >/dev/null 2>&1; then
CPU="$cpu"
break
fi
done
if [ -z "$CPU" ]; then
echo "ERROR: None of the required ARM baselines (armv9/armv8.6/armv8.2 + dotprod) are supported by this compiler."
exit 1
fi
echo ">>===== Using ARM baseline: ${CPU}"
CMAKE_EXTRA="${CMAKE_EXTRA:+$CMAKE_EXTRA } \
-DGGML_NATIVE=OFF \
-DGGML_CPU_KLEIDIAI=ON \
-DGGML_CPU_AARCH64=ON \
-DGGML_CPU_ARM_ARCH=${CPU} \
-DBUILD_SHARED_LIBS=OFF"
fi
## helpers
# download a file if it does not exist or if it is outdated
@@ -512,12 +543,7 @@ function gg_run_rerank_tiny {
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/tokenizer_config.json
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/special_tokens_map.json
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/resolve/main/pytorch_model.bin
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/sentence_bert_config.json
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/vocab.txt
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/modules.json
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/config.json
gg_wget models-mnt/rerank-tiny/1_Pooling https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/1_Pooling/config.json
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/vocab.json
path_models="../models-mnt/rerank-tiny"
+189 -152
View File
@@ -1615,18 +1615,14 @@ static void add_rpc_devices(const std::string & servers) {
if (!rpc_reg) {
throw std::invalid_argument("failed to find RPC backend");
}
typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint);
ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
if (!ggml_backend_rpc_add_device_fn) {
throw std::invalid_argument("failed to find RPC device add function");
typedef ggml_backend_reg_t (*ggml_backend_rpc_add_server_t)(const char * endpoint);
ggml_backend_rpc_add_server_t ggml_backend_rpc_add_server_fn = (ggml_backend_rpc_add_server_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_server");
if (!ggml_backend_rpc_add_server_fn) {
throw std::invalid_argument("failed to find RPC add server function");
}
for (const auto & server : rpc_servers) {
ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str());
if (dev) {
ggml_backend_device_register(dev);
} else {
throw std::invalid_argument("failed to register RPC device");
}
auto reg = ggml_backend_rpc_add_server_fn(server.c_str());
ggml_backend_register(reg);
}
}
@@ -1932,13 +1928,21 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_env("LLAMA_ARG_SWA_FULL"));
add_opt(common_arg(
{"--swa-checkpoints"}, "N",
string_format("max number of SWA checkpoints per slot to create (default: %d)\n"
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_swa_checkpoints),
{"--ctx-checkpoints", "--swa-checkpoints"}, "N",
string_format("max number of context checkpoints to create per slot (default: %d)\n"
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_ctx_checkpoints),
[](common_params & params, int value) {
params.n_swa_checkpoints = value;
params.n_ctx_checkpoints = value;
}
).set_env("LLAMA_ARG_SWA_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER}));
).set_env("LLAMA_ARG_CTX_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--cache-ram", "-cram"}, "N",
string_format("set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)\n"
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/16391)", params.cache_ram_mib),
[](common_params & params, int value) {
params.cache_ram_mib = value;
}
).set_env("LLAMA_ARG_CACHE_RAM").set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--kv-unified", "-kvu"},
string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"
@@ -2588,6 +2592,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.no_extra_bufts = true;
}
).set_env("LLAMA_ARG_NO_REPACK"));
add_opt(common_arg(
{"--no-host"},
"bypass host buffer allowing extra buffers to be used",
[](common_params & params) {
params.no_host = true;
}
).set_env("LLAMA_ARG_NO_HOST"));
add_opt(common_arg(
{"-ctk", "--cache-type-k"}, "TYPE",
string_format(
@@ -3347,7 +3358,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
add_opt(common_arg(
{"--chat-template-kwargs"}, "STRING",
string_format("sets additional params for the json template parser"),
[](common_params & params, const std::string & value) {
[](common_params & params, const std::string & value) {
auto parsed = json::parse(value);
for (const auto & item : parsed.items()) {
params.default_template_kwargs[item.key()] = item.value().dump();
@@ -3429,7 +3440,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--reasoning-format"}, "FORMAT",
"controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n"
"- none: leaves thoughts unparsed in `message.content`\n"
"- deepseek: puts thoughts in `message.reasoning_content` (except in streaming mode, which behaves as `none`)\n"
"- deepseek: puts thoughts in `message.reasoning_content`\n"
"- deepseek-legacy: keeps `<think>` tags in `message.content` while also populating `message.reasoning_content`\n"
"(default: auto)",
[](common_params & params, const std::string & value) {
params.reasoning_format = common_reasoning_format_from_name(value);
@@ -3558,21 +3570,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
common_log_set_file(common_log_main(), value.c_str());
}
));
add_opt(common_arg({ "--log-colors" }, "[on|off|auto]",
"Set colored logging ('on', 'off', or 'auto', default: 'auto')\n"
"'auto' enables colors when output is to a terminal",
[](common_params &, const std::string & value) {
if (is_truthy(value)) {
common_log_set_colors(common_log_main(), LOG_COLORS_ENABLED);
} else if (is_falsey(value)) {
common_log_set_colors(common_log_main(), LOG_COLORS_DISABLED);
} else if (is_autoy(value)) {
common_log_set_colors(common_log_main(), LOG_COLORS_AUTO);
} else {
throw std::invalid_argument(
string_format("error: unkown value for --log-colors: '%s'\n", value.c_str()));
}
}).set_env("LLAMA_LOG_COLORS"));
add_opt(common_arg(
{"--log-colors"}, "[on|off|auto]",
"Set colored logging ('on', 'off', or 'auto', default: 'auto')\n"
"'auto' enables colors when output is to a terminal",
[](common_params &, const std::string & value) {
if (is_truthy(value)) {
common_log_set_colors(common_log_main(), LOG_COLORS_ENABLED);
} else if (is_falsey(value)) {
common_log_set_colors(common_log_main(), LOG_COLORS_DISABLED);
} else if (is_autoy(value)) {
common_log_set_colors(common_log_main(), LOG_COLORS_AUTO);
} else {
throw std::invalid_argument(
string_format("error: unkown value for --log-colors: '%s'\n", value.c_str()));
}
}
).set_env("LLAMA_LOG_COLORS"));
add_opt(common_arg(
{"-v", "--verbose", "--log-verbose"},
"Set verbosity level to infinity (i.e. log all messages, useful for debugging)",
@@ -3838,7 +3852,87 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_examples({LLAMA_EXAMPLE_TTS}));
// model-specific
add_opt(common_arg(
{"--diffusion-steps"}, "N",
string_format("number of diffusion steps (default: %d)", params.diffusion.steps),
[](common_params & params, int value) { params.diffusion.steps = value; }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
add_opt(common_arg(
{"--diffusion-visual"},
string_format("enable visual diffusion mode (show progressive generation) (default: %s)", params.diffusion.visual_mode ? "true" : "false"),
[](common_params & params) { params.diffusion.visual_mode = true; }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
add_opt(common_arg(
{"--diffusion-eps"}, "F",
string_format("epsilon for timesteps (default: %.6f)", (double) params.diffusion.eps),
[](common_params & params, const std::string & value) { params.diffusion.eps = std::stof(value); }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
add_opt(common_arg(
{"--diffusion-algorithm"}, "N",
string_format("diffusion algorithm: 0=ORIGIN, 1=ENTROPY_BASED, 2=MARGIN_BASED, 3=RANDOM, 4=LOW_CONFIDENCE (default: %d)", params.diffusion.algorithm),
[](common_params & params, int value) { params.diffusion.algorithm = value; }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
add_opt(common_arg(
{"--diffusion-alg-temp"}, "F",
string_format("dream algorithm temperature (default: %.3f)", (double) params.diffusion.alg_temp),
[](common_params & params, const std::string & value) { params.diffusion.alg_temp = std::stof(value); }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
add_opt(common_arg(
{"--diffusion-block-length"}, "N",
string_format("llada block length for generation (default: %d)", params.diffusion.block_length),
[](common_params & params, int value) { params.diffusion.block_length = value; }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
add_opt(common_arg(
{"--diffusion-cfg-scale"}, "F",
string_format("llada classifier-free guidance scale (default: %.3f)", (double) params.diffusion.cfg_scale),
[](common_params & params, const std::string & value) { params.diffusion.cfg_scale = std::stof(value); }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
add_opt(common_arg(
{"--diffusion-add-gumbel-noise"}, "F",
string_format("add gumbel noise to the logits if temp > 0.0 (default: %s)", params.diffusion.add_gumbel_noise ? "true" : "false"),
[](common_params & params, const std::string & value) { params.diffusion.add_gumbel_noise = std::stof(value); }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
add_opt(common_arg(
{ "-lr", "--learning-rate" }, "ALPHA",
string_format("adamw or sgd optimizer alpha (default: %.2g); note: sgd alpha recommended ~10x (no momentum)", (double) params.lr.lr0),
[](common_params & params, const std::string & value) { params.lr.lr0 = std::stof(value); }
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg({ "-lr-min", "--learning-rate-min" }, "ALPHA",
string_format("(if >0) final learning rate after decay (if -decay-epochs is set, default=%.2g)",
(double) params.lr.lr_min),
[](common_params & params, const std::string & value) { params.lr.lr_min = std::stof(value); }
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg(
{"-decay-epochs", "--learning-rate-decay-epochs"}, "ALPHA",
string_format("(if >0) decay learning rate to -lr-min after this many epochs (exponential decay, default=%.2g)", (double) params.lr.decay_epochs),
[](common_params & params, const std::string & value) { params.lr.decay_epochs = std::stof(value); }
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg(
{"-wd", "--weight-decay"}, "WD",
string_format("adamw or sgd optimizer weight decay (0 is off; recommend very small e.g. 1e-9) (default: %.2g).", (double) params.lr.wd),
[](common_params & params, const std::string & value) { params.lr.wd = std::stof(value); }
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg(
{"-val-split", "--val-split"}, "FRACTION",
string_format("fraction of data to use as validation set for training (default: %.2g).", (double) params.val_split),
[](common_params & params, const std::string & value) { params.val_split = std::stof(value); }
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg(
{"-epochs", "--epochs"}, "N",
string_format("optimizer max # of epochs (default: %d)", params.lr.epochs),
[](common_params & params, int epochs) { params.lr.epochs = epochs; }
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg(
{"-opt", "--optimizer"}, "sgd|adamw", "adamw or sgd",
[](common_params & params, const std::string & name) {
params.optimizer = common_opt_get_optimizer(name.c_str());
if (params.optimizer == GGML_OPT_OPTIMIZER_TYPE_COUNT) {
throw std::invalid_argument("invalid --optimizer, valid options: adamw, sgd");
}
}
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
// presets
add_opt(common_arg(
{"--tts-oute-default"},
string_format("use default OuteTTS models (note: can download weights from the internet)"),
@@ -3851,42 +3945,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples({LLAMA_EXAMPLE_TTS}));
add_opt(common_arg(
{"--embd-bge-small-en-default"},
string_format("use default bge-small-en-v1.5 model (note: can download weights from the internet)"),
{"--embd-gemma-default"},
string_format("use default EmbeddingGemma model (note: can download weights from the internet)"),
[](common_params & params) {
params.model.hf_repo = "ggml-org/bge-small-en-v1.5-Q8_0-GGUF";
params.model.hf_file = "bge-small-en-v1.5-q8_0.gguf";
params.pooling_type = LLAMA_POOLING_TYPE_NONE;
params.embd_normalize = 2;
params.n_ctx = 512;
params.verbose_prompt = true;
params.embedding = true;
}
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--embd-e5-small-en-default"},
string_format("use default e5-small-v2 model (note: can download weights from the internet)"),
[](common_params & params) {
params.model.hf_repo = "ggml-org/e5-small-v2-Q8_0-GGUF";
params.model.hf_file = "e5-small-v2-q8_0.gguf";
params.pooling_type = LLAMA_POOLING_TYPE_NONE;
params.embd_normalize = 2;
params.n_ctx = 512;
params.verbose_prompt = true;
params.embedding = true;
}
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--embd-gte-small-default"},
string_format("use default gte-small model (note: can download weights from the internet)"),
[](common_params & params) {
params.model.hf_repo = "ggml-org/gte-small-Q8_0-GGUF";
params.model.hf_file = "gte-small-q8_0.gguf";
params.pooling_type = LLAMA_POOLING_TYPE_NONE;
params.embd_normalize = 2;
params.n_ctx = 512;
params.model.hf_repo = "ggml-org/embeddinggemma-300M-qat-q4_0-GGUF";
params.model.hf_file = "embeddinggemma-300M-qat-Q4_0.gguf";
params.port = 8011;
params.n_ubatch = 2048;
params.n_batch = 2048;
params.n_parallel = 32;
params.n_ctx = 2048*params.n_parallel;
params.verbose_prompt = true;
params.embedding = true;
}
@@ -3981,96 +4049,65 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{ "--diffusion-steps" }, "N",
string_format("number of diffusion steps (default: %d)", params.diffusion.steps),
[](common_params & params, int value) { params.diffusion.steps = value; }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
add_opt(common_arg(
{ "--diffusion-visual" },
string_format("enable visual diffusion mode (show progressive generation) (default: %s)",
params.diffusion.visual_mode ? "true" : "false"),
[](common_params & params) { params.diffusion.visual_mode = true; }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
{"--gpt-oss-20b-default"},
string_format("use gpt-oss-20b (note: can download weights from the internet)"),
[](common_params & params) {
params.model.hf_repo = "ggml-org/gpt-oss-20b-GGUF";
params.model.hf_file = "gpt-oss-20b-mxfp4.gguf";
params.port = 8013;
params.n_ubatch = 2048;
params.n_batch = 32768;
params.n_parallel = 2;
params.n_ctx = 131072*params.n_parallel;
params.sampling.temp = 1.0f;
params.sampling.top_p = 1.0f;
params.sampling.top_k = 0;
params.sampling.min_p = 0.01f;
params.use_jinja = true;
//params.default_template_kwargs["reasoning_effort"] = "\"high\"";
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{ "--diffusion-eps" }, "F",
string_format("epsilon for timesteps (default: %.6f)", (double) params.diffusion.eps),
[](common_params & params, const std::string & value) { params.diffusion.eps = std::stof(value); }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
add_opt(common_arg(
{ "--diffusion-algorithm" }, "N",
string_format("diffusion algorithm: 0=ORIGIN, 1=ENTROPY_BASED, 2=MARGIN_BASED, 3=RANDOM, 4=LOW_CONFIDENCE (default: %d)",
params.diffusion.algorithm),
[](common_params & params, int value) { params.diffusion.algorithm = value; }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
add_opt(common_arg(
{ "--diffusion-alg-temp" }, "F",
string_format("dream algorithm temperature (default: %.3f)", (double) params.diffusion.alg_temp),
[](common_params & params, const std::string & value) { params.diffusion.alg_temp = std::stof(value); }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
{"--gpt-oss-120b-default"},
string_format("use gpt-oss-120b (note: can download weights from the internet)"),
[](common_params & params) {
params.model.hf_repo = "ggml-org/gpt-oss-120b-GGUF";
params.port = 8013;
params.n_ubatch = 2048;
params.n_batch = 32768;
params.n_parallel = 2;
params.n_ctx = 131072*params.n_parallel;
params.sampling.temp = 1.0f;
params.sampling.top_p = 1.0f;
params.sampling.top_k = 0;
params.sampling.min_p = 0.01f;
params.use_jinja = true;
//params.default_template_kwargs["reasoning_effort"] = "\"high\"";
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{ "--diffusion-block-length" }, "N",
string_format("llada block length for generation (default: %d)", params.diffusion.block_length),
[](common_params & params, int value) { params.diffusion.block_length = value; }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
add_opt(common_arg(
{ "--diffusion-cfg-scale" }, "F",
string_format("llada classifier-free guidance scale (default: %.3f)", (double) params.diffusion.cfg_scale),
[](common_params & params, const std::string & value) { params.diffusion.cfg_scale = std::stof(value); }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
add_opt(common_arg(
{ "--diffusion-add-gumbel-noise" }, "F",
string_format("add gumbel noise to the logits if temp > 0.0 (default: %s)", params.diffusion.add_gumbel_noise ? "true" : "false"),
[](common_params & params, const std::string & value) { params.diffusion.add_gumbel_noise = std::stof(value); }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
{"--vision-gemma-4b-default"},
string_format("use Gemma 3 4B QAT (note: can download weights from the internet)"),
[](common_params & params) {
params.model.hf_repo = "ggml-org/gemma-3-4b-it-qat-GGUF";
params.port = 8014;
params.n_ctx = 0;
params.use_jinja = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(
common_arg({ "-lr", "--learning-rate" }, "ALPHA",
string_format(
"adamw or sgd optimizer alpha (default: %.2g); note: sgd alpha recommended ~10x (no momentum)",
(double) params.lr.lr0),
[](common_params & params, const std::string & value) { params.lr.lr0 = std::stof(value); })
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(
common_arg({ "-lr-min", "--learning-rate-min" }, "ALPHA",
string_format(
"(if >0) final learning rate after decay (if -decay-epochs is set, default=%.2g)",
(double) params.lr.lr_min),
[](common_params & params, const std::string & value) { params.lr.lr_min = std::stof(value); })
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(
common_arg({ "-decay-epochs", "--learning-rate-decay-epochs" }, "ALPHA",
string_format(
"(if >0) decay learning rate to -lr-min after this many epochs (exponential decay, default=%.2g)",
(double) params.lr.decay_epochs),
[](common_params & params, const std::string & value) { params.lr.decay_epochs = std::stof(value); })
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg(
{ "-wd", "--weight-decay" }, "WD",
string_format(
"adamw or sgd optimizer weight decay (0 is off; recommend very small e.g. 1e-9) (default: %.2g).",
(double) params.lr.wd),
[](common_params & params, const std::string & value) { params.lr.wd = std::stof(value); })
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg({ "-val-split", "--val-split" }, "FRACTION",
string_format("fraction of data to use as validation set for training (default: %.2g).",
(double) params.val_split),
[](common_params & params, const std::string & value) { params.val_split = std::stof(value); })
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg({ "-epochs", "--epochs" }, "N",
string_format("optimizer max # of epochs (default: %d)", params.lr.epochs),
[](common_params & params, int epochs) { params.lr.epochs = epochs; })
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg({ "-opt", "--optimizer" }, "sgd|adamw", "adamw or sgd",
[](common_params & params, const std::string & name) {
params.optimizer = common_opt_get_optimizer(name.c_str());
if (params.optimizer == GGML_OPT_OPTIMIZER_TYPE_COUNT) {
throw std::invalid_argument("invalid --optimizer, valid options: adamw, sgd");
}
})
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
{"--vision-gemma-12b-default"},
string_format("use Gemma 3 12B QAT (note: can download weights from the internet)"),
[](common_params & params) {
params.model.hf_repo = "ggml-org/gemma-3-12b-it-qat-GGUF";
params.port = 8014;
params.n_ctx = 0;
params.use_jinja = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
return ctx_arg;
}
+127 -15
View File
@@ -3,9 +3,12 @@
#include "log.h"
#include "regex-partial.h"
#include <algorithm>
#include <cctype>
#include <optional>
#include <stdexcept>
#include <string>
#include <string_view>
#include <vector>
using json = nlohmann::ordered_json;
@@ -166,6 +169,27 @@ void common_chat_msg_parser::consume_literal(const std::string & literal) {
}
bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think, const std::string & end_think) {
std::string pending_reasoning_prefix;
if (syntax_.reasoning_format == COMMON_REASONING_FORMAT_NONE) {
return false;
}
auto set_reasoning_prefix = [&](size_t prefix_pos) {
if (!syntax_.thinking_forced_open || syntax_.reasoning_in_content) {
return;
}
if (prefix_pos + start_think.size() > input_.size()) {
pending_reasoning_prefix.clear();
return;
}
// Capture the exact literal that opened the reasoning section so we can
// surface it back to callers. This ensures formats that force the
// reasoning tag open (e.g. DeepSeek R1) retain their original prefix
// instead of dropping it during parsing.
pending_reasoning_prefix = input_.substr(prefix_pos, start_think.size());
};
auto handle_reasoning = [&](const std::string & reasoning, bool closed) {
auto stripped_reasoning = string_strip(reasoning);
if (stripped_reasoning.empty()) {
@@ -178,28 +202,116 @@ bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think
add_content(syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "</think>" : end_think);
}
} else {
if (!pending_reasoning_prefix.empty()) {
add_reasoning_content(pending_reasoning_prefix);
pending_reasoning_prefix.clear();
}
add_reasoning_content(stripped_reasoning);
}
};
if (syntax_.reasoning_format != COMMON_REASONING_FORMAT_NONE) {
if (syntax_.thinking_forced_open || try_consume_literal(start_think)) {
if (auto res = try_find_literal(end_think)) {
handle_reasoning(res->prelude, /* closed */ true);
consume_spaces();
return true;
}
auto rest = consume_rest();
const size_t saved_pos = pos_;
const size_t saved_content_size = result_.content.size();
const size_t saved_reasoning_size = result_.reasoning_content.size();
auto restore_state = [&]() {
move_to(saved_pos);
result_.content.resize(saved_content_size);
result_.reasoning_content.resize(saved_reasoning_size);
};
// Allow leading whitespace to be preserved as content when reasoning is present at the start
size_t cursor = pos_;
size_t whitespace_end = cursor;
while (whitespace_end < input_.size() && std::isspace(static_cast<unsigned char>(input_[whitespace_end]))) {
++whitespace_end;
}
if (whitespace_end >= input_.size()) {
restore_state();
if (syntax_.thinking_forced_open) {
auto rest = input_.substr(saved_pos);
if (!rest.empty()) {
handle_reasoning(rest, /* closed */ !is_partial());
}
// Allow unclosed thinking tags, for now (https://github.com/ggml-org/llama.cpp/issues/13812, https://github.com/ggml-org/llama.cpp/issues/13877)
// if (!syntax_.thinking_forced_open) {
// throw common_chat_msg_partial_exception(end_think);
// }
move_to(input_.size());
return true;
}
return false;
}
cursor = whitespace_end;
const size_t remaining = input_.size() - cursor;
const size_t start_prefix = std::min(start_think.size(), remaining);
const bool has_start_tag = input_.compare(cursor, start_prefix, start_think, 0, start_prefix) == 0;
if (has_start_tag && start_prefix < start_think.size()) {
move_to(input_.size());
return true;
}
if (has_start_tag) {
if (whitespace_end > pos_) {
add_content(input_.substr(pos_, whitespace_end - pos_));
}
set_reasoning_prefix(cursor);
cursor += start_think.size();
} else if (syntax_.thinking_forced_open) {
cursor = whitespace_end;
} else {
restore_state();
return false;
}
while (true) {
if (cursor >= input_.size()) {
move_to(input_.size());
return true;
}
size_t end_pos = input_.find(end_think, cursor);
if (end_pos == std::string::npos) {
std::string_view remaining_view(input_.data() + cursor, input_.size() - cursor);
size_t partial_off = string_find_partial_stop(remaining_view, end_think);
size_t reasoning_end = partial_off == std::string::npos ? input_.size() : cursor + partial_off;
if (reasoning_end > cursor) {
handle_reasoning(input_.substr(cursor, reasoning_end - cursor), /* closed */ partial_off == std::string::npos && !is_partial());
}
move_to(input_.size());
return true;
}
if (end_pos > cursor) {
handle_reasoning(input_.substr(cursor, end_pos - cursor), /* closed */ true);
} else {
handle_reasoning("", /* closed */ true);
}
cursor = end_pos + end_think.size();
while (cursor < input_.size() && std::isspace(static_cast<unsigned char>(input_[cursor]))) {
++cursor;
}
const size_t next_remaining = input_.size() - cursor;
if (next_remaining == 0) {
move_to(cursor);
return true;
}
const size_t next_prefix = std::min(start_think.size(), next_remaining);
if (input_.compare(cursor, next_prefix, start_think, 0, next_prefix) == 0) {
if (next_prefix < start_think.size()) {
move_to(input_.size());
return true;
}
set_reasoning_prefix(cursor);
cursor += start_think.size();
continue;
}
move_to(cursor);
return true;
}
return false;
}
std::string common_chat_msg_parser::consume_rest() {
@@ -320,7 +432,7 @@ std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parse
if (is_arguments_path({})) {
// Entire JSON is the arguments and was parsed fully.
return consume_json_result {
partial->json.dump(),
partial->json.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true),
/* .is_partial = */ false,
};
}
@@ -332,7 +444,7 @@ std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parse
std::vector<std::string> path;
std::function<json(const json &)> remove_unsupported_healings_and_dump_args = [&](const json & j) -> json {
if (is_arguments_path(path)) {
auto arguments = j.dump();
auto arguments = j.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true);
if (is_partial() && !partial->healing_marker.marker.empty()) {
auto idx = arguments.find(partial->healing_marker.json_dump_marker);
if (idx != std::string::npos) {
+82
View File
@@ -625,6 +625,7 @@ const char * common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_CONTENT_ONLY: return "Content-only";
case COMMON_CHAT_FORMAT_GENERIC: return "Generic";
case COMMON_CHAT_FORMAT_MISTRAL_NEMO: return "Mistral Nemo";
case COMMON_CHAT_FORMAT_MAGISTRAL: return "Magistral";
case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x";
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools";
case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1";
@@ -984,6 +985,65 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO;
return data;
}
static common_chat_params common_chat_params_init_magistral(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
data.prompt = apply(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_MAGISTRAL;
data.preserved_tokens = {
"[THINK]",
"[/THINK]",
};
if (inputs.tools.is_array() && !inputs.tools.empty()) {
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array();
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
schemas.push_back({
{"type", "object"},
{"properties", {
{"name", {
{"type", "string"},
{"const", function.at("name")},
}},
{"arguments", function.at("parameters")},
{"id", {
{"type", "string"},
{"pattern", "^[a-zA-Z0-9]{9}$"},
}},
}},
{"required", json::array({"name", "arguments", "id"})},
});
});
auto schema = json {
{"type", "array"},
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
{"minItems", 1},
};
if (!inputs.parallel_tool_calls) {
schema["maxItems"] = 1;
}
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
});
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[TOOL_CALLS]"});
data.preserved_tokens.push_back("[TOOL_CALLS]");
} else {
data.grammar_lazy = false;
if (!inputs.json_schema.is_null()) {
if (!inputs.grammar.empty()) {
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
}
data.grammar = json_schema_to_grammar(inputs.json_schema);
} else {
data.grammar = inputs.grammar;
}
}
return data;
}
static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
@@ -994,6 +1054,18 @@ static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) {
parse_prefixed_json_tool_call_array(builder, prefix);
}
static void common_chat_parse_magistral(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("[THINK]", "[/THINK]");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex prefix(regex_escape("[TOOL_CALLS]"));
parse_prefixed_json_tool_call_array(builder, prefix);
}
static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
@@ -1336,6 +1408,8 @@ static common_chat_params common_chat_params_init_apertus(const common_chat_temp
return data;
}
static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) {
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
@@ -2702,6 +2776,10 @@ static common_chat_params common_chat_templates_apply_jinja(
return common_chat_params_init_llama_3_x(tmpl, params, allow_python_tag_builtin_tools);
}
if (src.find("[THINK]") != std::string::npos && src.find("[/THINK]") != std::string::npos) {
return common_chat_params_init_magistral(tmpl, params);
}
// Plain handler (no tools)
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
return common_chat_params_init_without_tools(tmpl, params);
@@ -2786,6 +2864,7 @@ common_chat_params common_chat_templates_apply(
}
static void common_chat_parse_content_only(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<think>", "</think>");
builder.add_content(builder.consume_rest());
}
@@ -2802,6 +2881,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
case COMMON_CHAT_FORMAT_MISTRAL_NEMO:
common_chat_parse_mistral_nemo(builder);
break;
case COMMON_CHAT_FORMAT_MAGISTRAL:
common_chat_parse_magistral(builder);
break;
case COMMON_CHAT_FORMAT_LLAMA_3_X:
common_chat_parse_llama_3_1(builder);
break;
+4 -3
View File
@@ -33,8 +33,8 @@ struct common_chat_msg_content_part {
struct common_chat_msg {
std::string role;
std::string content;
std::vector<common_chat_msg_content_part> content_parts = {};
std::vector<common_chat_tool_call> tool_calls = {};
std::vector<common_chat_msg_content_part> content_parts;
std::vector<common_chat_tool_call> tool_calls;
std::string reasoning_content;
std::string tool_name;
std::string tool_call_id;
@@ -44,7 +44,7 @@ struct common_chat_msg {
bool empty() const {
return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
}
void ensure_tool_call_ids_set(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) {
void set_tool_call_ids(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) {
for (auto i = 0u; i < tool_calls.size(); i++) {
if (ids_cache.size() <= i) {
auto id = tool_calls[i].id;
@@ -101,6 +101,7 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_CONTENT_ONLY,
COMMON_CHAT_FORMAT_GENERIC,
COMMON_CHAT_FORMAT_MISTRAL_NEMO,
COMMON_CHAT_FORMAT_MAGISTRAL,
COMMON_CHAT_FORMAT_LLAMA_3_X,
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
+1
View File
@@ -1133,6 +1133,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
mparams.use_mlock = params.use_mlock;
mparams.check_tensors = params.check_tensors;
mparams.use_extra_bufts = !params.no_extra_bufts;
mparams.no_host = params.no_host;
if (params.kv_overrides.empty()) {
mparams.kv_overrides = NULL;
+5 -3
View File
@@ -378,7 +378,7 @@ struct common_params {
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
bool cont_batching = true; // insert new sequences for decoding on-the-fly
bool no_perf = false; // disable performance metrics
bool ctx_shift = false; // context shift on infinite text generation
bool ctx_shift = false; // context shift on infinite text generation
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
bool kv_unified = false; // enable unified KV cache
@@ -392,6 +392,7 @@ struct common_params {
bool check_tensors = false; // validate tensor data
bool no_op_offload = false; // globally disable offload host tensor operations to device
bool no_extra_bufts = false; // disable extra buffer types (used for weight repacking)
bool no_host = false; // bypass host buffer allowing extra buffers to be used
bool single_turn = false; // single turn chat conversation
@@ -424,7 +425,8 @@ struct common_params {
int32_t timeout_write = timeout_read; // http write timeout in seconds
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
int32_t n_swa_checkpoints = 3; // max number of SWA checkpoints per slot
int32_t n_ctx_checkpoints = 8; // max number of context checkpoints per slot
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
std::string hostname = "127.0.0.1";
std::string public_path = ""; // NOLINT
@@ -432,7 +434,7 @@ struct common_params {
std::string chat_template = ""; // NOLINT
bool use_jinja = false; // NOLINT
bool enable_chat_template = true;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
int reasoning_budget = -1;
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
+51
View File
@@ -5,6 +5,7 @@
#include <nlohmann/json.hpp>
#include <string>
#include <regex>
using json = nlohmann::ordered_json;
@@ -168,6 +169,47 @@ bool common_json_parse(
}
}
// Matches a potentially partial unicode escape sequence, e.g. \u, \uX, \uXX, \uXXX, \uXXXX
static const std::regex partial_unicode_regex(R"(\\u(?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F])?)?)?)?$)");
auto is_high_surrogate = [&](const std::string & s) {
// Check if a partial of a high surrogate (U+D800-U+DBFF)
return s.length() >= 4 &&
s[0] == '\\' && s[1] == 'u' &&
std::tolower(s[2]) == 'd' &&
(s[3] == '8' || s[3] == '9' || std::tolower(s[3]) == 'a' || std::tolower(s[3]) == 'b');
};
// Initialize the unicode marker to a low surrogate to handle the edge case
// where a high surrogate (U+D800-U+DBFF) is immediately followed by a
// backslash (\)
std::string unicode_marker_padding = "udc00";
std::smatch last_unicode_seq;
if (std::regex_search(str, last_unicode_seq, partial_unicode_regex)) {
std::smatch second_last_seq;
std::string prelude = str.substr(0, last_unicode_seq.position());
// Pad the escape sequence with 0s until it forms a complete sequence of 6 characters
unicode_marker_padding = std::string(6 - last_unicode_seq.length(), '0');
if (is_high_surrogate(last_unicode_seq.str())) {
// If the sequence is a partial match for a high surrogate, add a low surrogate (U+DC00-U+UDFF)
unicode_marker_padding += "\\udc00";
} else if (std::regex_search(prelude, second_last_seq, partial_unicode_regex)) {
if (is_high_surrogate(second_last_seq.str())) {
// If this follows a high surrogate, pad it to be a low surrogate
if (last_unicode_seq.length() == 2) {
unicode_marker_padding = "dc00";
} else if (last_unicode_seq.length() == 3) {
unicode_marker_padding = "c00";
} else {
// The original unicode_marker_padding is already padded with 0s
}
}
}
}
const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$";
if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) {
@@ -186,6 +228,9 @@ bool common_json_parse(
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
// Was inside an object value string after an escape
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
} else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
// Was inside an object value string after a partial unicode escape
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
} else {
// find last :
auto last_pos = str.find_last_of(':');
@@ -205,6 +250,9 @@ bool common_json_parse(
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
// Was inside an array value string after an escape
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
} else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
// Was inside an array value string after a partial unicode escape
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
} else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) {
// Had just finished a value
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing;
@@ -230,6 +278,9 @@ bool common_json_parse(
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) {
// Was inside an object key string after an escape
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing;
} else if (can_parse(str + unicode_marker_padding + "\": 1" + closing)) {
// Was inside an object key string after a partial unicode escape
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\": 1" + closing;
} else {
auto last_pos = str.find_last_of(':');
if (last_pos == std::string::npos) {
+150 -14
View File
@@ -93,13 +93,15 @@ class ModelBase:
# Mistral format specifics
is_mistral_format: bool = False
disable_mistral_community_chat_template: bool = False
sentence_transformers_dense_modules: bool = False
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False,
use_temp_file: bool = False, eager: bool = False,
metadata_override: Path | None = None, model_name: str | None = None,
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None,
disable_mistral_community_chat_template: bool = False):
disable_mistral_community_chat_template: bool = False,
sentence_transformers_dense_modules: bool = False):
if type(self) is ModelBase or \
type(self) is TextModel or \
type(self) is MmprojModel:
@@ -114,6 +116,7 @@ class ModelBase:
self.lazy = not eager or (remote_hf_model_id is not None)
self.dry_run = dry_run
self.remote_hf_model_id = remote_hf_model_id
self.sentence_transformers_dense_modules = sentence_transformers_dense_modules
if remote_hf_model_id is not None:
self.is_safetensors = True
@@ -891,6 +894,9 @@ class TextModel(ModelBase):
if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206":
# ref: https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base
res = "llada-moe"
if chkhsh == "53e325976a6e142379c19b09afcae354f2f496f147afa8f9e189a33fe4e3024e":
# ref: https://huggingface.co/ibm-granite/granite-docling-258M
res = "granite-docling"
if res is None:
logger.warning("\n")
@@ -1325,6 +1331,7 @@ class MmprojModel(ModelBase):
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count)
# load preprocessor config
self.preprocessor_config = {}
if not self.is_mistral_format:
with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
self.preprocessor_config = json.load(f)
@@ -1347,7 +1354,8 @@ class MmprojModel(ModelBase):
self.gguf_writer.add_vision_projection_dim(self.n_embd_text)
# vision config
self.gguf_writer.add_vision_image_size(self.find_vparam(["image_size"]))
self.image_size = self.find_vparam(["image_size"])
self.gguf_writer.add_vision_image_size(self.image_size)
self.gguf_writer.add_vision_patch_size(self.find_vparam(["patch_size"]))
self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size"]))
self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size"]))
@@ -2378,6 +2386,10 @@ class SmolVLMModel(MmprojModel):
self.gguf_writer.add_vision_projector_scale_factor(self.global_config.get("scale_factor", 2))
self.gguf_writer.add_vision_use_gelu(True)
# Add the preprocessor longest edge size
preproc_image_size = self.preprocessor_config.get("size", {}).get("longest_edge", self.image_size)
self.gguf_writer.add_vision_preproc_image_size(preproc_image_size)
def tensor_force_quant(self, name, new_name, bid, n_dims):
if ".embeddings." in name:
return gguf.GGMLQuantizationType.F32
@@ -5260,6 +5272,53 @@ class Gemma3Model(TextModel):
@ModelBase.register("Gemma3TextModel")
class EmbeddingGemma(Gemma3Model):
model_arch = gguf.MODEL_ARCH.GEMMA_EMBEDDING
module_paths = []
dense_features_dims = {}
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.sentence_transformers_dense_modules:
# read modules.json to determine if model has Dense layers
modules_file = self.dir_model / "modules.json"
if modules_file.is_file():
with open(modules_file, encoding="utf-8") as modules_json_file:
mods = json.load(modules_json_file)
for mod in mods:
if mod["type"] == "sentence_transformers.models.Dense":
mod_path = mod["path"]
# check if model.safetensors file for Dense layer exists
model_tensors_file = self.dir_model / mod_path / "model.safetensors"
if model_tensors_file.is_file():
self.module_paths.append(mod_path)
# read config.json of the Dense layer to get in/out features
mod_conf_file = self.dir_model / mod_path / "config.json"
if mod_conf_file.is_file():
with open(mod_conf_file, encoding="utf-8") as mod_conf_json_file:
mod_conf = json.load(mod_conf_json_file)
# hparams dense_2_feat_out and dense_3_feat_in are required when loading model's dense weights
prefix = self._get_dense_prefix(mod_path)
if mod_conf["in_features"] is not None and mod_conf["out_features"] is not None:
self.dense_features_dims[prefix] = (mod_conf["in_features"], mod_conf["out_features"])
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
from safetensors.torch import load_file
module_paths = list(self.module_paths)
for i, module_path in enumerate(module_paths):
tensors_file = self.dir_model / module_path / "model.safetensors"
local_tensors = load_file(tensors_file)
tensor_name = self._get_dense_prefix(module_path)
for name, local_tensor in local_tensors.items():
if not name.endswith(".weight"):
continue
orig_name = name.replace("linear", tensor_name)
name = self.map_tensor_name(orig_name)
yield name, local_tensor.clone()
@staticmethod
def _get_dense_prefix(module_path) -> str:
"""Get the tensor name prefix for the Dense layer from module path."""
tensor_name = "dense_2" if module_path == "2_Dense" else "dense_3"
return tensor_name
def set_gguf_parameters(self):
super().set_gguf_parameters()
@@ -5276,6 +5335,10 @@ class EmbeddingGemma(Gemma3Model):
logger.info(f"Using original sliding_window from config: {orig_sliding_window} "
f"instead of {self.hparams['sliding_window']}")
self.gguf_writer.add_sliding_window(orig_sliding_window)
if self.sentence_transformers_dense_modules:
for dense, dims in self.dense_features_dims.items():
logger.info(f"Setting dense layer {dense} in/out features to {dims}")
self.gguf_writer.add_dense_features_dims(dense, dims[0], dims[1])
self._try_set_pooling_type()
@@ -5903,20 +5966,12 @@ class Mamba2Model(TextModel):
class JambaModel(TextModel):
model_arch = gguf.MODEL_ARCH.JAMBA
def get_vocab_base_pre(self, tokenizer) -> str:
del tokenizer # unused
return "gpt-2"
def set_vocab(self):
if (self.dir_model / "tokenizer.model").is_file():
# Using Jamba's tokenizer.json causes errors on model load
# (something about "byte not found in vocab"),
# but there's a working tokenizer.model
self._set_vocab_sentencepiece()
else:
# Some Jamba models only have a tokenizer.json, which works.
self._set_vocab_gpt2()
self._set_vocab_llama_hf()
self.gguf_writer.add_add_space_prefix(False)
def set_gguf_parameters(self):
d_model = self.find_hparam(["hidden_size", "mamba_d_model"])
@@ -8827,6 +8882,75 @@ class LFM2Model(TextModel):
return [(self.map_tensor_name(name), data_torch)]
@ModelBase.register("Lfm2MoeForCausalLM")
class LFM2MoeModel(TextModel):
model_arch = gguf.MODEL_ARCH.LFM2MOE
def set_gguf_parameters(self):
# set num_key_value_heads only for attention layers
self.hparams["num_key_value_heads"] = [
self.hparams["num_key_value_heads"] if layer_type == "full_attention" else 0
for layer_type in self.hparams["layer_types"]
]
super().set_gguf_parameters()
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"])
self.gguf_writer.add_leading_dense_block_count(self.hparams["num_dense_layers"])
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
self.gguf_writer.add_shortconv_l_cache(self.hparams["conv_L_cache"])
# cache for experts weights for merging
_experts_cache: dict[int, dict[str, Tensor]] = {}
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# conv op requires 2d tensor
if 'conv.conv' in name:
data_torch = data_torch.squeeze(1)
if name.endswith(".expert_bias"):
name = name.replace(".expert_bias", ".expert_bias.bias")
# merge expert weights
if 'experts' in name:
n_experts = self.hparams["num_experts"]
assert bid is not None
expert_cache = self._experts_cache.setdefault(bid, {})
expert_cache[name] = data_torch
expert_weights = ["w1", "w2", "w3"]
# not enough expert weights to merge
if len(expert_cache) < n_experts * len(expert_weights):
return []
tensors: list[tuple[str, Tensor]] = []
for w_name in expert_weights:
datas: list[Tensor] = []
for xid in range(n_experts):
ename = f"model.layers.{bid}.feed_forward.experts.{xid}.{w_name}.weight"
datas.append(expert_cache[ename])
del expert_cache[ename]
data_torch = torch.stack(datas, dim=0)
merged_name = f"layers.{bid}.feed_forward.experts.{w_name}.weight"
new_name = self.map_tensor_name(merged_name)
tensors.append((new_name, data_torch))
del self._experts_cache[bid]
return tensors
return [(self.map_tensor_name(name), data_torch)]
def prepare_tensors(self):
super().prepare_tensors()
assert not self._experts_cache
@ModelBase.register("Lfm2VlForConditionalGeneration")
class LFM2VLModel(MmprojModel):
def __init__(self, *args, **kwargs):
@@ -9257,6 +9381,13 @@ def parse_args() -> argparse.Namespace:
)
)
parser.add_argument(
"--sentence-transformers-dense-modules", action="store_true",
help=("Whether to include sentence-transformers dense modules."
"It can be used for sentence-transformers models, like google/embeddinggemma-300m"
"Default these modules are not included.")
)
args = parser.parse_args()
if not args.print_supported_models and args.model is None:
parser.error("the following arguments are required: model")
@@ -9319,9 +9450,13 @@ def main() -> None:
if args.remote:
hf_repo_id = args.model
from huggingface_hub import snapshot_download
allowed_patterns = ["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"]
if args.sentence_transformers_dense_modules:
# include sentence-transformers dense modules safetensors files
allowed_patterns.append("*.safetensors")
local_dir = snapshot_download(
repo_id=hf_repo_id,
allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"])
allow_patterns=allowed_patterns)
dir_model = Path(local_dir)
logger.info(f"Downloaded config and tokenizer to {local_dir}")
else:
@@ -9389,7 +9524,8 @@ def main() -> None:
split_max_tensors=args.split_max_tensors,
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
small_first_shard=args.no_tensor_first_split,
remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template
remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template,
sentence_transformers_dense_modules=args.sentence_transformers_dense_modules
)
if args.vocab_only:
+1
View File
@@ -140,6 +140,7 @@ models = [
{"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
{"name": "llada-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base", },
{"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
]
# some models are known to be broken upstream, so we will skip them as exceptions
+10 -8
View File
@@ -31,7 +31,7 @@ Legend:
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | | ✅ | ❌ |
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | | ✅ | ❌ |
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
| CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
@@ -51,7 +51,7 @@ Legend:
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | ❌ | ❌ |
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | ❌ | ❌ |
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| IM2COL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
@@ -65,11 +65,11 @@ Legend:
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ |
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | ❌ | ❌ |
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | ❌ | ❌ |
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| OPT_STEP_SGD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| OUT_PROD | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
| PAD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ | ❌ |
| PAD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ |
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
@@ -92,9 +92,9 @@ Legend:
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | ❌ | ❌ |
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ |
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | | ✅ | ❌ |
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | ❌ | ❌ |
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | | ✅ | ❌ |
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ |
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | ❌ | ❌ |
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
@@ -102,9 +102,11 @@ Legend:
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
| SUM | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
| SUM_ROWS | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ | ❌ |
| SUM_ROWS | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ |
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
| SWIGLU_OAI | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ |
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| TOPK_MOE | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ |
| XIELU | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
+12095 -4249
View File
File diff suppressed because it is too large Load Diff
+21 -2
View File
@@ -116,20 +116,39 @@ embedding-convert-model:
METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \
./scripts/embedding/convert-model.sh
embedding-convert-model-st:
$(call validate_embedding_model_path,embedding-convert-model-st)
@MODEL_NAME="$(MODEL_NAME)" OUTTYPE="$(OUTTYPE)" MODEL_PATH="$(EMBEDDING_MODEL_PATH)" \
METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \
./scripts/embedding/convert-model.sh -st
embedding-run-original-model:
$(call validate_embedding_model_path,embedding-run-original-model)
@EMBEDDING_MODEL_PATH="$(EMBEDDING_MODEL_PATH)" \
USE_SENTENCE_TRANSFORMERS="$(USE_SENTENCE_TRANSFORMERS)" \
./scripts/embedding/run-original-model.py \
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)")
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") \
$(if $(USE_SENTENCE_TRANSFORMERS),--use-sentence-transformers)
embedding-run-original-model-st: USE_SENTENCE_TRANSFORMERS=1
embedding-run-original-model-st: embedding-run-original-model
embedding-run-converted-model:
@./scripts/embedding/run-converted-model.sh $(CONVERTED_EMBEDDING_MODEL) \
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)")
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") \
$(if $(USE_POOLING),--pooling)
embedding-run-converted-model-st: USE_POOLING=1
embedding-run-converted-model-st: embedding-run-converted-model
embedding-verify-logits: embedding-run-original-model embedding-run-converted-model
@./scripts/embedding/compare-embeddings-logits.sh \
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)")
embedding-verify-logits-st: embedding-run-original-model-st embedding-run-converted-model-st
@./scripts/embedding/compare-embeddings-logits.sh \
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)")
embedding-inspect-original-model:
$(call validate_embedding_model_path,embedding-inspect-original-model)
@EMBEDDING_MODEL_PATH="$(EMBEDDING_MODEL_PATH)" ./scripts/utils/inspect-org-model.py -m ${EMBEDDING_MODEL_PATH}
+24
View File
@@ -189,6 +189,23 @@ This command will save two files to the `data` directory, one is a binary
file containing logits which will be used for comparison with the converted
model, and the other is a text file which allows for manual visual inspection.
#### Using SentenceTransformer with numbered layers
For models that have numbered SentenceTransformer layers (01_Pooling, 02_Dense,
03_Dense, 04_Normalize), use the `-st` targets to apply all these layers:
```console
# Run original model with SentenceTransformer (applies all numbered layers)
(venv) $ make embedding-run-original-model-st
# Run converted model with pooling enabled
(venv) $ make embedding-run-converted-model-st
```
This will use the SentenceTransformer library to load and run the model, which
automatically applies all the numbered layers in the correct order. This is
particularly useful when comparing with models that should include these
additional transformation layers beyond just the base model output.
### Model conversion
After updates have been made to [gguf-py](../../gguf-py) to add support for the
new model the model can be converted to GGUF format using the following command:
@@ -208,6 +225,13 @@ was done manually in the previous steps) and compare the logits:
(venv) $ make embedding-verify-logits
```
For models with SentenceTransformer layers, use the `-st` verification target:
```console
(venv) $ make embedding-verify-logits-st
```
This convenience target automatically runs both the original model with SentenceTransformer
and the converted model with pooling enabled, then compares the results.
### llama-server verification
To verify that the converted model works with llama-server, the following
command can be used:
+53 -25
View File
@@ -1,4 +1,7 @@
#include "llama.h"
#include "common.h"
#include <cstdio>
#include <cstring>
#include <string>
@@ -8,7 +11,10 @@
static void print_usage(int, char ** argv) {
printf("\nexample usage:\n");
printf("\n %s -m model.gguf [-ngl n_gpu_layers] -embd-mode [prompt]\n", argv[0]);
printf("\n %s -m model.gguf [-ngl n_gpu_layers] -embd-mode [-pooling] [-embd-norm <norm>] [prompt]\n", argv[0]);
printf("\n");
printf(" -embd-norm: normalization type for pooled embeddings (default: 2)\n");
printf(" -1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm\n");
printf("\n");
}
@@ -17,6 +23,8 @@ int main(int argc, char ** argv) {
std::string prompt = "Hello, my name is";
int ngl = 0;
bool embedding_mode = false;
bool pooling_enabled = false;
int32_t embd_norm = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
{
int i = 1;
@@ -41,9 +49,13 @@ int main(int argc, char ** argv) {
return 1;
}
} else if (strcmp(argv[i], "-embd-mode") == 0) {
embedding_mode = true;
} else if (strcmp(argv[i], "-pooling") == 0) {
pooling_enabled = true;
} else if (strcmp(argv[i], "-embd-norm") == 0) {
if (i + 1 < argc) {
try {
embedding_mode = true;
embd_norm = std::stoi(argv[++i]);
} catch (...) {
print_usage(argc, argv);
return 1;
@@ -112,7 +124,7 @@ int main(int argc, char ** argv) {
ctx_params.no_perf = false;
if (embedding_mode) {
ctx_params.embeddings = true;
ctx_params.pooling_type = LLAMA_POOLING_TYPE_NONE;
ctx_params.pooling_type = pooling_enabled ? LLAMA_POOLING_TYPE_MEAN : LLAMA_POOLING_TYPE_NONE;
ctx_params.n_ubatch = ctx_params.n_batch;
}
@@ -143,17 +155,27 @@ int main(int argc, char ** argv) {
return 1;
}
float * logits;
int n_logits;
float * data_ptr;
int data_size;
const char * type;
std::vector<float> embd_out;
if (embedding_mode) {
logits = llama_get_embeddings(ctx);
n_logits = llama_model_n_embd(model) * batch.n_tokens;
const int n_embd = llama_model_n_embd(model);
const int n_embd_count = pooling_enabled ? 1 : batch.n_tokens;
const int n_embeddings = n_embd * n_embd_count;
float * embeddings;
type = "-embeddings";
const int n_embd = llama_model_n_embd(model);
const int n_embd_count = batch.n_tokens;
if (llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_NONE) {
embeddings = llama_get_embeddings_seq(ctx, 0);
embd_out.resize(n_embeddings);
printf("Normalizing embeddings using norm: %d\n", embd_norm);
common_embd_normalize(embeddings, embd_out.data(), n_embeddings, embd_norm);
embeddings = embd_out.data();
} else {
embeddings = llama_get_embeddings(ctx);
}
printf("Embedding dimension: %d\n", n_embd);
printf("\n");
@@ -164,7 +186,7 @@ int main(int argc, char ** argv) {
// Print first 3 values
for (int i = 0; i < 3 && i < n_embd; i++) {
printf("%9.6f ", logits[j * n_embd + i]);
printf("%9.6f ", embeddings[j * n_embd + i]);
}
printf(" ... ");
@@ -172,7 +194,7 @@ int main(int argc, char ** argv) {
// Print last 3 values
for (int i = n_embd - 3; i < n_embd; i++) {
if (i >= 0) {
printf("%9.6f ", logits[j * n_embd + i]);
printf("%9.6f ", embeddings[j * n_embd + i]);
}
}
@@ -180,27 +202,33 @@ int main(int argc, char ** argv) {
}
printf("\n");
printf("Embeddings size: %d\n", n_logits);
printf("Embeddings size: %d\n", n_embeddings);
data_ptr = embeddings;
data_size = n_embeddings;
} else {
logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
n_logits = llama_vocab_n_tokens(vocab);
float * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
const int n_logits = llama_vocab_n_tokens(vocab);
type = "";
printf("Vocab size: %d\n", n_logits);
data_ptr = logits;
data_size = n_logits;
}
std::filesystem::create_directory("data");
// Save logits to binary file
// Save data to binary file
char bin_filename[512];
snprintf(bin_filename, sizeof(bin_filename), "data/llamacpp-%s%s.bin", model_name, type);
printf("Saving logits to %s\n", bin_filename);
printf("Saving data to %s\n", bin_filename);
FILE * f = fopen(bin_filename, "wb");
if (f == NULL) {
fprintf(stderr, "%s: error: failed to open binary output file\n", __func__);
return 1;
}
fwrite(logits, sizeof(float), n_logits, f);
fwrite(data_ptr, sizeof(float), data_size, f);
fclose(f);
// Also save as text for debugging
@@ -211,27 +239,27 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s: error: failed to open text output file\n", __func__);
return 1;
}
for (int i = 0; i < n_logits; i++) {
fprintf(f, "%d: %.6f\n", i, logits[i]);
for (int i = 0; i < data_size; i++) {
fprintf(f, "%d: %.6f\n", i, data_ptr[i]);
}
fclose(f);
if (!embedding_mode) {
printf("First 10 logits: ");
for (int i = 0; i < 10 && i < n_logits; i++) {
printf("%.6f ", logits[i]);
for (int i = 0; i < 10 && i < data_size; i++) {
printf("%.6f ", data_ptr[i]);
}
printf("\n");
printf("Last 10 logits: ");
for (int i = n_logits - 10; i < n_logits; i++) {
if (i >= 0) printf("%.6f ", logits[i]);
for (int i = data_size - 10; i < data_size; i++) {
if (i >= 0) printf("%.6f ", data_ptr[i]);
}
printf("\n\n");
}
printf("Logits saved to %s\n", bin_filename);
printf("Logits saved to %s\n", txt_filename);
printf("Data saved to %s\n", bin_filename);
printf("Data saved to %s\n", txt_filename);
llama_free(ctx);
llama_model_free(model);
@@ -4,3 +4,4 @@ torchvision
transformers
huggingface-hub
accelerate
sentence-transformers
@@ -2,6 +2,21 @@
set -e
# Parse command line arguments
SENTENCE_TRANSFORMERS=""
while [[ $# -gt 0 ]]; do
case $1 in
-st|--sentence-transformers)
SENTENCE_TRANSFORMERS="--sentence-transformers-dense-modules"
shift
;;
*)
echo "Unknown option: $1"
exit 1
;;
esac
done
MODEL_NAME="${MODEL_NAME:-$(basename "$EMBEDDING_MODEL_PATH")}"
OUTPUT_DIR="${OUTPUT_DIR:-../../models}"
TYPE="${OUTTYPE:-f16}"
@@ -15,7 +30,8 @@ echo "Converted model path:: ${CONVERTED_MODEL}"
python ../../convert_hf_to_gguf.py --verbose \
${EMBEDDING_MODEL_PATH} \
--outfile ${CONVERTED_MODEL} \
--outtype ${TYPE}
--outtype ${TYPE} \
${SENTENCE_TRANSFORMERS}
echo ""
echo "The environment variable CONVERTED_EMBEDDING MODEL can be set to this path using:"
@@ -5,6 +5,7 @@ set -e
# Parse command line arguments
CONVERTED_MODEL=""
PROMPTS_FILE=""
USE_POOLING=""
while [[ $# -gt 0 ]]; do
case $1 in
@@ -12,6 +13,10 @@ while [[ $# -gt 0 ]]; do
PROMPTS_FILE="$2"
shift 2
;;
--pooling)
USE_POOLING="1"
shift
;;
*)
if [ -z "$CONVERTED_MODEL" ]; then
CONVERTED_MODEL="$1"
@@ -47,4 +52,8 @@ echo $CONVERTED_MODEL
cmake --build ../../build --target llama-logits -j8
# TODO: update logits.cpp to accept a --file/-f option for the prompt
../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode "$PROMPT"
if [ -n "$USE_POOLING" ]; then
../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode -pooling "$PROMPT"
else
../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode "$PROMPT"
fi
@@ -14,6 +14,8 @@ unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME')
parser = argparse.ArgumentParser(description='Process model with specified path')
parser.add_argument('--model-path', '-m', help='Path to the model')
parser.add_argument('--prompts-file', '-p', help='Path to file containing prompts (one per line)')
parser.add_argument('--use-sentence-transformers', action='store_true',
help='Use SentenceTransformer to apply all numbered layers (01_Pooling, 02_Dense, 03_Dense, 04_Normalize)')
args = parser.parse_args()
def read_prompt_from_file(file_path):
@@ -31,41 +33,52 @@ model_path = os.environ.get('EMBEDDING_MODEL_PATH', args.model_path)
if model_path is None:
parser.error("Model path must be specified either via --model-path argument or EMBEDDING_MODEL_PATH environment variable")
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Determine if we should use SentenceTransformer
use_sentence_transformers = args.use_sentence_transformers or os.environ.get('USE_SENTENCE_TRANSFORMERS', '').lower() in ('1', 'true', 'yes')
config = AutoConfig.from_pretrained(model_path)
# This can be used to override the sliding window size for manual testing. This
# can be useful to verify the sliding window attention mask in the original model
# and compare it with the converted .gguf model.
if hasattr(config, 'sliding_window'):
original_sliding_window = config.sliding_window
#original_sliding_window = 6
print(f"Modified sliding window: {original_sliding_window} -> {config.sliding_window}")
print(f"Using unreleased model: {unreleased_model_name}")
if unreleased_model_name:
model_name_lower = unreleased_model_name.lower()
unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
class_name = f"{unreleased_model_name}Model"
print(f"Importing unreleased model module: {unreleased_module_path}")
try:
model_class = getattr(importlib.import_module(unreleased_module_path), class_name)
model = model_class.from_pretrained(model_path, config=config)
except (ImportError, AttributeError) as e:
print(f"Failed to import or load model: {e}")
exit(1)
if use_sentence_transformers:
from sentence_transformers import SentenceTransformer
print("Using SentenceTransformer to apply all numbered layers")
model = SentenceTransformer(model_path)
tokenizer = model.tokenizer
config = model[0].auto_model.config # type: ignore
else:
model = AutoModel.from_pretrained(model_path, config=config)
print(f"Model class: {type(model)}")
print(f"Model file: {type(model).__module__}")
tokenizer = AutoTokenizer.from_pretrained(model_path)
config = AutoConfig.from_pretrained(model_path)
# This can be used to override the sliding window size for manual testing. This
# can be useful to verify the sliding window attention mask in the original model
# and compare it with the converted .gguf model.
if hasattr(config, 'sliding_window'):
original_sliding_window = config.sliding_window
#original_sliding_window = 6
print(f"Modified sliding window: {original_sliding_window} -> {config.sliding_window}")
print(f"Using unreleased model: {unreleased_model_name}")
if unreleased_model_name:
model_name_lower = unreleased_model_name.lower()
unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
class_name = f"{unreleased_model_name}Model"
print(f"Importing unreleased model module: {unreleased_module_path}")
try:
model_class = getattr(importlib.import_module(unreleased_module_path), class_name)
model = model_class.from_pretrained(model_path, config=config)
except (ImportError, AttributeError) as e:
print(f"Failed to import or load model: {e}")
exit(1)
else:
model = AutoModel.from_pretrained(model_path, config=config)
print(f"Model class: {type(model)}")
print(f"Model file: {type(model).__module__}")
# Verify the model is using the correct sliding window
if hasattr(model.config, 'sliding_window'):
print(f"Model's sliding_window: {model.config.sliding_window}")
else:
print("Model config does not have sliding_window attribute")
if not use_sentence_transformers:
if hasattr(model.config, 'sliding_window'): # type: ignore
print(f"Model's sliding_window: {model.config.sliding_window}") # type: ignore
else:
print("Model config does not have sliding_window attribute")
model_name = os.path.basename(model_path)
@@ -75,34 +88,56 @@ if args.prompts_file:
else:
texts = ["Hello world today"]
encoded = tokenizer(
texts,
padding=True,
truncation=True,
return_tensors="pt"
)
tokens = encoded['input_ids'][0]
token_strings = tokenizer.convert_ids_to_tokens(tokens)
for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)):
print(f"{token_id:6d} -> '{token_str}'")
with torch.no_grad():
outputs = model(**encoded)
hidden_states = outputs.last_hidden_state # Shape: [batch_size, seq_len, hidden_size]
if use_sentence_transformers:
embeddings = model.encode(texts, convert_to_numpy=True)
all_embeddings = embeddings # Shape: [batch_size, hidden_size]
# Extract embeddings for each token (matching LLAMA_POOLING_TYPE_NONE behavior)
all_embeddings = hidden_states[0].cpu().numpy() # Shape: [seq_len, hidden_size]
encoded = tokenizer(
texts,
padding=True,
truncation=True,
return_tensors="pt"
)
tokens = encoded['input_ids'][0]
token_strings = tokenizer.convert_ids_to_tokens(tokens)
for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)):
print(f"{token_id:6d} -> '{token_str}'")
print(f"Hidden states shape: {hidden_states.shape}")
print(f"All embeddings shape: {all_embeddings.shape}")
print(f"Embedding dimension: {all_embeddings.shape[1]}")
print(f"Embeddings shape (after all SentenceTransformer layers): {all_embeddings.shape}")
print(f"Embedding dimension: {all_embeddings.shape[1] if len(all_embeddings.shape) > 1 else all_embeddings.shape[0]}") # type: ignore
else:
# Standard approach: use base model output only
encoded = tokenizer(
texts,
padding=True,
truncation=True,
return_tensors="pt"
)
# Print embeddings exactly like embedding.cpp does for LLAMA_POOLING_TYPE_NONE
n_embd = all_embeddings.shape[1]
n_embd_count = all_embeddings.shape[0]
tokens = encoded['input_ids'][0]
token_strings = tokenizer.convert_ids_to_tokens(tokens)
for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)):
print(f"{token_id:6d} -> '{token_str}'")
print() # Empty line to match C++ output
outputs = model(**encoded)
hidden_states = outputs.last_hidden_state # Shape: [batch_size, seq_len, hidden_size]
all_embeddings = hidden_states[0].cpu().numpy() # Shape: [seq_len, hidden_size]
print(f"Hidden states shape: {hidden_states.shape}")
print(f"All embeddings shape: {all_embeddings.shape}")
print(f"Embedding dimension: {all_embeddings.shape[1]}")
if len(all_embeddings.shape) == 1:
n_embd = all_embeddings.shape[0] # type: ignore
n_embd_count = 1
all_embeddings = all_embeddings.reshape(1, -1)
else:
n_embd = all_embeddings.shape[1] # type: ignore
n_embd_count = all_embeddings.shape[0] # type: ignore
print()
for j in range(n_embd_count):
embedding = all_embeddings[j]
@@ -120,29 +155,23 @@ with torch.no_grad():
print() # New line
print() # Final empty line to match C++ output
print()
data_dir = Path("data")
data_dir.mkdir(exist_ok=True)
bin_filename = data_dir / f"pytorch-{model_name}-embeddings.bin"
txt_filename = data_dir / f"pytorch-{model_name}-embeddings.txt"
# Save all embeddings flattened (matching what embedding.cpp would save if it did)
flattened_embeddings = all_embeddings.flatten()
flattened_embeddings.astype(np.float32).tofile(bin_filename)
with open(txt_filename, "w") as f:
f.write(f"# Model class: {model_name}\n")
f.write(f"# Tokens: {token_strings}\n")
f.write(f"# Shape: {all_embeddings.shape}\n")
f.write(f"# n_embd_count: {n_embd_count}, n_embd: {n_embd}\n\n")
idx = 0
for j in range(n_embd_count):
f.write(f"# Token {j} ({token_strings[j]}):\n")
for i, value in enumerate(all_embeddings[j]):
f.write(f"{j}_{i}: {value:.6f}\n")
f.write("\n")
print(f"Total values: {len(flattened_embeddings)} ({n_embd_count} tokens × {n_embd} dimensions)")
for value in all_embeddings[j]:
f.write(f"{idx}: {value:.6f}\n")
idx += 1
print(f"Total values: {len(flattened_embeddings)} ({n_embd_count} embeddings × {n_embd} dimensions)")
print("")
print(f"Saved bin embeddings to: {bin_filename}")
print(f"Saved txt embeddings to: {txt_filename}")
@@ -35,7 +35,11 @@ def cosine_similarity(a, b=None):
def load_embeddings_from_file(filename, n_tokens, n_embd):
embeddings = np.fromfile(filename, dtype=np.float32)
return embeddings.reshape(n_tokens, n_embd)
# Check if this is pooled (single embedding) or per-token embeddings
if len(embeddings) == n_embd:
return embeddings.reshape(1, n_embd)
else:
return embeddings.reshape(n_tokens, n_embd)
def test_single_prompt_similarity(python_emb, cpp_emb, tokens, prompt):
np.set_printoptions(suppress=True, precision=6)
@@ -48,58 +52,83 @@ def test_single_prompt_similarity(python_emb, cpp_emb, tokens, prompt):
print(f"Embeddings shape: Python {python_emb.shape}, llama.cpp {cpp_emb.shape}")
n_tokens = len(tokens)
is_pooled = python_emb.shape[0] == 1
# 1. Direct embedding comparison
print(f"\n1. Raw Embedding Magnitude Comparison:")
# Check if the distance of each token embedding from the origin and compare
# if the vectors are on the same "sphere". This does not tell us about
# direction (meaning of the token embedding), just magnitude.
for i in range(n_tokens):
py_mag = np.linalg.norm(python_emb[i]) # calculate standard euclidean norm for Python embeddings
cpp_mag = np.linalg.norm(cpp_emb[i]) # calculate standard euclidean norm for llama.cpp embeddings
if is_pooled:
print(f"\n[Pooled Embeddings Mode - comparing single sentence embeddings]")
# 1. Direct embedding comparison for pooled embeddings
print(f"\n1. Raw Embedding Magnitude Comparison:")
py_mag = np.linalg.norm(python_emb[0])
cpp_mag = np.linalg.norm(cpp_emb[0])
ratio = py_mag / cpp_mag if cpp_mag > 0 else float('inf')
print(f" Token {i} ({tokens[i]}): Python={py_mag:.3f}, llama.cpp={cpp_mag:.3f}, ratio={ratio:.3f}")
print(f" Pooled embedding: Python={py_mag:.3f}, llama.cpp={cpp_mag:.3f}, ratio={ratio:.3f}")
# 2. Cosine similarity between tokens within each model
# Here we check the direction of token embeddings to see if the have the
# same meaning (similarity). This is done by calculating cosine similarity
# of a pair of token embeddings within each model.
print(f"\n2. Within-Model Token Similarities:")
print(" Python model:")
for i in range(n_tokens):
for j in range(i+1, n_tokens):
sim = cosine_similarity([python_emb[i]], [python_emb[j]])[0][0]
print(f" {tokens[i]}{tokens[j]}: {sim:.4f}")
# 2. Cross-model similarity for pooled embeddings
print(f"\n2. Cross-Model Pooled Embedding Similarity:")
sim = cosine_similarity([python_emb[0]], [cpp_emb[0]])[0][0]
print(f" Cosine similarity: {sim:.6f}")
print(" llama.cpp model:")
for i in range(n_tokens):
for j in range(i+1, n_tokens):
sim = cosine_similarity([cpp_emb[i]], [cpp_emb[j]])[0][0]
print(f" {tokens[i]}{tokens[j]}: {sim:.4f}")
return {
'cross_model_similarities': [sim],
'similarity_matrix_diff': np.array([[0.0]]),
'max_diff': 0.0,
'mean_diff': 0.0,
'rms_diff': 0.0
}
else:
# Original per-token comparison logic
# 1. Direct embedding comparison
print(f"\n1. Raw Embedding Magnitude Comparison:")
# Check if the distance of each token embedding from the origin and compare
# if the vectors are on the same "sphere". This does not tell us about
# direction (meaning of the token embedding), just magnitude.
for i in range(n_tokens):
py_mag = np.linalg.norm(python_emb[i]) # calculate standard euclidean norm for Python embeddings
cpp_mag = np.linalg.norm(cpp_emb[i]) # calculate standard euclidean norm for llama.cpp embeddings
ratio = py_mag / cpp_mag if cpp_mag > 0 else float('inf')
print(f" Token {i} ({tokens[i]}): Python={py_mag:.3f}, llama.cpp={cpp_mag:.3f}, ratio={ratio:.3f}")
# 3. Cross-model similarity (same token position)
print(f"\n3. Cross-Model Same-Token Similarities:")
for i in range(n_tokens):
sim = cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0]
print(f" Token {i} ({tokens[i]}): {sim:.4f}")
# 2. Cosine similarity between tokens within each model
# Here we check the direction of token embeddings to see if the have the
# same meaning (similarity). This is done by calculating cosine similarity
# of a pair of token embeddings within each model.
print(f"\n2. Within-Model Token Similarities:")
print(" Python model:")
for i in range(n_tokens):
for j in range(i+1, n_tokens):
sim = cosine_similarity([python_emb[i]], [python_emb[j]])[0][0]
print(f" {tokens[i]}{tokens[j]}: {sim:.4f}")
# 4. Similarity matrix comparison
print(f"\n4. Similarity Matrix Differences:")
py_sim_matrix = cosine_similarity(python_emb)
cpp_sim_matrix = cosine_similarity(cpp_emb)
diff_matrix = np.abs(py_sim_matrix - cpp_sim_matrix)
print(" llama.cpp model:")
for i in range(n_tokens):
for j in range(i+1, n_tokens):
sim = cosine_similarity([cpp_emb[i]], [cpp_emb[j]])[0][0]
print(f" {tokens[i]}{tokens[j]}: {sim:.4f}")
print(f" Max difference: {np.max(diff_matrix):.4f}")
print(f" Mean difference: {np.mean(diff_matrix):.4f}")
print(f" RMS difference: {np.sqrt(np.mean(diff_matrix**2)):.4f}")
# 3. Cross-model similarity (same token position)
print(f"\n3. Cross-Model Same-Token Similarities:")
for i in range(n_tokens):
sim = cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0]
print(f" Token {i} ({tokens[i]}): {sim:.4f}")
return {
'cross_model_similarities': [cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0] for i in range(n_tokens)],
'similarity_matrix_diff': diff_matrix,
'max_diff': np.max(diff_matrix),
'mean_diff': np.mean(diff_matrix),
'rms_diff': np.sqrt(np.mean(diff_matrix**2))
}
# 4. Similarity matrix comparison
print(f"\n4. Similarity Matrix Differences:")
py_sim_matrix = cosine_similarity(python_emb)
cpp_sim_matrix = cosine_similarity(cpp_emb)
diff_matrix = np.abs(py_sim_matrix - cpp_sim_matrix)
print(f" Max difference: {np.max(diff_matrix):.4f}")
print(f" Mean difference: {np.mean(diff_matrix):.4f}")
print(f" RMS difference: {np.sqrt(np.mean(diff_matrix**2)):.4f}")
return {
'cross_model_similarities': [cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0] for i in range(n_tokens)],
'similarity_matrix_diff': diff_matrix,
'max_diff': np.max(diff_matrix),
'mean_diff': np.mean(diff_matrix),
'rms_diff': np.sqrt(np.mean(diff_matrix**2))
}
def read_prompt_from_file(file_path):
try:
+3
View File
@@ -222,6 +222,9 @@ option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation"
option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF)
option(GGML_WEBGPU "ggml: use WebGPU" OFF)
option(GGML_WEBGPU_DEBUG "ggml: enable WebGPU debug output" OFF)
option(GGML_WEBGPU_CPU_PROFILE "ggml: enable WebGPU profiling (CPU)" OFF)
option(GGML_WEBGPU_GPU_PROFILE "ggml: enable WebGPU profiling (GPU)" OFF)
option(GGML_ZDNN "ggml: use zDNN" OFF)
option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT})
option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF)
+2
View File
@@ -215,6 +215,8 @@ extern "C" {
// Backend registry
//
GGML_API void ggml_backend_register(ggml_backend_reg_t reg);
GGML_API void ggml_backend_device_register(ggml_backend_dev_t device);
// Backend (reg) enumeration
+8 -9
View File
@@ -7,26 +7,25 @@
extern "C" {
#endif
#define RPC_PROTO_MAJOR_VERSION 2
#define RPC_PROTO_MAJOR_VERSION 3
#define RPC_PROTO_MINOR_VERSION 0
#define RPC_PROTO_PATCH_VERSION 0
#define GGML_RPC_MAX_SERVERS 16
// backend API
GGML_BACKEND_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint);
GGML_BACKEND_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device);
GGML_BACKEND_API bool ggml_backend_is_rpc(ggml_backend_t backend);
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint);
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device);
GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total);
GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total);
GGML_BACKEND_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint,
const char * cache_dir,
size_t free_mem, size_t total_mem);
GGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
size_t n_threads, size_t n_devices,
ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem);
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void);
GGML_BACKEND_API ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint);
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint);
#ifdef __cplusplus
}
+7
View File
@@ -1630,6 +1630,13 @@ extern "C" {
float scale,
float max_bias);
GGML_API struct ggml_tensor * ggml_soft_max_ext_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * mask,
float scale,
float max_bias);
GGML_API void ggml_soft_max_add_sinks(
struct ggml_tensor * a,
struct ggml_tensor * sinks);
+3
View File
@@ -145,6 +145,9 @@ endif()
# which was introduced in POSIX.1-2008, forcing us to go higher
if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD")
add_compile_definitions(_XOPEN_SOURCE=700)
elseif (CMAKE_SYSTEM_NAME MATCHES "AIX")
# Don't define _XOPEN_SOURCE. We need _ALL_SOURCE, which is the default,
# in order to define _SC_PHYS_PAGES.
else()
add_compile_definitions(_XOPEN_SOURCE=600)
endif()
+16 -14
View File
@@ -392,12 +392,8 @@ static void ggml_dyn_tallocr_free(struct ggml_dyn_tallocr * alloc) {
free(alloc);
}
static size_t ggml_dyn_tallocr_max_size(struct ggml_dyn_tallocr * alloc) {
size_t max_size = 0;
for (int i = 0; i < alloc->n_chunks; i++) {
max_size += alloc->chunks[i]->max_size;
}
return max_size;
static size_t ggml_dyn_tallocr_max_size(struct ggml_dyn_tallocr * alloc, int chunk) {
return chunk < alloc->n_chunks ? alloc->chunks[chunk]->max_size : 0;
}
@@ -417,10 +413,8 @@ static void ggml_vbuffer_free(struct vbuffer * buf) {
free(buf);
}
static int ggml_vbuffer_n_chunks(struct vbuffer * buf) {
int n = 0;
while (n < GGML_VBUFFER_MAX_CHUNKS && buf->chunks[n]) n++;
return n;
static size_t ggml_vbuffer_chunk_size(struct vbuffer * buf, int chunk) {
return buf->chunks[chunk] ? ggml_backend_buffer_get_size(buf->chunks[chunk]) : 0;
}
static size_t ggml_vbuffer_size(struct vbuffer * buf) {
@@ -885,12 +879,20 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
}
}
size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0;
size_t new_size = ggml_dyn_tallocr_max_size(galloc->buf_tallocs[i]);
// even if there are no tensors allocated in this buffer, we still need to allocate it to initialize views
if (new_size > cur_size || galloc->buffers[i] == NULL) {
bool realloc = galloc->buffers[i] == NULL;
size_t new_size = 0;
for (int c = 0; c < galloc->buf_tallocs[i]->n_chunks; c++) {
size_t cur_chunk_size = galloc->buffers[i] ? ggml_vbuffer_chunk_size(galloc->buffers[i], c) : 0;
size_t new_chunk_size = ggml_dyn_tallocr_max_size(galloc->buf_tallocs[i], c);
new_size += new_chunk_size;
if (new_chunk_size > cur_chunk_size) {
realloc = true;
}
}
if (realloc) {
#ifndef NDEBUG
size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0;
GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
#endif
-3
View File
@@ -209,9 +209,6 @@ extern "C" {
void * context;
};
// Internal backend registry API
GGML_API void ggml_backend_register(ggml_backend_reg_t reg);
// Add backend dynamic loading support to the backend
// Initialize the backend
+100 -107
View File
@@ -146,9 +146,7 @@ void ggml_cann_op_unary_gated(
unary_op(ctx, acl_src0, acl_dst);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, acl_dst, acl_src1);
ggml_cann_release_resources(ctx, acl_src0, acl_dst);
if(src1)
ggml_cann_release_resources(ctx, acl_src1);
ggml_cann_release_resources(ctx, acl_src0, acl_src1, acl_dst);
}
/**
@@ -894,14 +892,13 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
}
/**
* @brief Get or expand a cached float32 tensor filled with a scalar value.
* @brief Get or expand a cached tensor filled with a scalar value.
*
* This function manages cached device memory for float32 tensors. If the current
* This function manages cached device memory for tensors. If the current
* cache size is insufficient for the requested tensor shape, the old memory will
* be released and new memory will be allocated. The allocated buffer is then
* initialized either with zeros (when @p value == 0.0f) or with the given scalar
* value using CANN operations. Finally, an aclTensor object is created from the
* cached memory and returned.
* be released and new memory will be allocated. The allocated buffer is
* initialized with the given scalar value using CANN operations.
* Finally, an aclTensor object is created from the cached memory and returned.
*
* @param ctx The CANN backend context that manages device memory.
* @param buffer A pointer to the cached device buffer (will be allocated
@@ -910,17 +907,19 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
* updated when the cache is expanded.
* @param ne The tensor shape array (number of elements in each dimension).
* @param nb The stride size for each dimension.
* @param dtype Data type of cached tensor.
* @param dims The number of tensor dimensions.
* @param value The scalar value used to fill the tensor (supports zero
* initialization via memset or arbitrary values via fill_scalar).
* @return An aclTensor pointer created from the cached buffer.
*/
static aclTensor* get_f32_cache_acl_tensor(
static aclTensor* get_cache_acl_tensor(
ggml_backend_cann_context& ctx,
void** buffer,
int64_t &cache_element,
int64_t* ne,
size_t* nb,
ggml_type dtype,
int64_t dims,
float value) {
// Calculate total number of elements
@@ -928,7 +927,7 @@ static aclTensor* get_f32_cache_acl_tensor(
for (int i = 0; i < dims; i++) {
n_element *= ne[i];
}
size_t size = n_element * sizeof(float);
size_t size = n_element * ggml_type_size(dtype);
// Allocate or expand cache if needed
if (cache_element < n_element) {
@@ -941,19 +940,17 @@ static aclTensor* get_f32_cache_acl_tensor(
cache_element = n_element;
// Initialize cache
if (value == 0.0f) {
ACL_CHECK(aclrtMemsetAsync(*buffer, size, 0, size, ctx.stream()));
} else {
int64_t pool_ne[1] = { n_element };
size_t pool_nb[1] = { sizeof(float) };
aclTensor* acl_value = ggml_cann_create_tensor(
*buffer, ACL_FLOAT, sizeof(float), pool_ne, pool_nb, 1);
aclnn_fill_scalar(ctx, 1, acl_value);
ggml_cann_release_resources(ctx, acl_value);
}
int64_t pool_ne[1] = { n_element };
size_t pool_nb[1] = { ggml_type_size(dtype) };
aclTensor* acl_value = ggml_cann_create_tensor(
*buffer, ggml_cann_type_mapping(dtype), ggml_type_size(dtype),
pool_ne, pool_nb, 1);
aclnn_fill_scalar(ctx, value, acl_value);
ggml_cann_release_resources(ctx, acl_value);
}
return ggml_cann_create_tensor(*buffer, ACL_FLOAT, sizeof(float), ne, nb, dims);
return ggml_cann_create_tensor(*buffer, ggml_cann_type_mapping(dtype),
ggml_type_size(dtype), ne, nb, dims);
}
void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -965,35 +962,39 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
// build gamma, one...
// build gamma.
size_t acl_gamma_nb[GGML_MAX_DIMS];
acl_gamma_nb[0] = sizeof(float);
// gamma's type is the same with dst.
acl_gamma_nb[0] = ggml_type_size(dst->type);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
acl_gamma_nb[i] = acl_gamma_nb[i - 1] * src->ne[i - 1];
}
aclTensor* acl_gamma = get_f32_cache_acl_tensor(
aclTensor* acl_gamma = get_cache_acl_tensor(
ctx,
&ctx.rms_norm_one_tensor_cache.cache,
ctx.rms_norm_one_tensor_cache.size,
src->ne,
acl_gamma_nb,
dst->type,
1, // dims
1.0f // value
);
// build rstd, zero...
// build rstd.
int64_t acl_rstd_ne[] = {src->ne[1], src->ne[2], src->ne[3]};
size_t acl_rstd_nb[GGML_MAX_DIMS - 1];
// rstd will always be F32.
acl_rstd_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS - 1; i++) {
acl_rstd_nb[i] = acl_rstd_nb[i - 1] * acl_rstd_ne[i - 1];
}
aclTensor* acl_rstd = get_f32_cache_acl_tensor(
aclTensor* acl_rstd = get_cache_acl_tensor(
ctx,
&ctx.rms_norm_zero_tensor_cache.cache,
ctx.rms_norm_zero_tensor_cache.size,
acl_rstd_ne,
acl_rstd_nb,
GGML_TYPE_F32,
GGML_MAX_DIMS - 1,
0.0f // value
);
@@ -1765,33 +1766,35 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
ggml_tensor* src0 = dst->src[0]; // src
ggml_tensor* src1 = dst->src[1]; // index
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
switch (src0->type) {
case GGML_TYPE_F32: {
aclnn_index_select_4d(ctx, src0->data, src0->ne, src0->nb,
dst->data, dst->ne, dst->nb,
src1, dst->type);
break;
}
case GGML_TYPE_F16: {
aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
ggml_cann_pool_alloc src_buffer_allocator(
ctx.pool(), ggml_nelements(src0) * sizeof(float));
void* src_trans_buffer = src_buffer_allocator.get();
size_t src_trans_nb[GGML_MAX_DIMS];
src_trans_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
case GGML_TYPE_F16:
case GGML_TYPE_F32:
if(src0->type == dst->type) {
aclnn_index_select_4d(ctx, src0->data, src0->ne, src0->nb,
dst->data, dst->ne, dst->nb,
src1, dst->type);
} else {
aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
ggml_cann_pool_alloc src_buffer_allocator(
ctx.pool(), ggml_nelements(src0) * ggml_element_size(dst));
void* src_trans_buffer = src_buffer_allocator.get();
size_t src_trans_nb[GGML_MAX_DIMS];
src_trans_nb[0] = dst->nb[0];
for (int i = 1; i < GGML_MAX_DIMS; i++) {
src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
}
aclTensor* src_trans_tensor = ggml_cann_create_tensor(
src_trans_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
src0->ne, src_trans_nb, GGML_MAX_DIMS);
aclnn_cast(ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping(dst->type));
aclnn_index_select_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb,
dst->data, dst->ne, dst->nb,
src1, dst->type);
ggml_cann_release_resources(ctx, acl_src0, src_trans_tensor);
}
aclTensor* src_trans_tensor = ggml_cann_create_tensor(
src_trans_buffer, ACL_FLOAT, ggml_type_size(dst->type),
src0->ne, src_trans_nb, GGML_MAX_DIMS);
aclnn_cast(ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping(dst->type));
aclnn_index_select_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb,
dst->data, dst->ne, dst->nb,
src1, dst->type);
ggml_cann_release_resources(ctx, acl_src0, src_trans_tensor);
break;
}
case GGML_TYPE_Q8_0: {
// add 1 dim for bcast mul.
size_t weight_nb[GGML_MAX_DIMS + 1], scale_nb[GGML_MAX_DIMS + 1],
@@ -1799,7 +1802,6 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
int64_t weight_ne[GGML_MAX_DIMS + 1], scale_ne[GGML_MAX_DIMS + 1],
*dequant_ne;
int64_t scale_offset = 0;
// [3,4,5,64] -> [3,4,5,2,32]
weight_ne[0] = QK8_0;
weight_ne[1] = src0->ne[0] / QK8_0;
@@ -1809,7 +1811,6 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
weight_ne[i] = src0->ne[i - 1];
weight_nb[i] = weight_nb[i - 1] * weight_ne[i - 1];
}
// [3,4,5,64] -> [3,4,5,2,1]
scale_ne[0] = 1;
scale_ne[1] = src0->ne[0] / QK8_0;
@@ -1819,18 +1820,15 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
scale_ne[i] = src0->ne[i - 1];
scale_nb[i] = scale_nb[i - 1] * scale_ne[i - 1];
}
// [3,4,5,64] -> [3,4,5,2,32]
dequant_ne = weight_ne;
dequant_nb[0] = sizeof(float);
dequant_nb[0] = ggml_type_size(dst->type);
for (int i = 1; i < GGML_MAX_DIMS + 1; i++) {
dequant_nb[i] = dequant_nb[i - 1] * dequant_ne[i - 1];
}
scale_offset = ggml_nelements(src0) * sizeof(int8_t);
ggml_cann_pool_alloc dequant_buffer_allocator(
ctx.pool(), ggml_nelements(src0) * sizeof(float));
ctx.pool(), ggml_nelements(src0) * ggml_type_size(dst->type));
aclTensor* acl_weight_tensor = ggml_cann_create_tensor(
src0->data, ACL_INT8, sizeof(int8_t), weight_ne, weight_nb,
GGML_MAX_DIMS + 1);
@@ -1838,22 +1836,20 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb,
GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset);
aclTensor* dequant_tensor = ggml_cann_create_tensor(
dequant_buffer_allocator.get(), ACL_FLOAT, sizeof(float),
dequant_buffer_allocator.get(), ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
dequant_ne, dequant_nb, GGML_MAX_DIMS + 1);
aclnn_mul(ctx, acl_weight_tensor, acl_scale_tensor, dequant_tensor);
dequant_nb[0] = sizeof(float);
dequant_nb[0] = ggml_type_size(dst->type);
dequant_ne = src0->ne;
for (int i = 1; i < GGML_MAX_DIMS; i++) {
dequant_nb[i] = dequant_nb[i - 1] * src0->ne[i - 1];
}
aclnn_index_select_4d(ctx, dequant_buffer_allocator.get(),
dequant_ne, dequant_nb,
dst->data, dst->ne, dst->nb,
src1, dst->type);
ggml_cann_release_resources(ctx, dequant_tensor);
ggml_cann_release_resources(ctx, acl_weight_tensor, acl_scale_tensor, dequant_tensor);
break;
}
default:
@@ -1965,16 +1961,8 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx,
// Only check env once.
static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on"));
if (weight_to_nz && is_matmul_weight(weight)) {
int64_t acl_stride[2] = {1, transpose_ne[1]};
// Reverse ne.
std::reverse(transpose_ne, transpose_ne + n_dims);
std::vector<int64_t> storageDims = {transpose_ne[0], transpose_ne[1]};
acl_weight_tensor = aclCreateTensor(
transpose_ne, n_dims, ggml_cann_type_mapping(weight->type), acl_stride,
0, ACL_FORMAT_FRACTAL_NZ, storageDims.data(), 2, weight->data);
acl_weight_tensor =
ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ);
} else {
acl_weight_tensor =
ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_ND);
@@ -3178,7 +3166,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
aclTensor* acl_src0_f16_tensor = nullptr;
aclTensor* acl_src1_f16_tensor = nullptr;
aclTensor* acl_src2_f16_tensor = nullptr;
aclTensor* acl_dst_f16_tensor = nullptr;
// Step 1: cast the src0 (Query) to fp16 if needed
ggml_cann_pool_alloc src0_f16_allocator(ctx.pool());
@@ -3216,22 +3203,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
acl_src2_f16_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne,
src2_bsnd_nb, GGML_MAX_DIMS);
ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
void* out_f16_buffer = out_f16_allocator.alloc(
ggml_nelements(dst) * faElemSize);
int64_t* out_f16_ne = src0_bsnd_ne;
size_t out_f16_nb[GGML_MAX_DIMS];
out_f16_nb[0] = faElemSize;
for(int i = 1; i < GGML_MAX_DIMS; ++i){
out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1];
}
acl_dst_f16_tensor = ggml_cann_create_tensor(
out_f16_buffer, faDataType, faElemSize,
out_f16_ne, out_f16_nb, GGML_MAX_DIMS
);
// Step 3: create the PSEShift tensor if needed
// this tensor is considered as mask (f16) in the llama.cpp
aclTensor* bcast_pse_tensor = nullptr;
@@ -3317,8 +3288,8 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
aclTensor* acl_q_tensor = acl_src0_f16_tensor;
aclTensor* acl_k_tensors[] = {acl_src1_f16_tensor};
aclTensor* acl_v_tensors[] = {acl_src2_f16_tensor};
auto acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum);
auto acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum);
aclTensorList* acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum);
aclTensorList* acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum);
int64_t numHeads = src0->ne[2]; // N
int64_t numKeyValueHeads = src1->ne[2];
@@ -3334,8 +3305,29 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
int64_t keyAntiquantMode = 0;
int64_t valueAntiquantMode = 0;
// Step 5: launch the FusedInferAttentionScoreV2 kernel.
// Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
aclTensor * fa_dst_tensor = nullptr;
aclTensor * acl_dst_tensor = nullptr;
ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
if (dst->type == GGML_TYPE_F32) {
void* out_f16_buffer = out_f16_allocator.alloc(
ggml_nelements(dst) * faElemSize);
int64_t* out_f16_ne = src0_bsnd_ne;
size_t out_f16_nb[GGML_MAX_DIMS];
out_f16_nb[0] = faElemSize;
for(int i = 1; i < GGML_MAX_DIMS; ++i){
out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1];
}
fa_dst_tensor = ggml_cann_create_tensor(
out_f16_buffer, faDataType, faElemSize,
out_f16_ne, out_f16_nb, GGML_MAX_DIMS
);
}
else {
fa_dst_tensor = ggml_cann_create_tensor(dst);
}
GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2,
acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v
@@ -3357,23 +3349,24 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
blockSize, antiquantMode, // blockSize, antiquantMode
softmaxLseFlag, // softmaxLseFlag
keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode
acl_dst_f16_tensor, // attentionOut
fa_dst_tensor, // attentionOut
nullptr // softmaxLse
);
// Step 6: post-processing, permute and cast to f32
aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
// TODO: when dst is fp16, don't need cast
aclnn_cast(ctx, acl_dst_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
ggml_cann_release_resources(ctx, acl_src0_f16_tensor,
acl_src1_f16_tensor,
acl_src2_f16_tensor,
acl_dst_f16_tensor,
acl_dst_tensor);
if(src3 != nullptr){
ggml_cann_release_resources(ctx, bcast_pse_tensor);
if (dst->type == GGML_TYPE_F32) {
// Step 6: post-processing, permute and cast to f32
aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
aclnn_cast(ctx, fa_dst_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
}
}else{
ggml_cann_release_resources(ctx, acl_src0_f16_tensor,
acl_k_tensor_list,
acl_v_tensor_list,
fa_dst_tensor,
acl_dst_tensor,
bcast_pse_tensor);
} else {
GGML_ABORT("Function is not implemented.");
}
}
+8 -1
View File
@@ -341,11 +341,18 @@ private:
#ifdef USE_ACL_GRAPH
struct ggml_graph_node_properties {
// dst tensor
void * node_address;
ggml_op node_op;
int64_t ne[GGML_MAX_DIMS];
size_t nb[GGML_MAX_DIMS];
// src tensor
void * src_address[GGML_MAX_SRC];
int64_t src_ne[GGML_MAX_SRC][GGML_MAX_DIMS];
size_t src_nb[GGML_MAX_SRC][GGML_MAX_DIMS];
// op
ggml_op node_op;
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
};
+37 -11
View File
@@ -2186,7 +2186,15 @@ static void add_lru_matched_graph_node_properties(
std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb);
for (int src = 0; src < GGML_MAX_SRC; ++src) {
prop.src_address[src] = node->src[src] ? node->src[src]->data : nullptr;
if (node->src[src]) {
prop.src_address[src] = node->src[src]->data;
std::copy_n(node->src[src]->ne, GGML_MAX_DIMS, prop.src_ne[src]);
std::copy_n(node->src[src]->nb, GGML_MAX_DIMS, prop.src_nb[src]);
} else {
prop.src_address[src] = nullptr;
std::fill_n(prop.src_ne[src], GGML_MAX_DIMS, 0);
std::fill_n(prop.src_nb[src], GGML_MAX_DIMS, 0);
}
}
memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS);
@@ -2206,14 +2214,18 @@ static void add_lru_matched_graph_node_properties(
* @param graph_node_properties The stored properties of a CANN graph node.
* @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
*/
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
static bool ggml_graph_node_has_matching_properties(
ggml_tensor * node,
ggml_graph_node_properties * graph_node_properties) {
if (node->data != graph_node_properties->node_address &&
node->op != GGML_OP_VIEW) {
node->op != GGML_OP_VIEW) {
return false;
}
if (node->op != graph_node_properties->node_op) {
return false;
}
for (int i = 0; i < GGML_MAX_DIMS; i++) {
if (node->ne[i] != graph_node_properties->ne[i]) {
return false;
@@ -2222,17 +2234,31 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
return false;
}
}
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (node->src[i] &&
node->src[i]->data != graph_node_properties->src_address[i] &&
node->op != GGML_OP_VIEW
) {
return false;
if (node->src[i]) {
if (node->src[i]->data != graph_node_properties->src_address[i] &&
node->op != GGML_OP_VIEW) {
return false;
}
for (int d = 0; d < GGML_MAX_DIMS; d++) {
if (node->src[i]->ne[d] != graph_node_properties->src_ne[i][d]) {
return false;
}
if (node->src[i]->nb[d] != graph_node_properties->src_nb[i][d]) {
return false;
}
}
} else {
if (graph_node_properties->src_address[i] != nullptr) {
return false;
}
}
}
if (node->op == GGML_OP_SCALE &&
memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
return false;
if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) {
return memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0;
}
return true;
}
+1
View File
@@ -149,6 +149,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
if (op->op == GGML_OP_MUL_MAT && is_contiguous_2d(op->src[0]) && // src0 must be contiguous
is_contiguous_2d(op->src[1]) && // src1 must be contiguous
op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_amx_buffer_type() &&
op->src[0]->ne[0] % (TILE_K * 2 * 32) == 0 && // TODO: not sure if correct (https://github.com/ggml-org/llama.cpp/pull/16315)
op->ne[0] % (TILE_N * 2) == 0 && // out_features is 32x
(qtype_has_amx_kernels(op->src[0]->type) || (op->src[0]->type == GGML_TYPE_F16))) {
// src1 must be host buffer
+1 -1
View File
@@ -68,7 +68,7 @@ struct ggml_compute_params {
#endif // __VXE2__
#endif // __s390x__ && __VEC__
#if defined(__ARM_FEATURE_SVE)
#if defined(__ARM_FEATURE_SVE) && defined(__linux__)
#include <sys/prctl.h>
#endif
+6 -1
View File
@@ -689,8 +689,13 @@ bool ggml_is_numa(void) {
#endif
static void ggml_init_arm_arch_features(void) {
#if defined(__linux__) && defined(__aarch64__) && defined(__ARM_FEATURE_SVE)
#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE)
#if defined(__linux__)
ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
#else
// TODO: add support of SVE for non-linux systems
#error "TODO: SVE is not supported on this platform. To use SVE, sve_cnt needs to be initialized here."
#endif
#endif
}
+209 -96
View File
@@ -29,6 +29,108 @@
#define NELEMS(x) sizeof(x) / sizeof(*x)
template<size_t(*Fn)(size_t,size_t,size_t)>
static inline size_t kernel_offs_fn3(size_t a, size_t b, size_t c) {
return Fn(a, b, c);
}
template<size_t(*Fn)(size_t,size_t)>
static inline size_t kernel_offs_fn2(size_t a, size_t b, size_t) {
return Fn(a, b);
}
template<void(*Fn)(size_t,size_t,size_t,size_t,const void*,const void*,float*,size_t,size_t,float,float)>
static inline void kernel_run_fn11(size_t m, size_t n, size_t k, size_t bl,
const void* lhs, const void* rhs, void* dst,
size_t dst_stride_row, size_t dst_stride_col,
float clamp_min, float clamp_max) {
Fn(m, n, k, bl, lhs, rhs, static_cast<float*>(dst), dst_stride_row, dst_stride_col, clamp_min, clamp_max);
}
template<void(*Fn)(size_t,size_t,size_t,const void*,const void*,void*,size_t,size_t,float,float)>
static inline void kernel_run_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
const void* lhs, const void* rhs, void* dst,
size_t dst_stride_row, size_t dst_stride_col,
float clamp_min, float clamp_max) {
Fn(m, n, k, lhs, rhs, dst, dst_stride_row, dst_stride_col, clamp_min, clamp_max);
}
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t)>
static inline size_t lhs_ps_fn6(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
return Fn(m, k, bl, mr, kr, sr);
}
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)>
static inline size_t lhs_ps_fn5(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr) {
return Fn(m, k, mr, kr, sr);
}
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t)>
static inline size_t lhs_offs_fn6(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
return Fn(m_idx, k, bl, mr, kr, sr);
}
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)>
static inline size_t lhs_offs_fn5(size_t m_idx, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr) {
return Fn(m_idx, k, mr, kr, sr);
}
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const float*,size_t,void*)>
static inline void lhs_pack_float_fn10(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr,
size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) {
Fn(m, k, bl, mr, kr, sr, m_idx_start, static_cast<const float*>(lhs), lhs_stride, lhs_packed);
}
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const void*,size_t,void*)>
static inline void lhs_pack_void_fn10(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr,
size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) {
Fn(m, k, bl, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed);
}
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const void*,size_t,void*)>
static inline void lhs_pack_void_fn9(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr,
size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) {
Fn(m, k, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed);
}
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)>
static inline size_t rhs_ps_fn5(size_t n, size_t k, size_t nr, size_t kr, size_t bl) {
return Fn(n, k, nr, kr, bl);
}
template<size_t(*Fn)(size_t,size_t)>
static inline size_t rhs_ps_fn2(size_t n, size_t k, size_t /*nr*/, size_t /*kr*/, size_t /*bl*/) {
return Fn(n, k);
}
template<size_t(*Fn)(size_t,size_t,size_t,size_t)>
static inline size_t rhs_stride_fn4(size_t k, size_t nr, size_t kr, size_t bl) {
return Fn(k, nr, kr, bl);
}
template<size_t(*Fn)(size_t)>
static inline size_t rhs_stride_fn1(size_t k, size_t /*nr*/, size_t /*kr*/, size_t /*bl*/) {
return Fn(k);
}
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const uint8_t*,const float*,void*,size_t,const struct kai_rhs_pack_qs4cxs1s0_param*)>
static inline void rhs_pack_fn12(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl,
size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* /*scale*/,
void* rhs_packed, size_t extra_bytes, const void* params) {
Fn(num_groups, n, k, nr, kr, sr, bl,
static_cast<const uint8_t*>(rhs),
static_cast<const float*>(bias),
rhs_packed, extra_bytes,
static_cast<const kai_rhs_pack_qs4cxs1s0_param*>(params));
}
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const void*,const void*,const void*,void*,size_t,const void*)>
static inline void rhs_pack_fn13(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
size_t rhs_stride, const void* rhs, const void* bias, const void* scale,
void* rhs_packed, size_t extra_bytes, const void* params) {
Fn(num_groups, n, k, nr, kr, sr, rhs_stride, rhs, bias, scale, rhs_packed, extra_bytes, params);
}
static const size_t INT4_PER_BYTE = 2;
static const size_t INT4_BITS = 4;
static const int Q4_0_ZERO_POINT = 8;
@@ -122,17 +224,18 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa>,
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa>,
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa>,
},
/* .gemm_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon,
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon>,
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon>,
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32_neon>,
},
/* SME GEMV */
/* .kern_info = */ {
@@ -142,23 +245,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot>,
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot>,
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot>,
},
/* .gemv_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon,
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon>,
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon>,
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32_neon>,
},
/* .rhs_info = */ {
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
/* .to_float = */ dequantize_row_qsi4c32ps1s0scalef16,
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
/* .to_float = */ dequantize_row_qsi4c32ps1s0scalef16,
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon>,
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon>,
/* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon>,
},
/* .required_cpu = */ CPU_FEATURE_SME,
/* .lhs_type = */ GGML_TYPE_F32,
@@ -174,17 +278,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa>,
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa>,
/* .run_kernel_ex = */ &kernel_run_fn10<kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa>,
},
/* .gemm_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme,
/* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme>,
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme>,
/* .pack_func_ex = */ &lhs_pack_void_fn9<kai_run_lhs_pack_bf16p2vlx2_f32_sme>,
},
/* SME GEMV */
/* .kern_info = */ {
@@ -194,23 +298,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_lhs_offset_ex = */ nullptr,
/* .get_rhs_packed_offset_ex = */ nullptr,
/* .run_kernel_ex = */ nullptr,
},
/* .gemv_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme,
/* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme>,
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme>,
/* .pack_func_ex = */ &lhs_pack_void_fn9<kai_run_lhs_pack_bf16p2vlx2_f32_sme>,
},
/* .rhs_info = */ {
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
/* .packed_stride = */ NULL,
/* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
/* .to_float = */ NULL,
/* .packed_stride = */ nullptr,
/* .to_float = */ nullptr,
/* .packed_size_ex = */ &rhs_ps_fn2<kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme>,
/* .packed_stride_ex = */ &rhs_stride_fn1<kai_get_rhs_packed_stride_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme>,
/* .pack_func_ex = */ &rhs_pack_fn13<kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme>,
},
/* .required_cpu = */ CPU_FEATURE_SME,
/* .lhs_type = */ GGML_TYPE_F32,
@@ -229,17 +334,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,
},
/* .gemm_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
},
/* DOTPROD GEMV */
/* .kern_info = */ {
@@ -249,23 +354,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,
},
/* .gemv_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
},
/* .rhs_info = */ {
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
/* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
},
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
/* .lhs_type = */ GGML_TYPE_F32,
@@ -283,17 +389,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,
},
/* .gemm_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
},
/* i8mm GEMV */
/* .kern_info = */ {
@@ -303,23 +409,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,
},
/* .gemv_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
},
/* .rhs_info = */ {
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
/* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
},
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
/* .lhs_type = */ GGML_TYPE_F32,
@@ -338,17 +445,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,
},
/* .gemm_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
},
/* i8mm GEMV */
/* .kern_info = */ {
@@ -358,23 +465,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,
},
/* .gemv_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
},
/* .rhs_info = */ {
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
/* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
},
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
/* .lhs_type = */ GGML_TYPE_F32,
@@ -392,17 +500,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,
},
/* .gemm_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
},
/* DOTPROD GEMV */
/* .kern_info = */ {
@@ -412,23 +520,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,
},
/* .gemv_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
},
/* .rhs_info = */ {
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
/* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
},
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
/* .lhs_type = */ GGML_TYPE_F32,
@@ -443,6 +552,7 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
ggml_kleidiai_kernels * kernel = nullptr;
if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) {
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu &&
gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type &&
@@ -452,6 +562,7 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
break;
}
}
#endif
}
return kernel;
@@ -460,12 +571,14 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) {
ggml_kleidiai_kernels * kernels = nullptr;
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) {
kernels = &gemm_gemv_kernels[i];
break;
}
}
#endif
return kernels;
}
+32 -44
View File
@@ -4,8 +4,6 @@
#pragma once
#include <functional>
#include <variant>
#include "ggml.h"
enum cpu_feature {
@@ -15,6 +13,7 @@ enum cpu_feature {
CPU_FEATURE_SVE = 4,
CPU_FEATURE_SME = 8
};
inline cpu_feature& operator|=(cpu_feature& lhs, cpu_feature rhs) {
lhs = static_cast<cpu_feature>(lhs | rhs);
return lhs;
@@ -30,63 +29,52 @@ struct kernel_info {
size_t (*get_nr)(void);
size_t (*get_kr)(void);
size_t (*get_sr)(void);
std::variant<
std::function<size_t(size_t n_idx, size_t k, size_t bl)>,
std::function<size_t(size_t m_idx, size_t k)>
> get_lhs_offset;
std::variant<
std::function<size_t(size_t n_idx, size_t k, size_t bl)>,
std::function<size_t(size_t n_idx, size_t k)>
> get_rhs_packed_offset;
size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride);
size_t (*get_dst_size)(size_t m, size_t n);
std::variant<
std::function<void(size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed,
float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max)>,
std::function<void(size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row,
size_t dst_stride_col, float clamp_min, float clamp_max)>
> run_kernel;
size_t (*get_lhs_offset_ex)(size_t m_idx, size_t k, size_t bl);
size_t (*get_rhs_packed_offset_ex)(size_t n_idx, size_t k, size_t bl);
void (*run_kernel_ex)(
size_t m, size_t n, size_t k, size_t bl,
const void* lhs_packed, const void* rhs_packed,
void* dst, size_t dst_stride_row, size_t dst_stride_col,
float clamp_min, float clamp_max);
};
struct lhs_packing_info {
size_t (*get_offset)(size_t m_idx, size_t lhs_stride);
std::variant<
std::function<size_t(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr)>,
std::function<size_t(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr)>
> get_packed_offset;
std::variant<
std::function<size_t(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr)>,
std::function<size_t(size_t m, size_t k, size_t mr, size_t kr, size_t sr)>
> packed_size;
std::variant<
std::function<void(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
size_t lhs_stride, void* lhs_packed)>,
std::function<void(size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride,
void* lhs_packed)>
> pack_func;
size_t (*get_packed_offset_ex)(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
size_t (*packed_size_ex)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
void (*pack_func_ex)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr,
size_t m_idx_start, const void * lhs, size_t lhs_stride, void * lhs_packed);
};
struct rhs_packing_info {
std::variant<
std::function<size_t(size_t n, size_t k, size_t nr, size_t kr, size_t bl)>,
std::function<size_t(size_t n, size_t k)>
> packed_size;
size_t (*packed_stride)(size_t k, size_t nr, size_t kr, size_t bl);
std::variant<
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params)>,
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params)>
> pack_func;
void (*to_float)(const void *packed_data, int32_t row_idx, int64_t nc, float *out, size_t nr_pack, size_t packed_row_stride,
size_t kr, size_t bl, size_t num_bytes_multiplier);
void (*to_float)(const void *packed_data, int32_t row_idx, int64_t nc, float *out,
size_t nr_pack, size_t packed_row_stride, size_t kr, size_t bl,
size_t num_bytes_multiplier);
size_t (*packed_size_ex)(size_t n, size_t k, size_t nr, size_t kr, size_t bl);
size_t (*packed_stride_ex)(size_t k, size_t nr, size_t kr, size_t bl);
void (*pack_func_ex)(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl,
size_t rhs_stride, const void * rhs, const void * bias, const void * scale, void * rhs_packed, size_t extra_bytes, const void * params);
};
struct ggml_kleidiai_kernels {
kernel_info gemm;
kernel_info gemm;
lhs_packing_info gemm_lhs_info;
kernel_info gemv;
kernel_info gemv;
lhs_packing_info gemv_lhs_info;
rhs_packing_info rhs_info;
+51 -73
View File
@@ -8,6 +8,7 @@
#include <stdexcept>
#include <stdint.h>
#include <string.h>
#include <string>
#if defined(__linux__)
#include <asm/hwcap.h>
#include <sys/auxv.h>
@@ -87,40 +88,6 @@ static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
return tensor->ne[dim];
}
template <typename Variant, typename Ret, typename... Args, std::size_t... Is>
constexpr bool variant_any_invocable_impl(std::index_sequence<Is...>) {
using V = std::remove_reference_t<Variant>;
return (std::is_invocable_r_v<
Ret,
std::variant_alternative_t<Is, V>,
Args...> || ...);
}
template <typename Variant, typename Ret, typename... Args>
constexpr bool variant_any_invocable_v =
variant_any_invocable_impl<Variant, Ret, Args...>(
std::make_index_sequence<
std::variant_size_v<std::remove_reference_t<Variant>>>{});
template<typename Ret, typename Variant, typename... Args>
static inline Ret variant_call(Variant && var, Args&&... args) {
static_assert(variant_any_invocable_v<std::remove_reference_t<Variant>, Ret, Args...>,
"No alternative in Variant is invocable with the provided arguments and return type.");
return std::visit(
[&](auto && f) -> Ret {
using F = std::decay_t<decltype(f)>;
if constexpr (std::is_invocable_r_v<Ret, F, Args...>) {
return std::invoke(std::forward<decltype(f)>(f), std::forward<Args>(args)...);
} else {
GGML_ABORT("Invalid function type in variant_call");
GGML_UNREACHABLE();
}
},
std::forward<Variant>(var)
);
}
namespace ggml::cpu::kleidiai {
static size_t round_down(size_t x, size_t y) {
@@ -145,7 +112,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
return false;
}
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
GGML_ASSERT(kernels);
if (!kernels) {
return false;
}
bool is_gemv = op->src[1]->ne[1] == 1;
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
@@ -159,16 +128,18 @@ class tensor_traits : public ggml::cpu::tensor_traits {
size_t sr = kernel->get_sr();
if (kernels->rhs_type == GGML_TYPE_Q4_0) {
size = variant_call<size_t>(lhs_info->packed_size, m, k, QK4_0, mr, kr, sr);
if (!lhs_info->packed_size_ex) return false;
size = lhs_info->packed_size_ex(m, k, QK4_0, mr, kr, sr);
} else if (kernels->rhs_type == GGML_TYPE_F16) {
if (!lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex) return false;
const int64_t lhs_batch_size0 = op->src[1]->ne[2];
const int64_t rhs_batch_size0 = op->src[0]->ne[2];
const int64_t r = lhs_batch_size0 / rhs_batch_size0;
size = variant_call<size_t>(lhs_info->packed_size, m * r, k, mr, kr, sr) +
variant_call<size_t>(kernels->rhs_info.packed_size, n, k) +
size = lhs_info->packed_size_ex(m * r, k, 0, mr, kr, sr) +
kernels->rhs_info.packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), 0) +
k * n * sizeof(float) + n * sizeof(float);
} else {
GGML_ASSERT(false);
return false;
}
return true;
@@ -196,12 +167,18 @@ class tensor_traits : public ggml::cpu::tensor_traits {
GGML_TENSOR_BINARY_OP_LOCALS
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
GGML_ASSERT(kernels);
if (!kernels) {
return false;
}
const bool is_gemv = src1->ne[1] == 1;
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
GGML_ASSERT(kernel);
if (!kernels->rhs_info.pack_func_ex ||
!kernel->get_lhs_offset_ex || !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex) {
return false;
}
const int nth = params->nth;
const int ith = params->ith;
@@ -228,10 +205,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
const int64_t kr = (int64_t) kernel->get_kr();
const int64_t sr = (int64_t) kernel->get_sr();
const size_t lhs_packed_size = variant_call<size_t>(lhs_info->packed_size, (size_t)m, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, (size_t)n, (size_t)k);
const size_t kxn_size = (size_t)k * (size_t)n * sizeof(float);
const size_t bias_size = (size_t)n * sizeof(float);
const size_t lhs_packed_size = lhs_info->packed_size_ex(m, k, 0, mr, kr, sr);
const size_t rhs_packed_size = kernels->rhs_info.packed_size_ex(n, k, nr, kr, 0);
const size_t kxn_size = k * n * sizeof(float);
const size_t bias_size = n * sizeof(float);
const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size;
GGML_ASSERT(wsize_required <= params->wsize);
@@ -259,10 +236,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
const int64_t m_count = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
// Base packed offset (aligned) and per-row stride in bytes
const size_t base_packed_off = variant_call<size_t>(
lhs_info->get_packed_offset, (size_t)m_start, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
const size_t next_block_off = variant_call<size_t>(
lhs_info->get_packed_offset, (size_t)(m_start + mr), (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
const size_t base_packed_off = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr);
const size_t next_block_off = lhs_info->get_packed_offset_ex(m_start + mr, k, 0, mr, kr, sr);
const size_t row_stride_bytes = (next_block_off - base_packed_off) / (size_t)mr;
int64_t remaining = m_count;
@@ -278,9 +253,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes;
void * dst_ptr = lhs_packed + dst_off;
variant_call<void>(lhs_info->pack_func,
(size_t)take, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr,
/*m_idx_start*/ 0, src_ptr, lhs_stride, dst_ptr);
lhs_info->pack_func_ex(take, k, 0, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
cur += take;
remaining -= take;
@@ -296,10 +269,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
reinterpret_cast<const uint16_t *>(rhs_batch_base),
rhs_stride);
variant_call<void>(kernels->rhs_info.pack_func,
/*num_groups*/ 1, (size_t)n, (size_t)k, (size_t)nr, (size_t)kr, (size_t)sr,
/*rhs_stride (bytes)*/ (size_t)(n * sizeof(float)),
rhs_kxn, bias, nullptr, rhs_packed, /*extra_bytes*/ 0, /*params*/ nullptr);
kernels->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, 0, n * sizeof(float),
rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr);
}
ggml_barrier(params->threadpool);
@@ -320,20 +291,15 @@ class tensor_traits : public ggml::cpu::tensor_traits {
const int64_t n_to_process = (ith == num_threads_n - 1) ? num_n_per_threadN_1 : num_n_per_thread0;
// LHS packed base at row 0 (consistent with packing above)
const size_t lhs_packed_offset0 = variant_call<size_t>(
lhs_info->get_packed_offset, (size_t)0, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, (size_t)n_start, (size_t)k);
const size_t dst_offset = kernel->get_dst_offset((size_t)0, (size_t)n_start, dst_stride);
const size_t lhs_packed_offset0 = lhs_info->get_packed_offset_ex(0, k, 0, mr, kr, sr);
const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, 0);
const size_t dst_offset = kernel->get_dst_offset((size_t)0, (size_t)n_start, dst_stride);
const void * lhs_ptr = lhs_packed + lhs_packed_offset0;
const void * rhs_ptr = rhs_packed + rhs_packed_offset;
float * dst_ptr = reinterpret_cast<float *>(dst_batch_base + dst_offset);
variant_call<void>(kernel->run_kernel,
(size_t)m, (size_t)n_to_process, (size_t)k,
lhs_ptr, rhs_ptr,
dst_ptr, dst_stride, sizeof(float),
-FLT_MAX, FLT_MAX);
kernel->run_kernel_ex(m, n_to_process, k, 0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
}
}
@@ -354,13 +320,19 @@ class tensor_traits : public ggml::cpu::tensor_traits {
GGML_TENSOR_BINARY_OP_LOCALS
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
GGML_ASSERT(kernels);
if (!kernels) {
return false;
}
bool is_gemv = src1->ne[1] == 1;
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
GGML_ASSERT(kernel);
if (!lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex ||
!kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) {
return false;
}
const int ith = params->ith;
const int nth_raw = params->nth;
@@ -402,25 +374,26 @@ class tensor_traits : public ggml::cpu::tensor_traits {
// Transform LHS
const size_t src_stride = src1->nb[1];
const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, m_start, k, QK4_0, mr, kr, sr);
const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, QK4_0, mr, kr, sr);
void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
variant_call<void>(lhs_info->pack_func, m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
// Pack this thread's chunk with m_idx_start = 0 and per-thread output pointer
lhs_info->pack_func_ex(m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
}
ggml_barrier(params->threadpool);
// Perform the operation
const size_t dst_stride = dst->nb[1];
const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, 0, k, QK4_0, mr, kr, sr);
const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k, QK4_0);
const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, QK4_0, mr, kr, sr);
const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, QK4_0);
const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
if (n_to_process > 0) {
variant_call<void>(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
kernel->run_kernel_ex(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
sizeof(float), -FLT_MAX, FLT_MAX);
}
@@ -429,7 +402,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
GGML_ASSERT(ctx.kernels);
if (!ctx.kernels) {
return false;
}
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
@@ -438,6 +413,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
rhs_packing_info * rhs_info = &ctx.kernels->rhs_info;
kernel_info * kernel = &ctx.kernels->gemm;
if (!rhs_info->to_float || !kernel->get_nr) {
return false;
}
const int64_t nc = ne00;
const int64_t nr = ggml_nelements(src1);
@@ -480,7 +458,7 @@ public:
struct kai_rhs_pack_qs4cxs1s0_param params;
params.lhs_zero_point = 1;
params.rhs_zero_point = 8;
variant_call<void>(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, &params);
ctx.kernels->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0, (const uint8_t*)data, nullptr, nullptr, tensor->data, 0, &params);
return 0;
GGML_UNUSED(data_size);
@@ -548,7 +526,7 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_
const size_t nr = ctx.kernels->gemm.get_nr();
const size_t kr = ctx.kernels->gemm.get_kr();
return variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
return ctx.kernels->rhs_info.packed_size_ex(n, k, nr, kr, QK4_0);
GGML_UNUSED(buft);
}
+11 -15
View File
@@ -3467,31 +3467,27 @@ static void ggml_compute_forward_norm_f32(
GGML_ASSERT(eps >= 0.0f);
// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
ggml_float sum = 0.0;
for (int64_t i00 = 0; i00 < ne00; i00++) {
sum += (ggml_float)x[i00];
}
float sum = 0.0;
ggml_vec_sum_f32(ne00, &sum, x);
float mean = sum/ne00;
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
float variance = 0;
ggml_float sum2 = 0.0;
for (int64_t i00 = 0; i00 < ne00; i00++) {
float v = x[i00] - mean;
y[i00] = v;
sum2 += (ggml_float)(v*v);
}
#ifdef GGML_USE_ACCELERATE
mean = -mean;
vDSP_vsadd(x, 1, &mean, y, 1, ne00);
vDSP_measqv(y, 1, &variance, ne00);
#else
variance = ggml_vec_cvar_f32(ne00, y, x, mean);
#endif //GGML_USE_ACCELERATE
float variance = sum2/ne00;
const float scale = 1.0f/sqrtf(variance + eps);
ggml_vec_scale_f32(ne00, y, scale);
}
}
@@ -8135,7 +8131,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
}
// V /= S
const float S_inv = 1.0f/S;
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
ggml_vec_scale_f32(DV, VKQ32, S_inv);
// dst indices
+66
View File
@@ -404,6 +404,72 @@ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float *
}
}
ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const float mean) {
int i = 0;
ggml_float sum = 0;
// TODO: optimize to process the remaining elements in groups using the smaller vector sizes from AVX2 and SSE
// ref: https://github.com/ggml-org/llama.cpp/pull/15953#pullrequestreview-3310928344
#if defined(__AVX512F__) && defined(__AVX512DQ__)
for (; i + 15 < n; i += 16) {
__m512 val = _mm512_sub_ps(_mm512_loadu_ps(x + i),
_mm512_set1_ps(mean));
_mm512_storeu_ps(y + i, val);
sum += (ggml_float)_mm512_reduce_add_ps(_mm512_mul_ps(val, val));
}
#elif defined(__AVX2__) && defined(__FMA__)
for (; i + 7 < n; i += 8) {
__m256 val = _mm256_sub_ps(_mm256_loadu_ps(x + i),
_mm256_set1_ps(mean));
_mm256_storeu_ps(y + i, val);
val = _mm256_mul_ps(val,val);
__m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
_mm256_castps256_ps128(val));
val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
sum += (ggml_float)_mm_cvtss_f32(val2);
}
#elif defined(__SSE2__)
for (; i + 3 < n; i += 4) {
__m128 val = _mm_sub_ps(_mm_loadu_ps(x + i),
_mm_set1_ps(mean));
_mm_storeu_ps(y + i, val);
val = _mm_mul_ps(val, val);
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
val = _mm_add_ps(val, _mm_movehl_ps(val, val));
val = _mm_add_ss(val, _mm_movehdup_ps(val));
#else
__m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
val = _mm_add_ps(val, tmp);
tmp = _mm_movehl_ps(tmp, val);
val = _mm_add_ss(val, tmp);
#endif // __AVX__ || __AVX2__ || __AVX512F__
sum += (ggml_float)_mm_cvtss_f32(val);
}
#elif defined(__ARM_NEON) && defined(__aarch64__)
for (; i + 3 < n; i += 4) {
float32x4_t val = vsubq_f32(vld1q_f32(x + i),
vdupq_n_f32(mean));
vst1q_f32(y + i, val);
val = vmulq_f32(val, val);
sum += (ggml_float)vaddvq_f32(val);
}
#elif defined(__VXE__) || defined(__VXE2__)
for (; i + 3 < n; i += 4) {
float32x4_t val = vec_sub(vec_xl(0, x + i), vec_splats(mean));
vec_xst(val, 0, y + i);
val = vec_mul(val, val);
sum += (ggml_float)vec_hsum_f32x4(val);
}
#endif
for (; i < n; ++i) {
float val = x[i] - mean;
y[i] = val;
val *= val;
sum += (ggml_float)val;
}
return sum/n;
}
ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
int i = 0;
ggml_float sum = 0;
+10 -8
View File
@@ -44,6 +44,7 @@ void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t *
void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc);
void ggml_vec_silu_f32(const int n, float * y, const float * x);
ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const float mean); //it will also center y ( y = y - mean )
ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max);
ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max);
@@ -143,14 +144,14 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
for (int i = 0; i < np; i += ggml_f16_step) {
ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); // 8 elements
ax1 = GGML_F16x_VEC_LOAD(x[0] + i + 0*ggml_f16_epr, 0); // 8 elemnst
ax1 = GGML_F16x_VEC_LOAD(x[0] + i + 0*ggml_f16_epr, 0); // 8 elements
sum_00 = GGML_F16x_VEC_FMA(sum_00, ax1, ay1); // sum_00 = sum_00+ax1*ay1
ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 0*ggml_f16_epr, 0); // 8 elements
sum_10 = GGML_F16x_VEC_FMA(sum_10, ax1, ay1);
ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); // next 8 elements
ax2 = GGML_F16x_VEC_LOAD(x[0] + i + 1*ggml_f16_epr, 1); // next 8 ekements
ax2 = GGML_F16x_VEC_LOAD(x[0] + i + 1*ggml_f16_epr, 1); // next 8 elements
sum_01 = GGML_F16x_VEC_FMA(sum_01, ax2, ay2);
ax2 = GGML_F16x_VEC_LOAD(x[1] + i + 1*ggml_f16_epr, 1);
sum_11 = GGML_F16x_VEC_FMA(sum_11, ax2, ay2);
@@ -159,7 +160,7 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
ax3 = GGML_F16x_VEC_LOAD(x[0] + i + 2*ggml_f16_epr, 2);
sum_02 = GGML_F16x_VEC_FMA(sum_02, ax3, ay3);
ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 2*ggml_f16_epr, 2);
ax3 = GGML_F16x_VEC_LOAD(x[1] + i + 2*ggml_f16_epr, 2);
sum_12 = GGML_F16x_VEC_FMA(sum_12, ax3, ay3);
ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
@@ -654,11 +655,11 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
}
// leftovers
// maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
if (np < n) {
svbool_t pg = svwhilelt_b32(np, n);
ay1 = svld1_f32(pg, y + np);
for (int i = np; i < n; i += ggml_f32_epr) {
svbool_t pg = svwhilelt_b32(i, n);
ay1 = svld1_f32(pg, y + i);
ay1 = svmul_f32_m(pg, ay1, vx);
svst1_f32(pg, y + np, ay1);
svst1_f32(pg, y + i, ay1);
}
#elif defined(__riscv_v_intrinsic)
for (int i = 0, avl; i < n; i += avl) {
@@ -819,7 +820,8 @@ inline static void ggml_vec_tanh_f16 (const int n, ggml_fp16_t * y, const ggml_f
inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expm1f(x[i]); }
inline static void ggml_vec_elu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
for (int i = 0; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(expm1f(GGML_CPU_FP16_TO_FP32(x[i])));
const float v = GGML_CPU_FP16_TO_FP32(x[i]);
y[i] = GGML_CPU_FP32_TO_FP16((v > 0.f) ? v : expm1f(v));
}
}
inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
+2
View File
@@ -44,6 +44,8 @@ if (CUDAToolkit_FOUND)
list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
file(GLOB GGML_SOURCES_CUDA "*.cu")
file(GLOB SRCS "template-instances/fattn-tile*.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
file(GLOB SRCS "template-instances/fattn-mma*.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
file(GLOB SRCS "template-instances/mmq*.cu")
+6 -1
View File
@@ -245,7 +245,8 @@ static bool fp16_available(const int cc) {
}
static bool fast_fp16_available(const int cc) {
return (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
return GGML_CUDA_CC_IS_AMD(cc) ||
(GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610);
}
// To be used for feature selection of external libraries, e.g. cuBLAS.
@@ -571,6 +572,10 @@ static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v,
}
// Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD.
// Important: do not use this function if dst and src both point at registers.
// Due to the strict aliasing rule the compiler can do incorrect optimizations if src and dst have different types.
// The function is intended for copies between registers and SRAM/VRAM to make the compiler emit the right instructions.
// If dst and src point at different address spaces then they are guaranteed to not be aliased.
template <int nbytes, int alignment = 0>
static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) {
if constexpr (alignment != 0) {
+3 -6
View File
@@ -793,8 +793,6 @@ void launch_fattn(
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
ggml_cuda_pool & pool = ctx.pool();
cudaStream_t main_stream = ctx.stream();
const int id = ggml_cuda_get_device();
@@ -878,7 +876,7 @@ void launch_fattn(
// Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
// multiple sequences of possibly different lengths.
if (mask && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
const int s31 = mask->nb[1] / sizeof(half2);
const int s33 = mask->nb[3] / sizeof(half2);
@@ -916,8 +914,7 @@ void launch_fattn(
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float));
} else {
GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
const int ntiles_KQ = (K->ne[1] + KQ_row_granularity - 1) / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
// parallel_blocks must not be larger than what the tensor size allows:
parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
@@ -946,7 +943,7 @@ void launch_fattn(
blocks_num.x = ntiles_x;
blocks_num.y = parallel_blocks;
blocks_num.z = Q->ne[2]*Q->ne[3];
blocks_num.z = (Q->ne[2]/ncols2)*Q->ne[3];
if (parallel_blocks > 1) {
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
+30 -741
View File
@@ -1,756 +1,45 @@
#include "common.cuh"
#include "fattn-common.cuh"
#include "fattn-tile.cuh"
#include "fattn-wmma-f16.cuh"
// kq_stride == number of KQ rows to process per iteration
// kq_nbatch == number of K columns to load in parallel for KQ calculation
static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int cc, const int warp_size) {
if (GGML_CUDA_CC_IS_AMD(cc)) {
if (GGML_CUDA_CC_IS_RDNA(cc)) {
switch (D) {
case 64:
return 128;
case 128:
case 256:
return ncols <= 16 ? 128 : 64;
default:
GGML_ABORT("fatal error");
return -1;
}
}
switch (D) {
case 64:
return ncols == 32 ? 128 : 64;
case 128:
return ncols == 32 ? 64 : 32;
case 256:
return 32;
default:
GGML_ABORT("fatal error");
return -1;
}
}
if (fast_fp16_available(cc)) {
switch (D) {
case 64:
case 128:
case 256:
return ncols <= 16 ? 128 : 64;
default:
GGML_ABORT("fatal error");
return -1;
}
}
switch (D) {
case 64:
return ncols <= 16 ? 128 : 64;
case 128:
return ncols <= 16 ? 64 : 32;
case 256:
return 32;
default:
GGML_ABORT("fatal error");
return -1;
}
GGML_UNUSED(warp_size);
}
static constexpr __device__ int fattn_tile_get_kq_stride_device(int D, int ncols, int warp_size) {
#ifdef GGML_USE_HIP
#ifdef RDNA
switch (D) {
case 64:
return 128;
case 128:
case 256:
return ncols <= 16 ? 128 : 64;
default:
return -1;
}
#else
switch (D) {
case 64:
return ncols == 32 ? 128 : 64;
case 128:
return ncols == 32 ? 64 : 32;
case 256:
return 32;
default:
return -1;
}
#endif // RDNA
#else
#ifdef FAST_FP16_AVAILABLE
switch (D) {
case 64:
case 128:
case 256:
return ncols <= 16 ? 128 : 64;
default:
return -1;
}
#else
switch (D) {
case 64:
return ncols <= 16 ? 128 : 64;
case 128:
return ncols <= 16 ? 64 : 32;
case 256:
return 32;
default:
return -1;
}
#endif // FAST_FP16_AVAILABLE
#endif // GGML_USE_HIP
GGML_UNUSED_VARS(ncols, warp_size);
}
static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols, int warp_size) {
#ifdef GGML_USE_HIP
switch (D) {
case 64:
return 64;
case 128:
case 256:
return 128;
default:
return -1;
}
#else
#ifdef FAST_FP16_AVAILABLE
switch (D) {
case 64:
return 64;
case 128:
case 256:
return 128;
default:
return -1;
}
#else
switch (D) {
case 64:
return 64;
case 128:
return 128;
case 256:
return ncols <= 16 ? 128 : 64;
default:
return -1;
}
#endif // FAST_FP16_AVAILABLE
#endif // GGML_USE_HIP
GGML_UNUSED_VARS(ncols, warp_size);
}
static int fattn_tile_get_nthreads_host(const int cc, const int ncols) {
return 256;
GGML_UNUSED_VARS(cc, ncols);
}
static constexpr __device__ int fattn_tile_get_nthreads_device(int ncols) {
return 256;
GGML_UNUSED(ncols);
}
static constexpr __device__ int fattn_tile_get_occupancy_device(int ncols) {
#ifdef RDNA
return 3;
#else
return ncols <= 16 ? 3 : 2;
#endif // RDNA
GGML_UNUSED(ncols);
}
template<int D, int ncols, bool use_logit_softcap> // D == head size
__launch_bounds__(fattn_tile_get_nthreads_device(ncols), fattn_tile_get_occupancy_device(ncols))
static __global__ void flash_attn_tile(
const char * __restrict__ Q,
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
const float max_bias,
const float m0,
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
const int32_t nb01, const int32_t nb02, const int32_t nb03,
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
const int32_t nb11, const int32_t nb12, const int64_t nb13,
const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
#ifdef FLASH_ATTN_AVAILABLE
// Skip unused kernel variants for faster compilation:
#ifdef GGML_USE_WMMA_FATTN
NO_DEVICE_CODE;
return;
#endif // GGML_USE_WMMA_FATTN
if (use_logit_softcap && !(D == 128 || D == 256)) {
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
max_bias, m0, m1, n_head_log2, logit_softcap,
ne00, ne01, ne02, ne03,
nb01, nb02, nb03,
ne10, ne11, ne12, ne13,
nb11, nb12, nb13,
nb21, nb22, nb23,
ne31, ne32, ne33,
nb31, nb32, nb33);
NO_DEVICE_CODE;
return;
}
constexpr int warp_size = 32;
constexpr int nwarps = fattn_tile_get_nthreads_device(ncols) / warp_size;
constexpr int kq_stride = fattn_tile_get_kq_stride_device(D, ncols, warp_size);
static_assert(kq_stride % warp_size == 0, "kq_stride not divisable by warp_size.");
constexpr int kq_nbatch = fattn_tile_get_kq_nbatch_device(D, ncols, warp_size);
static_assert(kq_nbatch % (2*warp_size) == 0, "bad kq_nbatch");
// In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float * sinksf = (const float *) (sinks);
const int stride_KV2 = nb11 / sizeof(half2);
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
constexpr int cpy_ne = cpy_nb / 4;
constexpr int cpw = ncols/nwarps; // cols per warp
// softmax_iter_j == number of KQ columns for which to calculate softmax in parallel.
// KQ is originall 2D but uses a Z-shaped memory pattern for larger reads/writes.
#ifdef FAST_FP16_AVAILABLE
constexpr int softmax_iter_j = cpw < 2*cpy_ne ? cpw : 2*cpy_ne;
__shared__ half KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j];
__shared__ half2 Q_tmp[ncols][D/2];
__shared__ half2 KV_tmp[kq_stride * (kq_nbatch/2 + cpy_ne)]; // Padded to avoid memory bank conflicts.
half2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
#else
constexpr int softmax_iter_j = cpw < 1*cpy_ne ? cpw : 1*cpy_ne;
__shared__ float KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j];
__shared__ float Q_tmp[ncols][D];
__shared__ float KV_tmp[kq_stride * (kq_nbatch + cpy_ne)]; // Padded to avoid memory bank conflicts.
float2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
#endif // FAST_FP16_AVAILABLE
static_assert(cpw % softmax_iter_j == 0, "bad softmax_iter_j");
float KQ_max[cpw];
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
KQ_max[j0/nwarps] = -FLT_MAX/2.0f;
}
float KQ_sum[cpw] = {0.0f};
// Load Q data, convert to FP16 if fast.
#pragma unroll
for (int j0 = 0; j0 < cpw; ++j0) {
const int j = j0 + threadIdx.y*cpw;
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
#pragma unroll
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
float tmp_f[cpy_ne_D] = {0.0f};
if (ic0 + j < ne01) {
ggml_cuda_memcpy_1<sizeof(tmp_f)>(tmp_f, &Q_f[j*(nb01/sizeof(float)) + i0 + threadIdx.x*cpy_ne_D]);
}
#pragma unroll
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
tmp_f[i1] *= scale;
}
#ifdef FAST_FP16_AVAILABLE
half2 tmp_h2[cpy_ne_D/2];
#pragma unroll
for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) {
tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]);
}
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(&Q_tmp[j][i0/2 + threadIdx.x*(cpy_ne_D/2)], tmp_h2);
#else
ggml_cuda_memcpy_1<sizeof(tmp_f)> (&Q_tmp[j][i0 + threadIdx.x* cpy_ne_D], tmp_f);
#endif // FAST_FP16_AVAILABLE
}
}
__syncthreads();
// Main loop over KV cache:
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
for (int k_VKQ_0 = blockIdx.y*kq_stride; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*kq_stride) {
// Calculate KQ tile and keep track of new maximum KQ values:
float KQ_max_new[cpw];
#pragma unroll
for (int j = 0; j < cpw; ++j) {
KQ_max_new[j] = KQ_max[j];
}
float KQ_acc[kq_stride/warp_size][cpw] = {{0.0f}}; // Accumulators for KQ matrix multiplication.
// KQ = K @ Q matrix multiplication:
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += kq_nbatch) {
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += nwarps) {
const int i_KQ = i_KQ_0 + threadIdx.y;
#ifdef FAST_FP16_AVAILABLE
constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/(2*warp_size) ? cpy_ne : kq_nbatch/(2*warp_size);
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size*cpy_ne_kqnb) {
ggml_cuda_memcpy_1<cpy_ne_kqnb*4>(
&KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb],
&K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x*cpy_ne_kqnb]);
}
#else
constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/warp_size ? cpy_ne : kq_nbatch/warp_size;
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += warp_size*cpy_ne_kqnb) {
half2 tmp_h2[cpy_ne_kqnb/2];
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
tmp_h2, &K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1/2 + threadIdx.x*(cpy_ne_kqnb/2)]);
float2 tmp_f2[cpy_ne_kqnb/2];
#pragma unroll
for (int k_KQ_2 = 0; k_KQ_2 < cpy_ne_kqnb/2; ++k_KQ_2) {
tmp_f2[k_KQ_2] = __half22float2(tmp_h2[k_KQ_2]);
}
ggml_cuda_memcpy_1<sizeof(tmp_f2)>(
&KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb], tmp_f2);
}
#endif // FAST_FP16_AVAILABLE
}
__syncthreads();
#ifdef FAST_FP16_AVAILABLE
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += cpy_ne) {
half2 K_k[kq_stride/warp_size][cpy_ne];
half2 Q_k[cpw][cpy_ne];
#else
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += cpy_ne) {
float K_k[kq_stride/warp_size][cpy_ne];
float Q_k[cpw][cpy_ne];
#endif // FAST_FP16_AVAILABLE
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
const int i_KQ = i_KQ_0 + threadIdx.x;
#ifdef FAST_FP16_AVAILABLE
ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1]);
#else
ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1]);
#endif // FAST_FP16_AVAILABLE
}
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
const int j_KQ = j_KQ_0 + threadIdx.y*cpw;
#ifdef FAST_FP16_AVAILABLE
ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]);
#else
ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0 + k_KQ_1]);
#endif // FAST_FP16_AVAILABLE
}
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
#pragma unroll
for (int k = 0; k < cpy_ne; ++k) {
ggml_cuda_mad(KQ_acc[i_KQ_0/warp_size][j_KQ_0], K_k[i_KQ_0/warp_size][k], Q_k[j_KQ_0][k]);
}
}
}
}
if (k_KQ_0 + kq_nbatch < D) {
__syncthreads(); // Sync not needed on last iteration.
}
}
// Apply logit softcap, mask, update KQ_max:
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
const int i_KQ = i_KQ_0 + threadIdx.x;
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
const int j_KQ = j_KQ_0 + threadIdx.y*cpw;
if (use_logit_softcap) {
KQ_acc[i_KQ_0/warp_size][j_KQ_0] = logit_softcap * tanhf(KQ_acc[i_KQ_0/warp_size][j_KQ_0]);
}
KQ_acc[i_KQ_0/warp_size][j_KQ_0] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
KQ_max_new[j_KQ_0] = fmaxf(KQ_max_new[j_KQ_0], KQ_acc[i_KQ_0/warp_size][j_KQ_0]);
}
}
__syncthreads();
// Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators:
#pragma unroll
for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
#ifdef FAST_FP16_AVAILABLE
half tmp[kq_stride/warp_size][softmax_iter_j];
#else
float tmp[kq_stride/warp_size][softmax_iter_j];
#endif // FAST_FP16_AVAILABLE
#pragma unroll
for (int j1 = 0; j1 < softmax_iter_j; ++j1) {
KQ_max_new[j0+j1] = warp_reduce_max<warp_size>(KQ_max_new[j0+j1]);
const float KQ_max_scale = expf(KQ_max[j0+j1] - KQ_max_new[j0+j1]);
KQ_max[j0+j1] = KQ_max_new[j0+j1];
float KQ_sum_add = 0.0f;
#pragma unroll
for (int i0 = 0; i0 < kq_stride; i0 += warp_size) {
const float val = expf(KQ_acc[i0/warp_size][j0+j1] - KQ_max[j0+j1]);
KQ_sum_add += val;
tmp[i0/warp_size][j1] = val;
}
KQ_sum[j0+j1] = KQ_sum[j0+j1]*KQ_max_scale + KQ_sum_add;
#ifdef FAST_FP16_AVAILABLE
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
VKQ[j0+j1][i0/warp_size] *= KQ_max_scale_h2;
}
#else
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
VKQ[j0+j1][i0/warp_size].x *= KQ_max_scale;
VKQ[j0+j1][i0/warp_size].y *= KQ_max_scale;
}
#endif // FAST_FP16_AVAILABLE
}
#pragma unroll
for (int i0 = 0; i0 < kq_stride; i0 += warp_size) {
const int i = i0 + threadIdx.x;
ggml_cuda_memcpy_1<sizeof(tmp[0])>(
KQ[j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j)][i], tmp[i0/warp_size]);
}
}
// VKQ = V @ KQ matrix multiplication:
constexpr int V_cols_per_iter = kq_stride*kq_nbatch / D; // Number of V columns that fit in SRAM for K.
static_assert(kq_stride % V_cols_per_iter == 0, "bad V_cols_per_iter");
#pragma unroll
for (int k0 = 0; k0 < kq_stride; k0 += V_cols_per_iter) {
#pragma unroll
for (int k1 = 0; k1 < V_cols_per_iter; k1 += nwarps) {
const int k_tile = k1 + threadIdx.y;
#ifdef FAST_FP16_AVAILABLE
constexpr int cpy_ne_D = cpy_ne < D/(2*warp_size) ? cpy_ne : D/(2*warp_size);
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
ggml_cuda_memcpy_1<cpy_ne_D*4>(
&KV_tmp[k_tile*(D/2) + i0 + threadIdx.x*cpy_ne_D],
&V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0 + threadIdx.x*cpy_ne_D]);
}
#else
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
#pragma unroll
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
half2 tmp_h2[cpy_ne_D/2];
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
tmp_h2, &V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0/2 + threadIdx.x*(cpy_ne_D/2)]);
float2 tmp_f2[cpy_ne_D/2];
#pragma unroll
for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) {
tmp_f2[i1] = __half22float2(tmp_h2[i1]);
}
ggml_cuda_memcpy_1<sizeof(tmp_f2)>(
&KV_tmp[k_tile*D + i0 + threadIdx.x*cpy_ne_D], tmp_f2);
}
#endif // FAST_FP16_AVAILABLE
}
__syncthreads();
#ifdef FAST_FP16_AVAILABLE
#pragma unroll
for (int k1 = 0; k1 < V_cols_per_iter; ++k1) {
half2 V_k[(D/2)/warp_size];
half2 KQ_k[cpw];
constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size;
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/warp_size], &KV_tmp[k1*(D/2) + i0 + threadIdx.x*cpy_ne_D]);
}
#pragma unroll
for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j);
half tmp[softmax_iter_j];
ggml_cuda_memcpy_1<softmax_iter_j*sizeof(half)>(
&tmp, KQ[j][k0 + k1]);
#pragma unroll
for (int j1 = 0; j1 < softmax_iter_j; ++j1) {
KQ_k[j0+j1] = __half2half2(tmp[j1]);
}
}
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
#pragma unroll
for (int j0 = 0; j0 < cpw; ++j0) {
VKQ[j0][i0/warp_size] += V_k[i0/warp_size]*KQ_k[j0];
}
}
}
#else
#pragma unroll
for (int k1 = 0; k1 < V_cols_per_iter; ++k1) {
float2 V_k[(D/2)/warp_size];
float KQ_k[cpw];
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
#pragma unroll
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/(2*warp_size)], &KV_tmp[k1*D + i0 + threadIdx.x*cpy_ne_D]);
}
#pragma unroll
for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j);
ggml_cuda_memcpy_1<softmax_iter_j*sizeof(float)>(
&KQ_k[j0], KQ[j][k0 + k1]);
}
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
#pragma unroll
for (int j0 = 0; j0 < cpw; ++j0) {
VKQ[j0][i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[j0];
VKQ[j0][i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[j0];
}
}
}
#endif // FAST_FP16_AVAILABLE
__syncthreads();
}
}
// Attention sink: adjust running max and sum once per head
if (sinksf && blockIdx.y == 0) {
const float sink = sinksf[head];
#pragma unroll
for (int j0 = 0; j0 < cpw; ++j0) {
float KQ_max_new_j = fmaxf(KQ_max[j0], sink);
KQ_max_new_j = warp_reduce_max<warp_size>(KQ_max_new_j);
const float KQ_max_scale = expf(KQ_max[j0] - KQ_max_new_j);
KQ_max[j0] = KQ_max_new_j;
const float val = expf(sink - KQ_max[j0]);
KQ_sum[j0] = KQ_sum[j0] * KQ_max_scale;
if (threadIdx.x == 0) {
KQ_sum[j0] += val;
}
#ifdef FAST_FP16_AVAILABLE
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
VKQ[j0][i0/warp_size] *= KQ_max_scale_h2;
}
#else
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
VKQ[j0][i0/warp_size].x *= KQ_max_scale;
VKQ[j0][i0/warp_size].y *= KQ_max_scale;
}
#endif // FAST_FP16_AVAILABLE
}
}
#pragma unroll
for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
KQ_sum[j_VKQ_0] = warp_reduce_sum<warp_size>(KQ_sum[j_VKQ_0]);
}
if (gridDim.y == 1) {
#pragma unroll
for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
#ifdef FAST_FP16_AVAILABLE
const half2 KQ_sum_j_inv = make_half2(1.0f/KQ_sum[j_VKQ_0], 1.0f/KQ_sum[j_VKQ_0]);
#pragma unroll
for (int i = 0; i < (D/2)/warp_size; ++i) {
VKQ[j_VKQ_0][i] *= KQ_sum_j_inv;
}
#else
const float KQ_sum_j_inv = 1.0f/KQ_sum[j_VKQ_0];
#pragma unroll
for (int i = 0; i < (D/2)/warp_size; ++i) {
VKQ[j_VKQ_0][i].x *= KQ_sum_j_inv;
VKQ[j_VKQ_0][i].y *= KQ_sum_j_inv;
}
#endif // FAST_FP16_AVAILABLE
}
}
// Write back results:
#pragma unroll
for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
const int j_VKQ = j_VKQ_0 + threadIdx.y*cpw;
if (ic0 + j_VKQ >= ne01) {
return;
}
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
#ifdef FAST_FP16_AVAILABLE
constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size;
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
float2 tmp[cpy_ne_D];
#pragma unroll
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
tmp[i1] = __half22float2(VKQ[j_VKQ_0][i0/warp_size + i1]);
}
ggml_cuda_memcpy_1<sizeof(tmp)>(&dst[j_dst_unrolled*D + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp);
}
#else
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
#pragma unroll
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
ggml_cuda_memcpy_1<cpy_ne_D*4>(
&dst[j_dst_unrolled*D + i0 + threadIdx.x*cpy_ne_D], &VKQ[j_VKQ_0][i0/(2*warp_size)]);
}
#endif // FAST_FP16_AVAILABLE
if (gridDim.y != 1 && threadIdx.x == 0) {
dst_meta[j_dst_unrolled] = make_float2(KQ_max[j_VKQ_0], KQ_sum[j_VKQ_0]);
}
}
#else
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
max_bias, m0, m1, n_head_log2, logit_softcap,
ne00, ne01, ne02, ne03,
nb01, nb02, nb03,
ne10, ne11, ne12, ne13,
nb11, nb12, nb13,
nb21, nb22, nb23,
ne31, ne32, ne33,
nb31, nb32, nb33);
NO_DEVICE_CODE;
#endif // FLASH_ATTN_AVAILABLE
}
template <int D, bool use_logit_softcap>
static void launch_fattn_tile_switch_ncols(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
const int id = ggml_cuda_get_device();
const int cc = ggml_cuda_info().devices[id].cc;
const int warp_size = 32;
constexpr size_t nbytes_shared = 0;
#ifdef GGML_USE_HIP
if constexpr (D <= 128) {
if (Q->ne[1] > 32) {
constexpr int cols_per_block = 64;
const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
launch_fattn<D, cols_per_block, 1>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size);
return;
}
}
#endif // GGML_USE_HIP
if (Q->ne[1] > 16) {
constexpr int cols_per_block = 32;
const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
launch_fattn<D, cols_per_block, 1>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size);
return;
}
constexpr int cols_per_block = 16;
const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
launch_fattn<D, cols_per_block, 1>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size);
}
template <bool use_logit_softcap>
static void launch_fattn_tile_switch_head_size(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
switch (K->ne[0]) {
case 40: {
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case< 40, 40>(ctx, dst);
} break;
case 64: {
launch_fattn_tile_switch_ncols< 64, use_logit_softcap>(ctx, dst);
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case< 64, 64>(ctx, dst);
} break;
case 80: {
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case< 80, 80>(ctx, dst);
} break;
case 96: {
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case< 96, 96>(ctx, dst);
} break;
case 112: {
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case<112, 112>(ctx, dst);
} break;
case 128: {
launch_fattn_tile_switch_ncols<128, use_logit_softcap>(ctx, dst);
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case<128, 128>(ctx, dst);
} break;
case 256: {
launch_fattn_tile_switch_ncols<256, use_logit_softcap>(ctx, dst);
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst);
} break;
case 576: {
GGML_ASSERT(V->ne[0] == 512);
ggml_cuda_flash_attn_ext_tile_case<576, 512>(ctx, dst);
} break;
default: {
GGML_ABORT("Unsupported head size");
} break;
}
}
void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
launch_fattn_tile_switch_head_size<use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
launch_fattn_tile_switch_head_size<use_logit_softcap>(ctx, dst);
}
}
File diff suppressed because it is too large Load Diff
+2
View File
@@ -1,3 +1,5 @@
#pragma once
#include "common.cuh"
#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
+41 -29
View File
@@ -198,6 +198,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
return BEST_FATTN_KERNEL_NONE;
#endif// FLASH_ATTN_AVAILABLE
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
@@ -206,31 +207,32 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
const int gqa_ratio = Q->ne[2] / K->ne[2];
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
float max_bias = 0.0f;
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
// The effective batch size for the kernel can be increased by gqa_ratio.
// The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded,
const bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
const int cc = ggml_cuda_info().devices[device].cc;
switch (K->ne[0]) {
case 40:
case 64:
case 128:
case 256:
if (V->ne[0] != K->ne[0]) {
return BEST_FATTN_KERNEL_NONE;
}
break;
case 80:
case 96:
case 128:
case 112:
case 256:
if (V->ne[0] != K->ne[0]) {
return BEST_FATTN_KERNEL_NONE;
}
if (!ggml_cuda_should_use_wmma_fattn(cc) && !turing_mma_available(cc)) {
return BEST_FATTN_KERNEL_NONE;
}
break;
case 576:
if (V->ne[0] != 512) {
return BEST_FATTN_KERNEL_NONE;
}
if (!turing_mma_available(cc) || gqa_ratio % 16 != 0) {
if (!gqa_opt_applies || gqa_ratio % 16 != 0) {
return BEST_FATTN_KERNEL_NONE;
}
break;
@@ -264,47 +266,57 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
return BEST_FATTN_KERNEL_NONE;
}
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0;
// If Turing tensor cores available, use them except for some cases with batch size 1:
if (turing_mma_available(cc)) {
best_fattn_kernel best = BEST_FATTN_KERNEL_MMA_F16;
// For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes:
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
// If Turing tensor cores available, use them:
if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40) {
if (can_use_vector_kernel) {
if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
best = BEST_FATTN_KERNEL_VEC;
return BEST_FATTN_KERNEL_VEC;
}
} else {
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
if (Q->ne[1] <= 2) {
best = BEST_FATTN_KERNEL_VEC;
return BEST_FATTN_KERNEL_VEC;
}
} else {
if (Q->ne[1] == 1) {
best = BEST_FATTN_KERNEL_VEC;
return BEST_FATTN_KERNEL_VEC;
}
}
}
if ((gqa_ratio % 2 != 0 || !mask) && Q->ne[1] == 1) {
best = BEST_FATTN_KERNEL_VEC; // GQA-specific optimizations in the mma kernel do not apply.
if (!gqa_opt_applies && Q->ne[1] == 1) {
return BEST_FATTN_KERNEL_VEC;
}
}
return best;
return BEST_FATTN_KERNEL_MMA_F16;
}
// Use kernels specialized for small batch sizes if possible:
if (Q->ne[1] <= 8 && can_use_vector_kernel) {
return BEST_FATTN_KERNEL_VEC;
}
// For large batch sizes, use the WMMA kernel if possible:
if (ggml_cuda_should_use_wmma_fattn(cc)) {
// Use the WMMA kernel if possible:
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 576) {
if (can_use_vector_kernel && Q->ne[1] <= 2) {
return BEST_FATTN_KERNEL_VEC;
}
return BEST_FATTN_KERNEL_WMMA_F16;
}
// If there is no suitable kernel for tensor cores or small batch sizes, use the generic kernel for large batch sizes:
// If there are no tensor cores available, use the generic tile kernel:
if (can_use_vector_kernel) {
if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
if (Q->ne[1] == 1) {
if (!gqa_opt_applies) {
return BEST_FATTN_KERNEL_VEC;
}
}
} else {
if (Q->ne[1] <= 2) {
return BEST_FATTN_KERNEL_VEC;
}
}
}
return BEST_FATTN_KERNEL_TILE;
}
+1 -2
View File
@@ -231,7 +231,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
info.default_tensor_split[id] = total_vram;
total_vram += prop.totalGlobalMem;
info.devices[id].integrated = prop.integrated;
info.devices[id].integrated = false; // Temporarily disabled due to issues with corrupted output (e.g. #15034)
info.devices[id].nsm = prop.multiProcessorCount;
info.devices[id].smpb = prop.sharedMemPerBlock;
info.devices[id].warp_size = prop.warpSize;
@@ -3867,7 +3867,6 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
dev_ctx->device = i;
dev_ctx->name = GGML_CUDA_NAME + std::to_string(i);
ggml_cuda_set_device(i);
cudaDeviceProp prop;
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
dev_ctx->description = prop.name;
@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-tile.cuh"
DECL_FATTN_TILE_CASE(112, 112);
@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-tile.cuh"
DECL_FATTN_TILE_CASE(128, 128);
@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-tile.cuh"
DECL_FATTN_TILE_CASE(256, 256);
@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-tile.cuh"
DECL_FATTN_TILE_CASE(40, 40);
@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-tile.cuh"
DECL_FATTN_TILE_CASE(576, 512);
@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-tile.cuh"
DECL_FATTN_TILE_CASE(64, 64);
@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-tile.cuh"
DECL_FATTN_TILE_CASE(80, 80);
@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-tile.cuh"
DECL_FATTN_TILE_CASE(96, 96);
@@ -3,8 +3,17 @@
from glob import glob
import os
HEAD_SIZES_KQ = [40, 64, 80, 96, 112, 128, 256, 576]
TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"]
SOURCE_FATTN_TILE = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-tile.cuh"
DECL_FATTN_TILE_CASE({head_size_kq}, {head_size_v});
"""
SOURCE_FATTN_VEC = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
@@ -51,6 +60,11 @@ def get_short_name(long_quant_name):
for filename in glob("*.cu"):
os.remove(filename)
for head_size_kq in HEAD_SIZES_KQ:
head_size_v = head_size_kq if head_size_kq != 576 else 512
with open(f"fattn-tile-instance-dkq{head_size_kq}-dv{head_size_v}.cu", "w") as f:
f.write(SOURCE_FATTN_TILE.format(head_size_kq=head_size_kq, head_size_v=head_size_v))
for type_k in TYPES_KV:
for type_v in TYPES_KV:
with open(f"fattn-vec-instance-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
@@ -64,7 +78,9 @@ for ncols in [8, 16, 32, 64]:
with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f:
f.write(SOURCE_FATTN_MMA_START)
for head_size_kq in [64, 80, 96, 112, 128, 256, 576]:
for head_size_kq in HEAD_SIZES_KQ:
if head_size_kq == 40:
continue
if head_size_kq != 576 and ncols2 == 16:
continue
if head_size_kq == 576 and ncols2 != 16:
+2
View File
@@ -53,6 +53,8 @@ file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh")
list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h")
file(GLOB GGML_SOURCES_ROCM "../ggml-cuda/*.cu")
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-tile*.cu")
list(APPEND GGML_SOURCES_ROCM ${SRCS})
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
list(APPEND GGML_SOURCES_ROCM ${SRCS})
file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
+2 -2
View File
@@ -112,7 +112,7 @@ static bool ggml_mem_ranges_add_dst(ggml_mem_ranges_t mrs, const ggml_tensor * t
}
bool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
for (int i = 0; i < GGML_MAX_DIMS; i++) {
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (tensor->src[i]) {
ggml_mem_ranges_add_src(mrs, tensor->src[i]);
}
@@ -173,7 +173,7 @@ static bool ggml_mem_ranges_check_dst(ggml_mem_ranges_t mrs, const ggml_tensor *
}
bool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
for (int i = 0; i < GGML_MAX_DIMS; i++) {
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (tensor->src[i]) {
if (!ggml_mem_ranges_check_src(mrs, tensor->src[i])) {
return false;
+174 -10
View File
@@ -268,6 +268,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu(ggml_metal_library_t l
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_SUM);
char base[256];
char name[256];
snprintf(base, 256, "kernel_op_sum_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
@@ -338,7 +357,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar
char base[256];
char name[256];
snprintf(base, 256, "kernel_ssm_conv_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
const char * suffix = "";
if (op->src[1]->ne[0] % 4 == 0) {
suffix = "_4";
}
snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix);
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
@@ -352,15 +377,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
char base[256];
char name[256];
if (op->src[3]->ne[0] == 1) {
snprintf(base, 256, "kernel_ssm_scan_group_%s", ggml_type_name(op->src[0]->type));
} else {
snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type));
}
snprintf(name, 256, "%s", base);
const int nsg = (ne00 + 31)/32;
snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s_nsg=%d", base, nsg);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
@@ -369,7 +394,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_librar
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
ggml_metal_pipeline_set_smem(res, 32*sizeof(float)*nsg);
return res;
}
@@ -918,6 +943,96 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
ggml_metal_library_t lib,
const struct ggml_tensor * op,
bool has_mask,
int32_t ncpsg) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
GGML_UNUSED(op);
char base[256];
char name[256];
snprintf(base, 256, "kernel_%s",
"flash_attn_ext_pad");
snprintf(name, 256, "%s_mask=%d_ncpsg=%d",
base,
has_mask,
ncpsg);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0);
//ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
//ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2);
//ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3);
//ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
//ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
//ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
//ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
//ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24);
ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_cv_free(cv);
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk(
ggml_metal_library_t lib,
const struct ggml_tensor * op,
int32_t nqptg,
int32_t ncpsg) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
GGML_UNUSED(op);
char base[256];
char name[256];
snprintf(base, 256, "kernel_%s",
"flash_attn_ext_blk");
snprintf(name, 256, "%s_nqptg=%d_ncpsg=%d",
base,
nqptg,
ncpsg);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
ggml_metal_cv_t cv = ggml_metal_cv_init();
//ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0);
//ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1);
//ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2);
//ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3);
//ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20);
//ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21);
//ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22);
//ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23);
ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24);
ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_cv_free(cv);
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
ggml_metal_library_t lib,
const ggml_tensor * op,
@@ -925,6 +1040,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
bool has_sinks,
bool has_bias,
bool has_scap,
bool has_kvpad,
int32_t nsg) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
@@ -937,18 +1053,23 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];
// do bounds checks for the mask?
const bool bc_mask = op->src[3] && (op->src[3]->ne[1] % 8 != 0);
snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
"flash_attn_ext",
ggml_type_name(op->src[1]->type),
dk,
dv);
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=%d_ns10=%d_ns20=%d_nsg=%d",
base,
has_mask,
has_sinks,
has_bias,
has_scap,
has_kvpad,
bc_mask,
ns10,
ns20,
nsg);
@@ -964,6 +1085,9 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2);
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3);
ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4);
ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10);
ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
@@ -983,6 +1107,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
bool has_sinks,
bool has_bias,
bool has_scap,
bool has_kvpad,
int32_t nsg,
int32_t nwg) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
@@ -1002,12 +1127,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
dk,
dv);
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
base,
has_mask,
has_sinks,
has_bias,
has_scap,
has_kvpad,
ns10,
ns20,
nsg, nwg);
@@ -1023,6 +1149,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);
ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2);
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3);
ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4);
ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);
ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);
@@ -1374,3 +1501,40 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_me
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_OPT_STEP_ADAMW);
char base[256];
char name[256];
snprintf(base, 256, "kernel_opt_step_adamw_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_OPT_STEP_SGD);
char base[256];
char name[256];
snprintf(base, 256, "kernel_opt_step_sgd_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
return res;
}
+17
View File
@@ -109,6 +109,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
@@ -134,6 +135,20 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
ggml_metal_library_t lib,
const struct ggml_tensor * op,
bool has_mask,
int32_t ncpsg);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk(
ggml_metal_library_t lib,
const struct ggml_tensor * op,
int32_t nqptg,
int32_t ncpsg);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
ggml_metal_library_t lib,
@@ -142,6 +157,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
bool has_sinks,
bool has_bias,
bool has_scap,
bool has_kvpad,
int32_t nsg);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
@@ -151,6 +167,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
bool has_sinks,
bool has_bias,
bool has_scap,
bool has_kvpad,
int32_t nsg,
int32_t nwg);
+5 -3
View File
@@ -656,6 +656,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_COS:
case GGML_OP_LOG:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_SOFT_MAX:
@@ -776,9 +777,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
};
}
case GGML_OP_GET_ROWS:
{
return op->ne[3] == 1;
}
return true;
case GGML_OP_SET_ROWS:
{
if (op->src[0]->type != GGML_TYPE_F32) {
@@ -800,6 +799,9 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
return false;
};
}
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
return has_simdgroup_reduction;
default:
return false;
}
+73 -7
View File
@@ -69,11 +69,20 @@
#define N_SG_IQ4_XS 2
// function constants offsets
#define FC_FLASH_ATTN_EXT 100
#define FC_FLASH_ATTN_EXT_VEC 200
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
#define FC_MUL_MV 400
#define FC_MUL_MM 500
#define FC_FLASH_ATTN_EXT_PAD 100
#define FC_FLASH_ATTN_EXT_BLK 200
#define FC_FLASH_ATTN_EXT 300
#define FC_FLASH_ATTN_EXT_VEC 400
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 500
#define FC_MUL_MV 600
#define FC_MUL_MM 700
// op-specific constants
#define OP_FLASH_ATTN_EXT_NQPTG 8
#define OP_FLASH_ATTN_EXT_NCPSG 64
#define OP_FLASH_ATTN_EXT_VEC_NQPTG 1
#define OP_FLASH_ATTN_EXT_VEC_NCPSG 32
// kernel argument structs
//
@@ -178,6 +187,7 @@ typedef struct {
} ggml_metal_kargs_clamp;
typedef struct {
int64_t nk0;
int64_t ne00;
int64_t ne01;
int64_t ne02;
@@ -243,6 +253,35 @@ typedef struct {
int32_t sect_3;
} ggml_metal_kargs_rope;
typedef struct {
int32_t ne11;
int32_t ne_12_2; // assume K and V are same shape
int32_t ne_12_3;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
int32_t ne31;
int32_t ne32;
int32_t ne33;
uint64_t nb31;
uint64_t nb32;
uint64_t nb33;
} ggml_metal_kargs_flash_attn_ext_pad;
typedef struct {
int32_t ne01;
int32_t ne30;
int32_t ne31;
int32_t ne32;
int32_t ne33;
uint64_t nb31;
uint64_t nb32;
uint64_t nb33;
} ggml_metal_kargs_flash_attn_ext_blk;
typedef struct {
int32_t ne01;
int32_t ne02;
@@ -261,6 +300,7 @@ typedef struct {
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
int32_t ne31;
int32_t ne32;
int32_t ne33;
uint64_t nb31;
@@ -295,6 +335,7 @@ typedef struct {
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
int32_t ne31;
int32_t ne32;
int32_t ne33;
uint64_t nb31;
@@ -503,6 +544,10 @@ typedef struct{
float limit;
} ggml_metal_kargs_glu;
typedef struct {
uint64_t np;
} ggml_metal_kargs_sum;
typedef struct {
int64_t ne00;
int64_t ne01;
@@ -572,32 +617,45 @@ typedef struct {
int64_t n_seq_tokens;
int64_t n_seqs;
uint64_t s_off;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t ns12;
uint64_t nb13;
uint64_t nb20;
uint64_t nb21;
uint64_t ns21;
uint64_t nb22;
int64_t ne30;
uint64_t nb31;
uint64_t nb41;
uint64_t nb42;
uint64_t ns42;
uint64_t nb43;
uint64_t nb51;
uint64_t nb52;
uint64_t ns52;
uint64_t nb53;
uint64_t nb0;
} ggml_metal_kargs_ssm_scan;
typedef struct {
int64_t ne00;
int32_t ne00t;
int32_t ne00;
uint64_t nb01;
uint64_t nb02;
int64_t ne10;
uint64_t nb03;
int32_t ne10;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
} ggml_metal_kargs_get_rows;
typedef struct {
@@ -719,4 +777,12 @@ typedef struct {
uint64_t nb01;
} ggml_metal_kargs_argmax;
typedef struct {
int64_t np;
} ggml_metal_kargs_opt_step_adamw;
typedef struct {
int64_t np;
} ggml_metal_kargs_opt_step_sgd;
#endif // GGML_METAL_IMPL
+439 -88
View File
@@ -226,6 +226,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS(uint64_t, nb0, node->src[0], nb);
GGML_TENSOR_LOCALS( int64_t, ne1, node->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, node->src[1], nb);
GGML_TENSOR_LOCALS( int64_t, ne2, node->src[2], ne);
GGML_TENSOR_LOCALS(uint64_t, nb2, node->src[2], nb);
GGML_TENSOR_LOCALS( int64_t, ne3, node->src[3], ne);
GGML_TENSOR_LOCALS(uint64_t, nb3, node->src[3], nb);
GGML_TENSOR_LOCALS( int64_t, ne, node, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, node, nb);
@@ -237,6 +241,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[1]->type), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
ggml_is_contiguous(node->src[1]), node->src[1]->name);
}
if (node->src[2]) {
GGML_LOG_DEBUG("%s: src2 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[2]->type), ne20, ne21, ne22, ne23, nb20, nb21, nb22, nb23,
ggml_is_contiguous(node->src[2]), node->src[2]->name);
}
if (node->src[3]) {
GGML_LOG_DEBUG("%s: src3 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[3]->type), ne30, ne31, ne32, ne33, nb30, nb31, nb32, nb33,
ggml_is_contiguous(node->src[3]), node->src[3]->name);
}
if (node) {
GGML_LOG_DEBUG("%s: node - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(node->type), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
node->name);
@@ -289,6 +301,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_glu(ctx, idx);
} break;
case GGML_OP_SUM:
{
n_fuse = ggml_metal_op_sum(ctx, idx);
} break;
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
{
@@ -398,6 +414,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_argmax(ctx, idx);
} break;
case GGML_OP_OPT_STEP_ADAMW:
{
n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx);
} break;
case GGML_OP_OPT_STEP_SGD:
{
n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx);
} break;
default:
{
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
@@ -577,6 +601,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
ggml_metal_kargs_cpy args = {
/*.nk0 =*/ ne00,
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
@@ -827,6 +852,30 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
return 1;
}
int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
const uint64_t n = (uint64_t) ggml_nelements(op->src[0]);
ggml_metal_kargs_sum args = {
/*.np =*/ n,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
return 1;
}
int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
@@ -906,23 +955,31 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
ggml_metal_kargs_get_rows args = {
/*.ne00 =*/ ne00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.ne10 =*/ ne10,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00,
/*.ne00 =*/ ne00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne10 =*/ ne10,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
};
const int nth = std::min(args.ne00t, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
const int nw0 = (args.ne00t + nth - 1)/nth;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
ggml_metal_encoder_dispatch_threadgroups(enc, ne10, ne11, ne12, 32, 1, 1);
ggml_metal_encoder_dispatch_threadgroups(enc, nw0*ne10, ne11, ne12, nth, 1, 1);
return 1;
}
@@ -1117,7 +1174,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
@@ -1172,25 +1229,36 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
/*.n_seq_tokens =*/ n_seq_tokens,
/*.n_seqs =*/ n_seqs,
/*.s_off =*/ ggml_nelements(op->src[1]) * sizeof(float),
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.ns12 =*/ nb12/nb10,
/*.nb13 =*/ nb13,
/*.nb20 =*/ nb20,
/*.nb21 =*/ nb21,
/*.ns21 =*/ nb21/nb20,
/*.nb22 =*/ nb22,
/*.ne30 =*/ ne30,
/*.nb31 =*/ nb31,
/*.nb41 =*/ nb41,
/*.nb42 =*/ nb42,
/*.ns42 =*/ nb42/nb40,
/*.nb43 =*/ nb43,
/*.nb51 =*/ nb51,
/*.nb52 =*/ nb52,
/*.ns52 =*/ nb52/nb50,
/*.nb53 =*/ nb53,
/*.nb0 =*/ nb0,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);
GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
const size_t sms = ggml_metal_pipeline_get_smem(pipeline);
ggml_metal_encoder_set_pipeline(enc, pipeline);
@@ -1206,13 +1274,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_threadgroup_memory_size(enc, sms, 0);
if (ne30 == 1) {
// Mamba-2
ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
} else {
GGML_ASSERT(d_inner == 1);
ggml_metal_encoder_dispatch_threadgroups(enc, n_head, n_seqs, 1, d_state, 1, 1);
}
ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
return 1;
}
@@ -1273,26 +1335,23 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0);
// TODO: support
//const int32_t nk00 = ne00/ggml_blck_size(op->type);
const int32_t nk00 = ne00;
int nth = 32; // SIMD width
while (nth < nk00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth *= 2;
int64_t nk0 = ne00;
if (ggml_is_quantized(op->src[0]->type)) {
nk0 = ne00/16;
} else if (ggml_is_quantized(op->type)) {
nk0 = ne00/ggml_blck_size(op->type);
}
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
// when rows are small, we can batch them together in a single threadgroup
int nrptg = 1;
// TODO: relax this constraint in the future
if (ggml_blck_size(op->src[0]->type) == 1 && ggml_blck_size(op->type) == 1) {
if (nth > nk00) {
nrptg = (nth + nk00 - 1)/nk00;
nth = nk00;
if (nth > nk0) {
nrptg = (nth + nk0 - 1)/nk0;
nth = nk0;
if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nrptg--;
@@ -1300,10 +1359,11 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
}
}
nth = std::min(nth, nk00);
nth = std::min<int>(nth, nk0);
ggml_metal_kargs_cpy args = {
/*.ne00 =*/ nk00,
/*.nk0 =*/ nk0,
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
@@ -1321,12 +1381,14 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
/*.nb3 =*/ nb3,
};
const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, nrptg, 1);
ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1);
return 1;
}
@@ -1520,9 +1582,8 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
!ggml_is_transposed(op->src[1]) &&
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
props_dev->has_simdgroup_mm && ne00 >= 64 &&
(ne11 > ne11_mm_min || (ggml_is_quantized(op->src[0]->type) && ne12 > 1))) {
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
props_dev->has_simdgroup_mm && ne00 >= 64 && ne11 > ne11_mm_min) {
//GGML_LOG_INFO("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
// some Metal matrix data types require aligned pointers
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
@@ -1875,20 +1936,107 @@ bool ggml_metal_op_flash_attn_ext_use_vec(const ggml_tensor * op) {
return (ne01 < 20) && (ne00 % 32 == 0);
}
size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
size_t res = 0;
const bool has_mask = op->src[3] != nullptr;
if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
if (has_kvpad) {
res += OP_FLASH_ATTN_EXT_VEC_NCPSG*(
nb11*ne12*ne13 +
nb21*ne22*ne23 +
(has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
}
} else {
const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
if (has_kvpad) {
res += OP_FLASH_ATTN_EXT_NCPSG*(
nb11*ne12*ne13 +
nb21*ne22*ne23 +
(has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
}
}
return res;
}
size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
//GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
//GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
//GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
//GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
//GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
size_t res = 0;
const bool has_mask = op->src[3] != nullptr;
if (!has_mask) {
return res;
}
const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec(op);
// this optimization is not useful for the vector kernels
if (is_vec) {
return res;
}
const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
const int64_t ne1 = (ne01 + nqptg - 1)/nqptg;
const int64_t ne0 = (ne30 + ncpsg - 1)/ncpsg;
res += GGML_PAD(ggml_type_size(GGML_TYPE_I8)*ne0*ne1*ne32*ne33, 32);
return res;
}
size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
const int64_t nwg = 32;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
//GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
//GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
//GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
//GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
const int64_t ne01 = op->src[0]->ne[1];
const int64_t ne02 = op->src[0]->ne[2];
const int64_t ne03 = op->src[0]->ne[3];
const int64_t ne20 = op->src[2]->ne[0];
size_t res = 0;
// temp buffer for writing the results from each workgroup
// - ne20: the size of the Value head
// - + 2: the S and M values for each intermediate result
return ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
const int64_t nwg = 32;
// temp buffer for writing the results from each workgroup
// - ne20: the size of the Value head
// - + 2: the S and M values for each intermediate result
res += ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
}
return res;
}
int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
@@ -1910,8 +2058,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS( int32_t, nb, op, nb);
GGML_ASSERT(ne00 % 4 == 0);
GGML_ASSERT(ne11 % 32 == 0);
GGML_ASSERT(ne00 % 4 == 0);
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(op->src[1]->type == op->src[2]->type);
@@ -1921,8 +2068,8 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
GGML_ASSERT(ne12 == ne22);
GGML_ASSERT(!op->src[3] || op->src[3]->type == GGML_TYPE_F16);
GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= GGML_PAD(op->src[0]->ne[1], 8) &&
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= op->src[0]->ne[1] &&
"the Flash-Attention Metal kernel requires the mask to be at least n_queries big");
float scale;
float max_bias;
@@ -1949,15 +2096,111 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
GGML_ASSERT(ne01 < 65536);
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]);
ggml_metal_buffer_id bid_src3 = has_mask ? ggml_metal_get_buffer_id(op->src[3]) : bid_src0;
ggml_metal_buffer_id bid_src4 = has_sinks ? ggml_metal_get_buffer_id(op->src[4]) : bid_src0;
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
ggml_metal_buffer_id bid_pad = bid_dst;
bid_pad.offs += ggml_nbytes(op);
ggml_metal_buffer_id bid_blk = bid_pad;
bid_blk.offs += ggml_metal_op_flash_attn_ext_extra_pad(op);
ggml_metal_buffer_id bid_tmp = bid_blk;
bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_blk(op);
if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
// half8x8 kernel
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
const int64_t ncpsg = 64; // cache values per simdgroup !! sync with kernel template arguments !!
const int nqptg = OP_FLASH_ATTN_EXT_NQPTG; // queries per threadgroup
const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup
GGML_ASSERT(nqptg <= 32);
GGML_ASSERT(nqptg % 8 == 0);
GGML_ASSERT(ncpsg % 32 == 0);
bool need_sync = false;
const bool has_kvpad = ne11 % ncpsg != 0;
if (has_kvpad) {
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
ggml_metal_kargs_flash_attn_ext_pad args0 = {
/*.ne11 =*/ne11,
/*.ne_12_2 =*/ne12,
/*.ne_12_3 =*/ne13,
/*.nb11 =*/nb11,
/*.nb12 =*/nb12,
/*.nb13 =*/nb13,
/*.nb21 =*/nb21,
/*.nb22 =*/nb22,
/*.nb23 =*/nb23,
/*.ne31 =*/ne31,
/*.ne32 =*/ne32,
/*.ne33 =*/ne33,
/*.nb31 =*/nb31,
/*.nb32 =*/nb32,
/*.nb33 =*/nb33,
};
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
ggml_metal_encoder_set_pipeline(enc, pipeline0);
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
assert(ne12 == ne22);
assert(ne13 == ne23);
ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
need_sync = true;
} else {
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
}
if (has_mask) {
assert(ggml_metal_op_flash_attn_ext_extra_blk(op) != 0);
ggml_metal_kargs_flash_attn_ext_blk args0 = {
/*.ne01 =*/ ne01,
/*.ne30 =*/ ne30,
/*.ne31 =*/ ne31,
/*.ne32 =*/ ne32,
/*.ne33 =*/ ne33,
/*.nb31 =*/ nb31,
/*.nb32 =*/ nb32,
/*.nb33 =*/ nb33,
};
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg);
ggml_metal_encoder_set_pipeline(enc, pipeline0);
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
ggml_metal_encoder_set_buffer (enc, bid_src3, 1);
ggml_metal_encoder_set_buffer (enc, bid_blk, 2);
const int32_t nblk1 = ((ne01 + nqptg - 1)/nqptg);
const int32_t nblk0 = ((ne30 + ncpsg - 1)/ncpsg);
ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1);
need_sync = true;
} else {
assert(ggml_metal_op_flash_attn_ext_extra_blk(op) == 0);
}
if (need_sync) {
ggml_metal_op_concurrency_reset(ctx);
}
const int is_q = ggml_is_quantized(op->src[1]->type) ? 1 : 0;
// 2*(2*ncpsg)
@@ -2007,6 +2250,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
/*.nb21 =*/ nb21,
/*.nb22 =*/ nb22,
/*.nb23 =*/ nb23,
/*.ne31 =*/ ne31,
/*.ne32 =*/ ne32,
/*.ne33 =*/ ne33,
/*.nb31 =*/ nb31,
@@ -2023,24 +2267,18 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
/*.logit_softcap =*/ logit_softcap,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
if (op->src[3]) {
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4);
} else {
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4);
}
if (op->src[4]) {
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5);
} else {
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5);
}
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 6);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
ggml_metal_encoder_set_buffer (enc, bid_pad, 6);
ggml_metal_encoder_set_buffer (enc, bid_blk, 7);
ggml_metal_encoder_set_buffer (enc, bid_dst, 8);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
@@ -2048,14 +2286,62 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
#undef FATTN_SMEM
} else {
// half4x4 kernel
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
const int64_t nkpsg = 1*ncpsg;
const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG; // queries per threadgroup
const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !!
const int nkpsg = 1*ncpsg;
GGML_ASSERT(nqptg <= 32);
GGML_ASSERT(nqptg % 1 == 0);
GGML_ASSERT(ncpsg % 32 == 0);
bool need_sync = false;
const bool has_kvpad = ne11 % ncpsg != 0;
if (has_kvpad) {
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
ggml_metal_kargs_flash_attn_ext_pad args0 = {
/*.ne11 =*/ne11,
/*.ne_12_2 =*/ne12,
/*.ne_12_3 =*/ne13,
/*.nb11 =*/nb11,
/*.nb12 =*/nb12,
/*.nb13 =*/nb13,
/*.nb21 =*/nb21,
/*.nb22 =*/nb22,
/*.nb23 =*/nb23,
/*.ne31 =*/ne31,
/*.ne32 =*/ne32,
/*.ne33 =*/ne33,
/*.nb31 =*/nb31,
/*.nb32 =*/nb32,
/*.nb33 =*/nb33,
};
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
ggml_metal_encoder_set_pipeline(enc, pipeline0);
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
assert(ne12 == ne22);
assert(ne13 == ne23);
ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
need_sync = true;
} else {
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
}
if (need_sync) {
ggml_metal_op_concurrency_reset(ctx);
}
// ne00 + 2*ncpsg*(nsg)
// for each query, we load it as f16 in shared memory (ne00)
// and store the soft_max values and the mask
@@ -2120,6 +2406,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
/*.nb21 =*/ nb21,
/*.nb22 =*/ nb22,
/*.nb23 =*/ nb23,
/*.ne31 =*/ ne31,
/*.ne32 =*/ ne32,
/*.ne33 =*/ ne33,
/*.nb31 =*/ nb31,
@@ -2136,25 +2423,17 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
/*.logit_softcap =*/ logit_softcap,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg, nwg);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
if (op->src[3]) {
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4);
} else {
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4);
}
if (op->src[4]) {
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5);
} else {
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5);
}
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
const size_t smem = FATTN_SMEM(nsg);
@@ -2162,23 +2441,25 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size);
if (nwg == 1) {
assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) == 0);
// using 1 workgroup -> write the result directly into dst
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 6);
ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
ggml_metal_encoder_set_buffer(enc, bid_dst, 7);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
} else {
// sanity checks
assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0);
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31));
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
// write the results from each workgroup into a temp buffer
ggml_metal_buffer_id bid_tmp = bid_dst;
bid_tmp.offs += ggml_nbytes(op);
ggml_metal_encoder_set_buffer(enc, bid_tmp, 6);
ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
@@ -3156,3 +3437,73 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
return 1;
}
int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
const int64_t np = ggml_nelements(op->src[0]);
ggml_metal_kargs_opt_step_adamw args = {
/*.np =*/ np,
};
int ida = 0;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
const int64_t n = (np + nth - 1) / nth;
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
return 1;
}
int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
const int64_t np = ggml_nelements(op->src[0]);
ggml_metal_kargs_opt_step_sgd args = {
/*.np =*/ np,
};
int ida = 0;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
const int64_t n = (np + nth - 1) / nth;
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
return 1;
}
+5
View File
@@ -39,6 +39,8 @@ size_t ggml_metal_op_mul_mat_id_extra_ids(const struct ggml_tensor * op);
// return true if we should use the FA vector kernel for this op
bool ggml_metal_op_flash_attn_ext_use_vec(const struct ggml_tensor * op);
size_t ggml_metal_op_flash_attn_ext_extra_pad(const struct ggml_tensor * op);
size_t ggml_metal_op_flash_attn_ext_extra_blk(const struct ggml_tensor * op);
size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op);
int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx);
@@ -48,6 +50,7 @@ int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx);
@@ -76,6 +79,8 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);
#ifdef __cplusplus
}
+3 -3
View File
@@ -193,9 +193,9 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_
} break;
case GGML_OP_FLASH_ATTN_EXT:
{
if (ggml_metal_op_flash_attn_ext_use_vec(tensor)) {
res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
}
res += ggml_metal_op_flash_attn_ext_extra_pad(tensor);
res += ggml_metal_op_flash_attn_ext_extra_blk(tensor);
res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
} break;
default:
break;
File diff suppressed because it is too large Load Diff
+2
View File
@@ -30,6 +30,8 @@ if (MUSAToolkit_FOUND)
list(APPEND GGML_HEADERS_MUSA "../ggml-musa/mudnn.cuh")
file(GLOB GGML_SOURCES_MUSA "../ggml-cuda/*.cu")
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-tile*.cu")
list(APPEND GGML_SOURCES_MUSA ${SRCS})
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
list(APPEND GGML_SOURCES_MUSA ${SRCS})
file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
+301 -137
View File
@@ -105,9 +105,12 @@ enum rpc_cmd {
RPC_CMD_INIT_TENSOR,
RPC_CMD_GET_ALLOC_SIZE,
RPC_CMD_HELLO,
RPC_CMD_DEVICE_COUNT,
RPC_CMD_COUNT,
};
static_assert(RPC_CMD_HELLO == 14, "RPC_CMD_HELLO must be always 14");
// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
@@ -117,7 +120,12 @@ struct rpc_msg_hello_rsp {
uint8_t patch;
};
struct rpc_msg_device_count_rsp {
uint32_t device_count;
};
struct rpc_msg_get_alloc_size_req {
uint32_t device;
rpc_tensor tensor;
};
@@ -130,6 +138,7 @@ struct rpc_msg_init_tensor_req {
};
struct rpc_msg_alloc_buffer_req {
uint32_t device;
uint64_t size;
};
@@ -138,10 +147,18 @@ struct rpc_msg_alloc_buffer_rsp {
uint64_t remote_size;
};
struct rpc_msg_get_alignment_req {
uint32_t device;
};
struct rpc_msg_get_alignment_rsp {
uint64_t alignment;
};
struct rpc_msg_get_max_size_req {
uint32_t device;
};
struct rpc_msg_get_max_size_rsp {
uint64_t max_size;
};
@@ -192,6 +209,10 @@ struct rpc_msg_graph_compute_rsp {
uint8_t result;
};
struct rpc_msg_get_device_memory_req {
uint32_t device;
};
struct rpc_msg_get_device_memory_rsp {
uint64_t free_mem;
uint64_t total_mem;
@@ -207,13 +228,15 @@ static ggml_guid_t ggml_backend_rpc_guid() {
struct ggml_backend_rpc_buffer_type_context {
std::string endpoint;
uint32_t device;
std::string name;
size_t alignment;
size_t max_size;
size_t alignment;
size_t max_size;
};
struct ggml_backend_rpc_context {
std::string endpoint;
uint32_t device;
std::string name;
};
@@ -608,23 +631,30 @@ static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, con
RPC_STATUS_ASSERT(status);
}
static bool ggml_backend_buffer_is_rpc(ggml_backend_buffer_t buffer) {
return buffer->iface.free_buffer == ggml_backend_rpc_buffer_free_buffer;
}
static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
// check if src and dst are on the same server
ggml_backend_buffer_t src_buffer = src->buffer;
ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
ggml_backend_buffer_t dst_buffer = dst->buffer;
ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
if (src_ctx->sock != dst_ctx->sock) {
return false;
if (ggml_backend_buffer_is_rpc(src->buffer)) {
// check if src and dst are on the same server
ggml_backend_buffer_t src_buffer = src->buffer;
ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
ggml_backend_buffer_t dst_buffer = dst->buffer;
ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
if (src_ctx->sock != dst_ctx->sock) {
return false;
}
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
rpc_msg_copy_tensor_req request;
request.src = serialize_tensor(src);
request.dst = serialize_tensor(dst);
rpc_msg_copy_tensor_rsp response;
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
return response.result;
}
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
rpc_msg_copy_tensor_req request;
request.src = serialize_tensor(src);
request.dst = serialize_tensor(dst);
rpc_msg_copy_tensor_rsp response;
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
return response.result;
return false;
}
static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
@@ -653,7 +683,7 @@ static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t
static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
rpc_msg_alloc_buffer_req request = {size};
rpc_msg_alloc_buffer_req request = {buft_ctx->device, size};
rpc_msg_alloc_buffer_rsp response;
auto sock = get_socket(buft_ctx->endpoint);
bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));
@@ -669,9 +699,10 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
}
}
static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
static size_t get_alignment(const std::shared_ptr<socket_t> & sock, uint32_t device) {
rpc_msg_get_alignment_req request = {device};
rpc_msg_get_alignment_rsp response;
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, nullptr, 0, &response, sizeof(response));
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, &request, sizeof(request), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
return response.alignment;
}
@@ -681,9 +712,10 @@ static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_typ
return buft_ctx->alignment;
}
static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
static size_t get_max_size(const std::shared_ptr<socket_t> & sock, uint32_t device) {
rpc_msg_get_max_size_req request = {device};
rpc_msg_get_max_size_rsp response;
bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, nullptr, 0, &response, sizeof(response));
bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, &request, sizeof(request), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
return response.max_size;
}
@@ -700,7 +732,7 @@ static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_ty
auto sock = get_socket(buft_ctx->endpoint);
rpc_msg_get_alloc_size_req request;
request.device = buft_ctx->device;
request.tensor = serialize_tensor(tensor);
rpc_msg_get_alloc_size_rsp response;
@@ -754,7 +786,7 @@ static void add_tensor(ggml_tensor * tensor, std::vector<rpc_tensor> & tensors,
tensors.push_back(serialize_tensor(tensor));
}
static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {
static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {
uint32_t n_nodes = cgraph->n_nodes;
std::vector<rpc_tensor> tensors;
std::unordered_set<ggml_tensor*> visited;
@@ -762,24 +794,29 @@ static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & o
add_tensor(cgraph->nodes[i], tensors, visited);
}
// serialization format:
// | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
// | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
uint32_t n_tensors = tensors.size();
int output_size = sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
int output_size = 2*sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
output.resize(output_size, 0);
memcpy(output.data(), &n_nodes, sizeof(n_nodes));
uint8_t * dest = output.data();
memcpy(dest, &device, sizeof(device));
dest += sizeof(device);
memcpy(dest, &n_nodes, sizeof(n_nodes));
dest += sizeof(n_nodes);
for (uint32_t i = 0; i < n_nodes; i++) {
memcpy(output.data() + sizeof(n_nodes) + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t));
memcpy(dest + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t));
}
uint32_t * out_ntensors = (uint32_t *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t));
*out_ntensors = n_tensors;
rpc_tensor * out_tensors = (rpc_tensor *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t));
dest += n_nodes * sizeof(uint64_t);
memcpy(dest, &n_tensors, sizeof(n_tensors));
dest += sizeof(n_tensors);
rpc_tensor * out_tensors = (rpc_tensor *)dest;
memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor));
}
static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
std::vector<uint8_t> input;
serialize_graph(cgraph, input);
serialize_graph(rpc_ctx->device, cgraph, input);
rpc_msg_graph_compute_rsp response;
auto sock = get_socket(rpc_ctx->endpoint);
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
@@ -804,12 +841,13 @@ static ggml_backend_i ggml_backend_rpc_interface = {
/* .graph_optimize = */ NULL,
};
ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device) {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
std::string buft_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]";
// NOTE: buffer types are allocated and never freed; this is by design
static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map;
auto it = buft_map.find(endpoint);
auto it = buft_map.find(buft_name);
if (it != buft_map.end()) {
return it->second;
}
@@ -818,34 +856,37 @@ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
GGML_LOG_ERROR("Failed to connect to %s\n", endpoint);
return nullptr;
}
size_t alignment = get_alignment(sock);
size_t max_size = get_max_size(sock);
size_t alignment = get_alignment(sock, device);
size_t max_size = get_max_size(sock, device);
ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
/* .endpoint = */ endpoint,
/* .name = */ "RPC[" + std::string(endpoint) + "]",
/* .device = */ device,
/* .name = */ buft_name,
/* .alignment = */ alignment,
/* .max_size = */ max_size
};
auto reg = ggml_backend_rpc_add_server(endpoint);
ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
/* .iface = */ ggml_backend_rpc_buffer_type_interface,
/* .device = */ ggml_backend_rpc_add_device(endpoint),
/* .device = */ ggml_backend_reg_dev_get(reg, device),
/* .context = */ buft_ctx
};
buft_map[endpoint] = buft;
buft_map[buft_name] = buft;
return buft;
}
ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) {
std::string dev_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]";
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
/* .endpoint = */ endpoint,
/* .name = */ "RPC[" + std::string(endpoint) + "]",
/* .endpoint = */ endpoint,
/* .device = */ device,
/* .name = */ dev_name
};
auto reg = ggml_backend_rpc_add_server(endpoint);
ggml_backend_t backend = new ggml_backend {
/* .guid = */ ggml_backend_rpc_guid(),
/* .iface = */ ggml_backend_rpc_interface,
/* .device = */ ggml_backend_rpc_add_device(endpoint),
/* .device = */ ggml_backend_reg_dev_get(reg, device),
/* .context = */ ctx
};
return backend;
@@ -855,37 +896,39 @@ bool ggml_backend_is_rpc(ggml_backend_t backend) {
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());
}
static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * free, size_t * total) {
static void get_device_memory(const std::shared_ptr<socket_t> & sock, uint32_t device, size_t * free, size_t * total) {
rpc_msg_get_device_memory_req request;
request.device = device;
rpc_msg_get_device_memory_rsp response;
bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr, 0, &response, sizeof(response));
bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, &request, sizeof(request), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
*free = response.free_mem;
*total = response.total_mem;
}
void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total) {
auto sock = get_socket(endpoint);
if (sock == nullptr) {
*free = 0;
*total = 0;
return;
}
get_device_memory(sock, free, total);
get_device_memory(sock, device, free, total);
}
// RPC server-side implementation
class rpc_server {
public:
rpc_server(ggml_backend_t backend, const char * cache_dir)
: backend(backend), cache_dir(cache_dir) {
rpc_server(std::vector<ggml_backend_t> backends, const char * cache_dir)
: backends(std::move(backends)), cache_dir(cache_dir) {
}
~rpc_server();
void hello(rpc_msg_hello_rsp & response);
void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
void get_alignment(rpc_msg_get_alignment_rsp & response);
void get_max_size(rpc_msg_get_max_size_rsp & response);
bool alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
bool get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response);
bool get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response);
bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response);
bool free_buffer(const rpc_msg_free_buffer_req & request);
bool buffer_clear(const rpc_msg_buffer_clear_req & request);
@@ -906,7 +949,7 @@ private:
std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map);
ggml_backend_t backend;
std::vector<ggml_backend_t> backends;
const char * cache_dir;
std::unordered_set<ggml_backend_buffer_t> buffers;
};
@@ -919,6 +962,10 @@ void rpc_server::hello(rpc_msg_hello_rsp & response) {
}
bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
uint32_t dev_id = request.device;
if (dev_id >= backends.size()) {
return false;
}
ggml_backend_buffer_type_t buft;
struct ggml_init_params params {
/*.mem_size =*/ ggml_tensor_overhead(),
@@ -935,10 +982,10 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_
GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n");
return false;
}
LOG_DBG("[%s] buffer: %p, data: %p\n", __func__, (void*)tensor->buffer, tensor->data);
LOG_DBG("[%s] device: %d, buffer: %p, data: %p\n", __func__, dev_id, (void*)tensor->buffer, tensor->data);
if (tensor->buffer == nullptr) {
//No buffer allocated.
buft = ggml_backend_get_default_buffer_type(backend);
buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
} else {
buft = tensor->buffer->buft;
}
@@ -948,33 +995,49 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_
return true;
}
void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
bool rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
uint32_t dev_id = request.device;
if (dev_id >= backends.size()) {
return false;
}
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size);
response.remote_ptr = 0;
response.remote_size = 0;
if (buffer != nullptr) {
response.remote_ptr = reinterpret_cast<uint64_t>(buffer);
response.remote_size = buffer->size;
LOG_DBG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size);
LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n",
__func__, dev_id, request.size, response.remote_ptr, response.remote_size);
buffers.insert(buffer);
} else {
LOG_DBG("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size);
LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> failed\n", __func__, dev_id, request.size);
}
return true;
}
void rpc_server::get_alignment(rpc_msg_get_alignment_rsp & response) {
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
bool rpc_server::get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response) {
uint32_t dev_id = request.device;
if (dev_id >= backends.size()) {
return false;
}
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
size_t alignment = ggml_backend_buft_get_alignment(buft);
LOG_DBG("[%s] alignment: %lu\n", __func__, alignment);
LOG_DBG("[%s] device: %d, alignment: %lu\n", __func__, dev_id, alignment);
response.alignment = alignment;
return true;
}
void rpc_server::get_max_size(rpc_msg_get_max_size_rsp & response) {
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
bool rpc_server::get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response) {
uint32_t dev_id = request.device;
if (dev_id >= backends.size()) {
return false;
}
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
size_t max_size = ggml_backend_buft_get_max_size(buft);
LOG_DBG("[%s] max_size: %lu\n", __func__, max_size);
LOG_DBG("[%s] device: %d, max_size: %lu\n", __func__, dev_id, max_size);
response.max_size = max_size;
return true;
}
bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) {
@@ -1332,23 +1395,33 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
// serialization format:
// | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
if (input.size() < sizeof(uint32_t)) {
// | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
if (input.size() < 2*sizeof(uint32_t)) {
return false;
}
const uint8_t * src = input.data();
uint32_t device;
memcpy(&device, src, sizeof(device));
src += sizeof(device);
if (device >= backends.size()) {
return false;
}
uint32_t n_nodes;
memcpy(&n_nodes, input.data(), sizeof(n_nodes));
if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) {
memcpy(&n_nodes, src, sizeof(n_nodes));
src += sizeof(n_nodes);
if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) {
return false;
}
const uint64_t * nodes = (const uint64_t *)(input.data() + sizeof(n_nodes));
const uint64_t * nodes = (const uint64_t *)src;
src += n_nodes*sizeof(uint64_t);
uint32_t n_tensors;
memcpy(&n_tensors, input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t), sizeof(n_tensors));
if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) {
memcpy(&n_tensors, src, sizeof(n_tensors));
src += sizeof(n_tensors);
if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) {
return false;
}
const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t) + sizeof(n_tensors));
LOG_DBG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors);
const rpc_tensor * tensors = (const rpc_tensor *)src;
LOG_DBG("[%s] device: %u, n_nodes: %u, n_tensors: %u\n", __func__, device, n_nodes, n_tensors);
size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
@@ -1380,7 +1453,7 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
return false;
}
}
ggml_status status = ggml_backend_graph_compute(backend, graph);
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
response.result = status;
return true;
}
@@ -1391,9 +1464,9 @@ rpc_server::~rpc_server() {
}
}
static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
sockfd_t sockfd, size_t free_mem, size_t total_mem) {
rpc_server server(backend, cache_dir);
static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const char * cache_dir,
sockfd_t sockfd, const std::vector<size_t> & free_mem, const std::vector<size_t> & total_mem) {
rpc_server server(backends, cache_dir);
uint8_t cmd;
if (!recv_data(sockfd, &cmd, 1)) {
return;
@@ -1425,13 +1498,26 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
// HELLO command is handled above
return;
}
case RPC_CMD_DEVICE_COUNT: {
if (!recv_msg(sockfd, nullptr, 0)) {
return;
}
rpc_msg_device_count_rsp response;
response.device_count = backends.size();
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
break;
}
case RPC_CMD_ALLOC_BUFFER: {
rpc_msg_alloc_buffer_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
rpc_msg_alloc_buffer_rsp response;
server.alloc_buffer(request, response);
if (!server.alloc_buffer(request, response)) {
return;
}
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
@@ -1452,22 +1538,28 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
break;
}
case RPC_CMD_GET_ALIGNMENT: {
if (!recv_msg(sockfd, nullptr, 0)) {
rpc_msg_get_alignment_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
rpc_msg_get_alignment_rsp response;
server.get_alignment(response);
if (!server.get_alignment(request, response)) {
return;
}
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
break;
}
case RPC_CMD_GET_MAX_SIZE: {
if (!recv_msg(sockfd, nullptr, 0)) {
rpc_msg_get_max_size_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
rpc_msg_get_max_size_rsp response;
server.get_max_size(response);
if (!server.get_max_size(request, response)) {
return;
}
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
@@ -1593,12 +1685,19 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
break;
}
case RPC_CMD_GET_DEVICE_MEMORY: {
if (!recv_msg(sockfd, nullptr, 0)) {
rpc_msg_get_device_memory_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
auto dev_id = request.device;
if (dev_id >= backends.size()) {
return;
}
rpc_msg_get_device_memory_rsp response;
response.free_mem = free_mem;
response.total_mem = total_mem;
response.free_mem = free_mem[dev_id];
response.total_mem = total_mem[dev_id];
LOG_DBG("[get_device_mem] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", dev_id,
response.free_mem, response.total_mem);
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
@@ -1612,16 +1711,41 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
}
}
void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint,
const char * cache_dir,
size_t free_mem, size_t total_mem) {
void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
size_t n_threads, size_t n_devices,
ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem) {
if (n_devices == 0 || devices == nullptr || free_mem == nullptr || total_mem == nullptr) {
fprintf(stderr, "Invalid arguments to ggml_backend_rpc_start_server\n");
return;
}
std::vector<ggml_backend_t> backends;
std::vector<size_t> free_mem_vec(free_mem, free_mem + n_devices);
std::vector<size_t> total_mem_vec(total_mem, total_mem + n_devices);
printf("Starting RPC server v%d.%d.%d\n",
RPC_PROTO_MAJOR_VERSION,
RPC_PROTO_MINOR_VERSION,
RPC_PROTO_PATCH_VERSION);
printf(" endpoint : %s\n", endpoint);
printf(" local cache : %s\n", cache_dir ? cache_dir : "n/a");
printf(" backend memory : %zu MB\n", free_mem / (1024 * 1024));
printf("Devices:\n");
for (size_t i = 0; i < n_devices; i++) {
auto dev = devices[i];
printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev),
total_mem[i] / 1024 / 1024, free_mem[i] / 1024 / 1024);
auto backend = ggml_backend_dev_init(dev, nullptr);
if (!backend) {
fprintf(stderr, "Failed to create backend for device %s\n", dev->iface.get_name(dev));
return;
}
backends.push_back(backend);
ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
if (reg) {
auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
if (ggml_backend_set_n_threads_fn) {
ggml_backend_set_n_threads_fn(backend, n_threads);
}
}
}
std::string host;
int port;
@@ -1649,22 +1773,27 @@ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint
fprintf(stderr, "Failed to accept client connection\n");
return;
}
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
printf("Accepted client connection\n");
fflush(stdout);
rpc_serve_client(backend, cache_dir, client_socket->fd, free_mem, total_mem);
rpc_serve_client(backends, cache_dir, client_socket->fd, free_mem_vec, total_mem_vec);
printf("Client connection closed\n");
fflush(stdout);
}
#ifdef _WIN32
WSACleanup();
#endif
for (auto backend : backends) {
ggml_backend_free(backend);
}
}
// device interface
struct ggml_backend_rpc_device_context {
std::string endpoint;
uint32_t device;
std::string name;
std::string description;
};
static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
@@ -1676,15 +1805,13 @@ static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) {
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
return ctx->name.c_str();
return ctx->description.c_str();
}
static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total);
GGML_UNUSED(dev);
ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), ctx->device, free, total);
}
static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) {
@@ -1710,7 +1837,7 @@ static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggm
static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) {
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
return ggml_backend_rpc_init(ctx->endpoint.c_str());
return ggml_backend_rpc_init(ctx->endpoint.c_str(), ctx->device);
GGML_UNUSED(params);
}
@@ -1718,7 +1845,7 @@ static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const
static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) {
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str(), ctx->device);
GGML_UNUSED(dev);
}
@@ -1736,7 +1863,7 @@ static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_b
}
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context;
return buft_ctx->endpoint == dev_ctx->endpoint;
return buft_ctx->endpoint == dev_ctx->endpoint && buft_ctx->device == dev_ctx->device;
}
static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
@@ -1759,28 +1886,34 @@ static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
// backend reg interface
static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
return "RPC";
struct ggml_backend_rpc_reg_context {
std::string name;
std::vector<ggml_backend_dev_t> devices;
};
GGML_UNUSED(reg);
static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
return ctx ? ctx->name.c_str() : "RPC";
}
static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) {
return 0;
GGML_UNUSED(reg);
ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
return ctx ? ctx->devices.size() : 0;
}
static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) {
GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_add_device instead");
GGML_UNUSED(reg);
GGML_UNUSED(index);
ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
if (ctx == nullptr) {
GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_rpc_add_server instead");
} else {
GGML_ASSERT(index < ctx->devices.size());
return ctx->devices[index];
}
}
static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) {
if (std::strcmp(name, "ggml_backend_rpc_add_device") == 0) {
return (void *)ggml_backend_rpc_add_device;
if (std::strcmp(name, "ggml_backend_rpc_add_server") == 0) {
return (void *)ggml_backend_rpc_add_server;
}
if (std::strcmp(name, "ggml_backend_rpc_start_server") == 0) {
return (void *)ggml_backend_rpc_start_server;
@@ -1807,30 +1940,61 @@ ggml_backend_reg_t ggml_backend_rpc_reg(void) {
return &ggml_backend_rpc_reg;
}
ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint) {
static std::unordered_map<std::string, ggml_backend_dev_t> dev_map;
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
if (dev_map.find(endpoint) != dev_map.end()) {
return dev_map[endpoint];
}
ggml_backend_rpc_device_context * ctx = new ggml_backend_rpc_device_context {
/* .endpoint = */ endpoint,
/* .name = */ "RPC[" + std::string(endpoint) + "]",
};
ggml_backend_dev_t dev = new ggml_backend_device {
/* .iface = */ ggml_backend_rpc_device_i,
/* .reg = */ ggml_backend_rpc_reg(),
/* .context = */ ctx,
};
dev_map[endpoint] = dev;
return dev;
static uint32_t ggml_backend_rpc_get_device_count(const char * endpoint) {
auto sock = get_socket(endpoint);
rpc_msg_device_count_rsp response;
bool status = send_rpc_cmd(sock, RPC_CMD_DEVICE_COUNT, nullptr, 0, &response, sizeof(response));
RPC_STATUS_ASSERT(status);
return response.device_count;
}
static const ggml_backend_reg_i ggml_backend_rpc_reg_interface = {
/* .get_name = */ ggml_backend_rpc_reg_get_name,
/* .get_device_count = */ ggml_backend_rpc_reg_get_device_count,
/* .get_device = */ ggml_backend_rpc_reg_get_device,
/* .get_proc_address = */ ggml_backend_rpc_get_proc_address,
};
ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint) {
static std::unordered_map<std::string, ggml_backend_reg_t> reg_map;
static std::mutex mutex;
static uint32_t dev_id = 0;
std::lock_guard<std::mutex> lock(mutex);
if (reg_map.find(endpoint) != reg_map.end()) {
return reg_map[endpoint];
}
uint32_t dev_count = ggml_backend_rpc_get_device_count(endpoint);
if (dev_count == 0) {
return nullptr;
}
ggml_backend_rpc_reg_context * ctx = new ggml_backend_rpc_reg_context;
ctx->name = "RPC[" + std::string(endpoint) + "]";
for (uint32_t ind = 0; ind < dev_count; ind++) {
std::string dev_name = "RPC" + std::to_string(dev_id);
std::string dev_desc = std::string(endpoint);
ggml_backend_rpc_device_context * dev_ctx = new ggml_backend_rpc_device_context {
/* .endpoint = */ endpoint,
/* .device = */ ind,
/* .name = */ dev_name,
/* .description = */ dev_desc
};
ggml_backend_dev_t dev = new ggml_backend_device {
/* .iface = */ ggml_backend_rpc_device_i,
/* .reg = */ ggml_backend_rpc_reg(),
/* .context = */ dev_ctx,
};
ctx->devices.push_back(dev);
dev_id++;
}
ggml_backend_reg_t reg = new ggml_backend_reg {
/* .api_version = */ GGML_BACKEND_API_VERSION,
/* .iface = */ ggml_backend_rpc_reg_interface,
/* .context = */ ctx
};
reg_map[endpoint] = reg;
return reg;
}
GGML_BACKEND_DL_IMPL(ggml_backend_rpc_reg)
+2
View File
@@ -18,6 +18,7 @@
#include "concat.hpp"
#include "conv.hpp"
#include "convert.hpp"
#include "count-equal.hpp"
#include "cpy.hpp"
#include "dequantize.hpp"
#include "dmmv.hpp"
@@ -28,6 +29,7 @@
#include "mmvq.hpp"
#include "norm.hpp"
#include "outprod.hpp"
#include "pad.hpp"
#include "quantize.hpp"
#include "quants.hpp"
#include "rope.hpp"
-9
View File
@@ -303,10 +303,6 @@ inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(ctx, dst->src[0], dst->src[1], dst);
}
inline void ggml_sycl_op_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_count_equal>>(ctx, dst->src[0], dst->src[1], dst);
}
inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(ctx, dst->src[0], dst->src[1], dst);
@@ -332,11 +328,6 @@ void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_op_sub(ctx, dst);
}
void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
ggml_sycl_op_count_equal(ctx, dst);
}
void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
ggml_sycl_op_mul(ctx, dst);
-6
View File
@@ -16,12 +16,6 @@ static __dpct_inline__ float op_sub(const float a, const float b) {
return a - b;
}
static __dpct_inline__ float op_count_equal(const float a, const float b) {
return (a == b) ? 1.0f : 0.0f;
}
void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
static __dpct_inline__ float op_mul(const float a, const float b) {
return a * b;
}
+74 -15
View File
@@ -195,8 +195,10 @@ struct optimize_feature {
struct sycl_device_info {
int cc; // compute capability
// int nsm; // number of streaming multiprocessors
int nsm; // number of streaming multiprocessors (CUDA) maps to the maximum
// number of compute units on a SYCL device.
// size_t smpb; // max. shared memory per block
size_t smpbo; // max. shared memory per block (with opt-in)
bool vmm; // virtual memory support
size_t total_vram;
//sycl_hw_info hw_info; \\ device id and aarch, currently not used
@@ -416,13 +418,6 @@ static __dpct_inline__ float warp_reduce_sum(float x,
const sycl::nd_item<3>& item_ct1) {
#pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
/*
DPCT1096:98: The right-most dimension of the work-group used in the SYCL
kernel that calls this function may be less than "32". The function
"dpct::permute_sub_group_by_xor" may return an unexpected result on the
CPU device. Modify the size of the work-group to ensure that the value
of the right-most dimension is a multiple of "32".
*/
x += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), x, mask);
}
return x;
@@ -440,17 +435,67 @@ warp_reduce_sum(sycl::float2 a, const sycl::nd_item<3>& item_ct1) {
return a;
}
template <int width = WARP_SIZE>
static __dpct_inline__ int warp_reduce_sum(int x) {
return sycl::reduce_over_group(
sycl::ext::oneapi::this_work_item::get_sub_group(), x, sycl::plus<>());
}
template <int width = WARP_SIZE>
static __dpct_inline__ float warp_reduce_sum(float x) {
#pragma unroll
for (int offset = width / 2; offset > 0; offset >>= 1) {
x += dpct::permute_sub_group_by_xor(
sycl::ext::oneapi::this_work_item::get_sub_group(), x, offset, width);
}
return x;
}
template <int width = WARP_SIZE>
static __dpct_inline__ sycl::float2 warp_reduce_sum(sycl::float2 a) {
#pragma unroll
for (int offset = width / 2; offset > 0; offset >>= 1) {
a.x() += dpct::permute_sub_group_by_xor(
sycl::ext::oneapi::this_work_item::get_sub_group(), a.x(), offset,
width);
a.y() += dpct::permute_sub_group_by_xor(
sycl::ext::oneapi::this_work_item::get_sub_group(), a.y(), offset,
width);
}
return a;
}
template <int width = WARP_SIZE>
static __dpct_inline__ sycl::half2 warp_reduce_sum(sycl::half2 a) {
#pragma unroll
for (int offset = width / 2; offset > 0; offset >>= 1) {
a = a + dpct::permute_sub_group_by_xor(
sycl::ext::oneapi::this_work_item::get_sub_group(), a, offset,
width);
}
return a;
}
static constexpr int ggml_sycl_get_physical_warp_size() {
// todo: for old iGPU + dGPU case, need to be changed.
return WARP_SIZE;
}
template <int width = WARP_SIZE>
static __dpct_inline__ float warp_reduce_max(float x) {
#pragma unroll
for (int offset = width / 2; offset > 0; offset >>= 1) {
x = sycl::fmax(x, dpct::permute_sub_group_by_xor(
sycl::ext::oneapi::this_work_item::get_sub_group(), x,
offset, width));
}
return x;
}
static __dpct_inline__ float warp_reduce_max(float x,
const sycl::nd_item<3>& item_ct1) {
#pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
/*
DPCT1096:97: The right-most dimension of the work-group used in the SYCL
kernel that calls this function may be less than "32". The function
"dpct::permute_sub_group_by_xor" may return an unexpected result on the
CPU device. Modify the size of the work-group to ensure that the value
of the right-most dimension is a multiple of "32".
*/
x = sycl::fmax(x, dpct::permute_sub_group_by_xor(
item_ct1.get_sub_group(), x, mask));
}
@@ -558,4 +603,18 @@ struct scope_op_debug_print {
std::string_view func_suffix;
};
static __dpct_inline__ float get_alibi_slope(const float max_bias,
const uint32_t h,
const uint32_t n_head_log2,
const float m0,
const float m1) {
if (max_bias <= 0.0f) {
return 1.0f;
}
const float base = h < n_head_log2 ? m0 : m1;
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
return dpct::pow(base, exph);
}
#endif // GGML_SYCL_COMMON_HPP
+79
View File
@@ -0,0 +1,79 @@
#include "count-equal.hpp"
#include <cstdint>
template <typename T>
static void count_equal(const T *__restrict__ x, const T *__restrict__ y,
int64_t *__restrict__ dst, const int64_t dk,
const int64_t k) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
const int64_t i0 = (int64_t)item_ct1.get_group(2) * dk;
const int64_t i1 = sycl::min(i0 + dk, k);
int nequal = 0;
for (int64_t i = i0 + item_ct1.get_local_id(2); i < i1; i += WARP_SIZE) {
const T xi = x[i];
const T yi = y[i];
nequal += xi == yi;
}
nequal = warp_reduce_sum(nequal);
if (item_ct1.get_local_id(2) != 0) {
return;
}
dpct::atomic_fetch_add<sycl::access::address_space::generic_space>(
(int *)dst, nequal);
}
void ggml_sycl_count_equal(ggml_backend_sycl_context &ctx, ggml_tensor *dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_ASSERT(src0->type == src1->type);
GGML_ASSERT( dst->type == GGML_TYPE_I64);
GGML_ASSERT(ggml_are_same_shape(src0, src1));
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(ggml_is_contiguous(dst));
int64_t * dst_d = (int64_t *) dst->data;
dpct::queue_ptr stream = ctx.stream();
const int id = get_current_device_id();
const int nsm = ggml_sycl_info().devices[id].nsm;
const int64_t ne = ggml_nelements(src0);
GGML_ASSERT(ne < (1 << 30) && "atomicAdd implementation only supports int");
const int64_t dne =
GGML_PAD((ne + 4 * nsm - 1) / (4 * nsm), SYCL_COUNT_EQUAL_CHUNK_SIZE);
SYCL_CHECK(CHECK_TRY_ERROR(stream->memset(dst_d, 0, ggml_nbytes(dst))));
const dpct::dim3 block_dims(WARP_SIZE, 1, 1);
const dpct::dim3 block_nums(
std::min((int64_t)4 * nsm, (ne + SYCL_COUNT_EQUAL_CHUNK_SIZE - 1) /
SYCL_COUNT_EQUAL_CHUNK_SIZE),
1, 1);
switch (src0->type) {
case GGML_TYPE_I32: {
const int *src0_d = (const int *)src0->data;
const int *src1_d = (const int *)src1->data;
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
count_equal(src0_d, src1_d, dst_d, dne, ne);
GGML_UNUSED(item_ct1);
});
} break;
default:
GGML_ASSERT(false);
break;
}
}
+9
View File
@@ -0,0 +1,9 @@
#ifndef GGML_SYCL_COUNT_EQUAL_HPP
#define GGML_SYCL_COUNT_EQUAL_HPP
#include "common.hpp"
#define SYCL_COUNT_EQUAL_CHUNK_SIZE 128
void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
#endif //GGML_SYCL_COUNT_EQUAL_HPP
+20
View File
@@ -277,6 +277,26 @@ namespace dpct
} // namespace detail
// COPY from DPCT head files
/// dim3 is used to store 3 component dimensions.
class dim3 {
public:
unsigned x, y, z;
constexpr dim3(unsigned x = 1, unsigned y = 1, unsigned z = 1)
: x(x), y(y), z(z) {}
dim3(const sycl::id<3> &r) : dim3(r[2], r[1], r[0]) {}
operator sycl::range<3>() const { return sycl::range<3>(z, y, x); }
}; // namespace dim3
inline dim3 operator*(const dim3 &a, const dim3 &b) {
return dim3{a.x * b.x, a.y * b.y, a.z * b.z};
}
// COPY from DPCT head files
/// Pitched 2D/3D memory data.
class pitched_data
{
-78
View File
@@ -328,26 +328,6 @@ static void upscale(const T *x, T *dst, const int nb00, const int nb01,
dst[index] = *(const T *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
}
template <typename T>
static void pad(const T *x, T *dst, const int ne0, const int ne00, const int ne01, const int ne02,
const sycl::nd_item<3> &item_ct1) {
int nidx = SYCL_LOCAL_ID_CALC(item_ct1, 2);
if (nidx >= ne0) {
return;
}
// operation
int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
if (nidx < ne00 && item_ct1.get_group(1) < (size_t) ne01 && item_ct1.get_group(0) < (size_t) ne02) {
int offset_src = nidx + item_ct1.get_group(1) * ne00 +
item_ct1.get_group(0) * ne00 * ne01;
dst[offset_dst] = x[offset_src];
} else {
dst[offset_dst] = static_cast<T>(0.0f);
}
}
template<typename T>
static void clamp(const T * x, T * dst, const float min, const float max, const int k,
const sycl::nd_item<1> &item_ct1) {
@@ -431,18 +411,6 @@ static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
});
}
template<typename T>
static void pad_sycl(const T *x, T *dst, const int ne00,
const int ne01, const int ne02, const int ne0,
const int ne1, const int ne2, queue_ptr stream) {
int num_blocks = ceil_div(ne0, SYCL_PAD_BLOCK_SIZE);
sycl::range<3> gridDim(ne2, ne1, num_blocks);
stream->parallel_for(
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { pad(x, dst, ne0, ne00, ne01, ne02, item_ct1); });
}
template<typename KernelInvoker, typename... Args>
static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
#if defined (GGML_SYCL_F16)
@@ -596,40 +564,6 @@ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx
}
}
template<typename KernelInvoker, typename... Args>
static inline void dispatch_ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
#if defined (GGML_SYCL_F16)
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
#else
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
#endif
GGML_ASSERT(dst->src[0]->type == dst->type);
GGML_ASSERT(dst->src[0]->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
switch (dst->type) {
#if defined (GGML_SYCL_F16)
case GGML_TYPE_F16:
{
auto data_pts = cast_data<sycl::half>(dst);
kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0],
(int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward<Args>(args)...);
break;
}
#endif
case GGML_TYPE_F32:
{
auto data_pts = cast_data<float>(dst);
kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0],
(int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward<Args>(args)...);
break;
}
default:
GGML_ABORT("GGML tensor type not supported!\n");
}
}
} // namespace ggml_sycl_detail
@@ -919,14 +853,6 @@ static inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_te
});
}
static inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_detail::dispatch_ggml_sycl_op_pad(ctx, dst,
[](const auto* src, auto* dst_ptr, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2,
queue_ptr stream) {
ggml_sycl_detail::pad_sycl(src, dst_ptr, ne00, ne01, ne02, ne0, ne1, ne2, stream);
});
}
static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
float min_val;
float max_val;
@@ -1119,10 +1045,6 @@ void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_op_upscale(ctx, dst);
}
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_pad(ctx, dst);
}
void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
-2
View File
@@ -67,8 +67,6 @@ void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+75 -51
View File
@@ -85,7 +85,10 @@ static ggml_sycl_device_info ggml_sycl_init() {
info.devices[i].cc =
100 * prop.get_major_version() + 10 * prop.get_minor_version();
info.devices[i].nsm = prop.get_max_compute_units();
info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);
info.devices[i].smpbo = prop.get_local_mem_size();
info.max_work_group_sizes[i] = prop.get_max_work_group_size();
}
@@ -1511,60 +1514,70 @@ static inline void ggml_sycl_swap(T & a, T & b) {
template <ggml_sort_order order>
__dpct_inline__ static void
k_argsort_f32_i32(const float *x, int *dst, const int ncols, int ncols_pad,
const sycl::nd_item<3> &item_ct1, uint8_t *dpct_local) {
const int tasks_per_thread, const sycl::nd_item<3> &item_ct1,
uint8_t *dpct_local) {
// bitonic sort
int col = item_ct1.get_local_id(2);
int col_index = item_ct1.get_local_id(2);
int row = item_ct1.get_group(1);
if (col >= ncols_pad) {
return;
for (int i = 0; i < tasks_per_thread; i++) {
int col = col_index * tasks_per_thread + i;
if (col >= ncols_pad) {
return;
}
}
const float * x_row = x + row * ncols;
auto dst_row = (int *)dpct_local;
// initialize indices
dst_row[col] = col;
for (int i=0;i<tasks_per_thread;i++){
int col = col_index*tasks_per_thread+i;
dst_row[col] = col;
}
item_ct1.barrier(sycl::access::fence_space::local_space);
for (int k = 2; k <= ncols_pad; k *= 2) {
for (int j = k / 2; j > 0; j /= 2) {
int ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
if (dst_row[col] >= ncols ||
(dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
) {
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
}
} else {
if (dst_row[ixj] >= ncols ||
(dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
) {
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
for (int i = 0; i < tasks_per_thread; i++) {
int col = col_index * tasks_per_thread + i;
int ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
if (dst_row[col] >= ncols ||
(dst_row[ixj] < ncols &&
(order == GGML_SORT_ORDER_ASC
? x_row[dst_row[col]] > x_row[dst_row[ixj]]
: x_row[dst_row[col]] <
x_row[dst_row[ixj]]))) {
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
}
} else {
if (dst_row[ixj] >= ncols ||
(dst_row[col] < ncols &&
(order == GGML_SORT_ORDER_ASC
? x_row[dst_row[col]] < x_row[dst_row[ixj]]
: x_row[dst_row[col]] >
x_row[dst_row[ixj]]))) {
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
}
}
}
item_ct1.barrier(sycl::access::fence_space::local_space);
}
/*
DPCT1118:1: SYCL group functions and algorithms must be encountered
in converged control flow. You may need to adjust the code.
*/
item_ct1.barrier(sycl::access::fence_space::local_space);
}
}
// copy the result to dst without the padding
if (col < ncols) {
dst[row * ncols + col] = dst_row[col];
for (int i = 0; i < tasks_per_thread; i++) {
int col = col_index * tasks_per_thread + i;
if (col < ncols) {
dst[row * ncols + col] = dst_row[col];
}
}
}
static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past,
const sycl::nd_item<3> &item_ct1) {
const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
@@ -1737,11 +1750,20 @@ static int next_power_of_2(int x) {
static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
const int nrows, ggml_sort_order order,
queue_ptr stream) {
queue_ptr stream, int device) {
// bitonic sort requires ncols to be power of 2
const int ncols_pad = next_power_of_2(ncols);
const sycl::range<3> block_dims(1, 1, ncols_pad);
int nth = 1;
int max_block_size = ggml_sycl_info().max_work_group_sizes[device];
while (nth < ncols_pad && nth < max_block_size)
nth *= 2;
if (nth > max_block_size)
nth = max_block_size;
const int tasks_per_thread = ncols_pad / nth;
const sycl::range<3> block_dims(1, 1, nth);
const sycl::range<3> block_nums(1, nrows, 1);
const size_t shared_mem = ncols_pad * sizeof(int);
@@ -1754,8 +1776,9 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
x, dst, ncols, ncols_pad, item_ct1,
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
x, dst, ncols, ncols_pad, tasks_per_thread, item_ct1,
dpct_local_acc_ct1
.get_multi_ptr<sycl::access::decorated::no>()
.get());
});
});
@@ -1768,8 +1791,9 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
x, dst, ncols, ncols_pad, item_ct1,
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
x, dst, ncols, ncols_pad, tasks_per_thread, item_ct1,
dpct_local_acc_ct1
.get_multi_ptr<sycl::access::decorated::no>()
.get());
});
});
@@ -2141,7 +2165,8 @@ inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor *
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
argsort_f32_i32_sycl(src0_dd, (int *) dst_dd, ncols, nrows, order, main_stream);
argsort_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, order,
main_stream, ctx.device);
}
inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
@@ -3741,6 +3766,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_OP_SOFT_MAX:
ggml_sycl_op_soft_max(ctx, dst);
break;
case GGML_OP_SOFT_MAX_BACK:
ggml_sycl_op_soft_max_back(ctx, dst);
break;
case GGML_OP_ROPE:
ggml_sycl_rope(ctx, dst);
break;
@@ -3778,6 +3806,7 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
return true;
} catch (sycl::exception & e) {
std::cerr << e.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
std::cerr << "Error OP "<<ggml_op_name(dst->op)<< std::endl;
std::exit(1);
}
@@ -4386,19 +4415,15 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
return true;
case GGML_OP_CONT:
return op->src[0]->type != GGML_TYPE_BF16;
case GGML_OP_SOFT_MAX:
// TODO: support batching
if (op->src[0]->ne[3] != 1) {
return false;
}
// TODO: support attention sinks [TAG_ATTN_SINKS]
if (op->src[2]) {
return false;
}
// TODO: support broadcast
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
case GGML_OP_DIAG_MASK_INF:
return true;
case GGML_OP_SOFT_MAX:
return true;
case GGML_OP_SOFT_MAX_BACK: {
float max_bias = 0.0f;
memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
return max_bias == 0.0f;
}
case GGML_OP_ROPE:
case GGML_OP_IM2COL:
return true;
@@ -4412,8 +4437,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_ACC:
return true;
case GGML_OP_PAD:
return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
(ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
return ggml_is_contiguous(op->src[0]);
case GGML_OP_LEAKY_RELU:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_RWKV_WKV6:
+97
View File
@@ -0,0 +1,97 @@
//
// MIT license
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: MIT
//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//#include "common.hpp"
#include "pad.hpp"
static void pad_f32(const float * src, float * dst,
const int lp0, const int rp0, const int lp1, const int rp1,
const int lp2, const int rp2, const int lp3, const int rp3,
const int ne0, const int ne1, const int ne2, const int ne3) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
int i0 = item_ct1.get_local_id(2) +
item_ct1.get_group(2) * item_ct1.get_local_range(2);
int i1 = item_ct1.get_group(1);
int i2 = item_ct1.get_group(0) % ne2;
int i3 = item_ct1.get_group(0) / ne2;
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
return;
}
// operation
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
if ((i0 >= lp0 && i0 < ne0 - rp0) &&
(i1 >= lp1 && i1 < ne1 - rp1) &&
(i2 >= lp2 && i2 < ne2 - rp2) &&
(i3 >= lp3 && i3 < ne3 - rp3)) {
const int64_t i00 = i0 - lp0;
const int64_t i01 = i1 - lp1;
const int64_t i02 = i2 - lp2;
const int64_t i03 = i3 - lp3;
const int64_t ne02 = ne2 - lp2 - rp2;
const int64_t ne01 = ne1 - lp1 - rp1;
const int64_t ne00 = ne0 - lp0 - rp0;
const int64_t src_idx = i03 * (ne00 * ne01 * ne02) +
i02 * (ne00 * ne01) + i01 * ne00 + i00;
dst[dst_idx] = src[src_idx];
} else {
dst[dst_idx] = 0.0f;
}
}
static void pad_f32_sycl(const float *src, float *dst, const int lp0,
const int rp0, const int lp1, const int rp1,
const int lp2, const int rp2, const int lp3,
const int rp3, const int ne0, const int ne1,
const int ne2, const int ne3,
dpct::queue_ptr stream) {
int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;
dpct::dim3 gridDim(num_blocks, ne1, ne2 * ne3);
stream->parallel_for(
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
pad_f32(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1,
ne2, ne3);
});
}
void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
dpct::queue_ptr stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
const int32_t lp0 = ((const int32_t*)(dst->op_params))[0];
const int32_t rp0 = ((const int32_t*)(dst->op_params))[1];
const int32_t lp1 = ((const int32_t*)(dst->op_params))[2];
const int32_t rp1 = ((const int32_t*)(dst->op_params))[3];
const int32_t lp2 = ((const int32_t*)(dst->op_params))[4];
const int32_t rp2 = ((const int32_t*)(dst->op_params))[5];
const int32_t lp3 = ((const int32_t*)(dst->op_params))[6];
const int32_t rp3 = ((const int32_t*)(dst->op_params))[7];
pad_f32_sycl(src0_d, dst_d,
lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
}
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_pad(ctx, dst);
}
+24
View File
@@ -0,0 +1,24 @@
//
// MIT license
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: MIT
//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
#ifndef GGML_SYCL_PAD_HPP
#define GGML_SYCL_PAD_HPP
#include "common.hpp"
#define SYCL_PAD_BLOCK_SIZE 256
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
#endif // GGML_SYCL_PAD_HPP
+328 -163
View File
@@ -1,37 +1,94 @@
#include "softmax.hpp"
#include <cstdint>
#include <utility>
#include <cmath>
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
static void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par,
const int nrows_y, const float scale, const float max_bias, const float m0,
const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
const int tid = item_ct1.get_local_id(2);
const int rowx = item_ct1.get_group(2);
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
template <typename T> static __dpct_inline__ float t2f32(T val) {
return (float) val;
}
const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template;
template <> float __dpct_inline__ t2f32<sycl::half>(sycl::half val) {
return sycl::vec<sycl::half, 1>(val)
.convert<float, sycl::rounding_mode::automatic>()[0];
}
const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
struct soft_max_params {
int64_t nheads;
uint32_t n_head_log2;
int64_t ncols;
int64_t nrows_x;
int64_t nrows_y;
int64_t ne00;
int64_t ne01;
int64_t ne02;
int64_t ne03;
int64_t nb11;
int64_t nb12;
int64_t nb13;
int64_t ne12;
int64_t ne13;
float scale;
float max_bias;
float m0;
float m1;
};
// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.
// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here.
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wpass-failed"
#endif // __clang__
template <bool use_shared, int ncols_template, int block_size_template, typename T>
static void soft_max_f32(const float * x,
const T * mask,
const float * sinks,
float * dst,
const soft_max_params p,
uint8_t * dpct_local) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
const int block_size = block_size_template == 0
? item_ct1.get_local_range(2)
: block_size_template;
const int nthreads = block_size;
const int nwarps = nthreads / WARP_SIZE;
size_t nreduce = nwarps / WARP_SIZE;
float slope = 1.0f;
// ALiBi
if (max_bias > 0.0f) {
const uint32_t h = rowx/nrows_y; // head index
const int tid = item_ct1.get_local_id(2);
const float base = h < n_head_log2 ? m0 : m1;
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
const int64_t i03 = item_ct1.get_group(0);
const int64_t i02 = item_ct1.get_group(1);
const int64_t i01 = item_ct1.get_group(2);
slope = sycl::pow(base, float(exp));
}
//TODO: noncontigous inputs/outputs
const int rowx = item_ct1.get_group(2) +
item_ct1.get_group(1) * item_ct1.get_group_range(2) +
item_ct1.get_group(0) * item_ct1.get_group_range(2) *
item_ct1.get_group_range(1);
float *vals = vals_smem ? buf + sycl::max(nwarps, WARP_SIZE) : dst + rowx * ncols;
float max_val = -INFINITY;
const int64_t i11 = i01;
const int64_t i12 = i02 % p.ne12;
const int64_t i13 = i03 % p.ne13;
x += int64_t(rowx)*ncols;
mask += (i11*p.nb11 + i12*p.nb12 + i13*p.nb13) / sizeof(T) * (mask != nullptr);
dst += int64_t(rowx)*ncols;
const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1);
float * buf_iw = (float *) dpct_local;
// shared memory buffer to cache values between iterations:
float *vals = use_shared ? buf_iw + sycl::max(nwarps, WARP_SIZE) : dst;
float max_val = sinks ? sinks[i02] : -INFINITY;
#pragma unroll
for (int col0 = 0; col0 < ncols; col0 += block_size) {
const int col = col0 + tid;
@@ -39,42 +96,35 @@ static void soft_max_f32(const float * x, const T * mask, float * dst, const int
break;
}
const int ix = rowx*ncols + col;
const int iy = rowy*ncols + col;
const float val = x[ix]*scale + (mask ? slope*static_cast<float>(mask[iy]) : 0.0f);
const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
vals[col] = val;
max_val = sycl::max(max_val, val);
max_val = sycl::max(max_val, val);
}
// find the max value in the block
max_val = warp_reduce_max(max_val, item_ct1);
max_val = warp_reduce_max(max_val);
if (block_size > WARP_SIZE) {
if (warp_id == 0) {
buf[lane_id] = -INFINITY;
for (size_t i = 1; i < nreduce; i += 1) {
buf[lane_id + i * WARP_SIZE] = -INFINITY;
}
buf_iw[lane_id] = -INFINITY;
}
item_ct1.barrier(sycl::access::fence_space::local_space);
item_ct1.barrier();
if (lane_id == 0) {
buf[warp_id] = max_val;
buf_iw[warp_id] = max_val;
}
item_ct1.barrier(sycl::access::fence_space::local_space);
max_val = buf[lane_id];
for (size_t i = 1; i < nreduce; i += 1) {
max_val = sycl::max(max_val, buf[lane_id + i * WARP_SIZE]);
}
max_val = warp_reduce_max(max_val, item_ct1);
}
item_ct1.barrier();
max_val = buf_iw[lane_id];
max_val = warp_reduce_max(max_val);
}
float tmp = 0.0f; // partial sum
float tmp = 0.f;
#pragma unroll
for (int col0 = 0; col0 < ncols; col0 += block_size) {
const int col = col0 + tid;
if (ncols_template == 0 && col >= ncols) {
if (ncols_template == 0 && col >= ncols) {
break;
}
@@ -82,32 +132,33 @@ static void soft_max_f32(const float * x, const T * mask, float * dst, const int
tmp += val;
vals[col] = val;
}
// find the sum of exps in the block
tmp = warp_reduce_sum(tmp, item_ct1);
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
item_ct1.barrier(sycl::access::fence_space::local_space);
item_ct1.barrier();
if (warp_id == 0) {
buf[lane_id] = 0.f;
buf_iw[lane_id] = 0.0f;
for (size_t i = 1; i < nreduce; i += 1) {
buf[lane_id + i * WARP_SIZE] = 0.f;
buf_iw[lane_id + i * WARP_SIZE] = 0.f;
}
}
item_ct1.barrier(sycl::access::fence_space::local_space);
item_ct1.barrier();
if (lane_id == 0) {
buf[warp_id] = tmp;
buf_iw[warp_id] = tmp;
}
item_ct1.barrier(sycl::access::fence_space::local_space);
item_ct1.barrier();
tmp = buf[lane_id];
tmp = buf_iw[lane_id];
for (size_t i = 1; i < nreduce; i += 1) {
tmp += buf[lane_id + i * WARP_SIZE];
tmp += buf_iw[lane_id + i * WARP_SIZE];
}
tmp = warp_reduce_sum(tmp, item_ct1);
tmp = warp_reduce_sum(tmp);
}
const float inv_sum = 1.f / tmp;
if (sinks) {
tmp += sycl::native::exp(sinks[i02] - max_val);
}
const float inv_sum = 1.0f / tmp;
#pragma unroll
for (int col0 = 0; col0 < ncols; col0 += block_size) {
@@ -117,145 +168,259 @@ static void soft_max_f32(const float * x, const T * mask, float * dst, const int
return;
}
const int idst = rowx*ncols + col;
dst[idst] = vals[col] * inv_sum;
dst[col] = vals[col] * inv_sum;
}
}
#ifdef __clang__
#pragma clang diagnostic pop
#endif // __clang__
static void soft_max_back_f32(const float *grad, const float *dstf, float *dst,
const int ncols, const float scale) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
const int tid = item_ct1.get_local_id(2);
const int rowx = item_ct1.get_group(2);
grad += int64_t(rowx)*ncols;
dstf += int64_t(rowx)*ncols;
dst += int64_t(rowx)*ncols;
float dgf_dot = 0.0f; // dot product of dst from forward pass and gradients
for (int col = tid; col < ncols; col += WARP_SIZE) {
dgf_dot += dstf[col]*grad[col];
}
dgf_dot = warp_reduce_sum(dgf_dot);
for (int col = tid; col < ncols; col += WARP_SIZE) {
dst[col] = scale * (grad[col] - dgf_dot) * dstf[col];
}
}
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
static void soft_max_f32_submitter(const float * x, const T * mask, float * dst, const int ncols_par,
const int nrows_y, const float scale, const float max_bias, const float m0,
const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
const size_t n_local_scratch, queue_ptr stream) {
template <int... Ns, typename T>
static void launch_soft_max_kernels(const float * x,
const T * mask,
const float * sinks,
float * dst,
const soft_max_params & p,
dpct::queue_ptr stream,
dpct::dim3 block_dims,
dpct::dim3 block_nums,
size_t nbytes_shared)
{
auto launch_kernel = [=](auto I) -> bool {
constexpr int ncols = decltype(I)::value;
constexpr int block = (ncols > 1024 ? 1024 : ncols);
if (p.ncols == ncols) {
stream->submit([&](sycl::handler &cgh) {
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
sycl::range<1>(nbytes_shared), cgh);
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(
WARP_SIZE)]] {
soft_max_f32<true, ncols, block>(
x, mask, sinks, dst, p,
dpct_local_acc_ct1
.get_multi_ptr<sycl::access::decorated::no>()
.get());
GGML_UNUSED(item_ct1);
});
});
return true;
}
return false;
};
// unary fold over launch_kernel
if ((launch_kernel(std::integral_constant<int, Ns>{}) || ...)) {
return;
}
stream->submit([&](sycl::handler &cgh) {
sycl::local_accessor<float, 1> local_buf_acc(n_local_scratch, cgh);
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
sycl::range<1>(nbytes_shared), cgh);
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
nrows_y, scale, max_bias, m0,
m1, n_head_log2, item_ct1,
get_pointer(local_buf_acc));
});
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
soft_max_f32<true, 0, 0>(
x, mask, sinks, dst, p,
dpct_local_acc_ct1
.get_multi_ptr<sycl::access::decorated::no>()
.get());
GGML_UNUSED(item_ct1);
});
});
}
template<typename T>
static void soft_max_f32_sycl(const float * x, const T * mask,
float * dst, const int ncols_x, const int nrows_x,
const int nrows_y, const float scale, const float max_bias,
queue_ptr stream, int device) {
template <typename T>
static void soft_max_f32_sycl(const float *x, const T *mask,
const float *sinks, float *dst,
const soft_max_params &params,
dpct::queue_ptr stream, int device) {
int nth = WARP_SIZE;
int max_block_size = ggml_sycl_info().max_work_group_sizes[device];
const int64_t ncols_x = params.ncols;
while (nth < ncols_x && nth < max_block_size) nth *= 2;
if (nth>max_block_size) nth = max_block_size;
const sycl::range<3> block_dims(1, 1, nth);
const sycl::range<3> block_nums(1, 1, nrows_x);
const size_t n_val_tmp = nth / WARP_SIZE;
const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + n_val_tmp);
const dpct::dim3 block_dims(nth, 1, 1);
const dpct::dim3 block_nums(params.ne01, params.ne02, params.ne03);
const size_t nbytes_shared =
(GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE) * sizeof(float);
const uint32_t n_head_kv = nrows_x/nrows_y;
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
const int id = get_current_device_id();
const size_t smpbo = ggml_sycl_info().devices[id].smpbo;
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
if (n_local_scratch*sizeof(float) < local_mem_size) {
if (ncols_x > max_block_size) {
soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
max_bias, m0, m1, n_head_log2, block_nums,
block_dims, n_local_scratch, stream);
return;
}
switch (ncols_x) {
case 32:
soft_max_f32_submitter<true, 32, 32>(x, mask, dst, ncols_x, nrows_y, scale,
max_bias, m0, m1, n_head_log2, block_nums,
block_dims, n_local_scratch, stream);
break;
case 64:
soft_max_f32_submitter<true, 64, 64>(x, mask, dst, ncols_x, nrows_y, scale,
max_bias, m0, m1, n_head_log2, block_nums,
block_dims, n_local_scratch, stream);
break;
case 128:
soft_max_f32_submitter<true, 128, 128>(x, mask, dst, ncols_x, nrows_y, scale,
max_bias, m0, m1, n_head_log2, block_nums,
block_dims, n_local_scratch, stream);
break;
case 256:
soft_max_f32_submitter<true, 256, 256>(x, mask, dst, ncols_x, nrows_y, scale,
max_bias, m0, m1, n_head_log2, block_nums,
block_dims, n_local_scratch, stream);
break;
case 512:
soft_max_f32_submitter<true, 512, 512>(x, mask, dst, ncols_x, nrows_y, scale,
max_bias, m0, m1, n_head_log2, block_nums,
block_dims, n_local_scratch, stream);
break;
case 1024:
soft_max_f32_submitter<true, 1024, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
max_bias, m0, m1, n_head_log2, block_nums,
block_dims, n_local_scratch, stream);
break;
case 2048:
soft_max_f32_submitter<true, 2048, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
max_bias, m0, m1, n_head_log2, block_nums,
block_dims, n_local_scratch, stream);
break;
case 4096:
soft_max_f32_submitter<true, 4096, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
max_bias, m0, m1, n_head_log2, block_nums,
block_dims, n_local_scratch, stream);
break;
default:
soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
max_bias, m0, m1, n_head_log2, block_nums,
block_dims, n_local_scratch, stream);
break;
}
if (nbytes_shared <= smpbo) {
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(
x, mask, sinks, dst, params, stream, block_dims, block_nums,
nbytes_shared);
} else {
soft_max_f32_submitter<false, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
max_bias, m0, m1, n_head_log2, block_nums,
block_dims, WARP_SIZE, stream);
const size_t nbytes_shared_low = WARP_SIZE * sizeof(float);
stream->submit([&](sycl::handler &cgh) {
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
sycl::range<1>(nbytes_shared_low), cgh);
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
soft_max_f32<false, 0, 0>(
x, mask, sinks, dst, params,
dpct_local_acc_ct1
.get_multi_ptr<sycl::access::decorated::no>()
.get());
GGML_UNUSED(item_ct1);
});
});
}
}
static void soft_max_back_f32_sycl(const float * grad,
const float * dstf,
float * dst,
const int ncols,
const int nrows,
const float scale,
dpct::queue_ptr stream) {
const dpct::dim3 block_dims(WARP_SIZE, 1, 1);
const dpct::dim3 block_nums(nrows, 1, 1);
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
soft_max_back_f32(grad, dstf, dst, ncols, scale);
GGML_UNUSED(item_ct1);
});
}
void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
const float * src0_d = (const float *) src0->data;
const void * src1_d = src1 ? (const void *) src1->data : nullptr;
const void * src2_d = src2 ? (const void *) src2->data : nullptr;
float * dst_d = (float *) dst->data;
dpct::queue_ptr stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(!dst->src[1] || dst->src[1]->type == GGML_TYPE_F16 || dst->src[1]->type == GGML_TYPE_F32); // src1 contains mask and it is optional
// src1 contains mask and it is optional
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
const int64_t ne00 = dst->src[0]->ne[0];
const int64_t nrows_x = ggml_nrows(dst->src[0]);
const int64_t nrows_y = dst->src[0]->ne[1];
const int64_t nrows_x = ggml_nrows(src0);
const int64_t nrows_y = src0->ne[1];
float scale = 1.0f;
const int64_t ne00 = src0->ne[0];
float scale = 1.0f;
float max_bias = 0.0f;
memcpy(&scale, dst->op_params + 0, sizeof(float));
memcpy(&max_bias, dst->op_params + 1, sizeof(float));
memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
ggml_sycl_set_device(ctx.device);
dpct::queue_ptr main_stream = ctx.stream();
const int64_t nb11 = src1 ? src1->nb[1] : 1;
const int64_t nb12 = src1 ? src1->nb[2] : 1;
const int64_t nb13 = src1 ? src1->nb[3] : 1;
if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F16) {
const sycl::half * src1_dd = static_cast<sycl::half *>(dst->src[1]->data);
soft_max_f32_sycl<sycl::half>(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias,
main_stream, ctx.device);
} else if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F32) {
const float * src1_dd = static_cast<const float *>(dst->src[1]->data);
soft_max_f32_sycl<float>(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
const int64_t ne12 = src1 ? src1->ne[2] : 1;
const int64_t ne13 = src1 ? src1->ne[3] : 1;
const uint32_t n_head = src0->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
soft_max_params params = {};
params.nheads = src0->ne[2];
params.n_head_log2 = n_head_log2;
params.ncols = ne00;
params.nrows_x = nrows_x;
params.nrows_y = nrows_y;
params.ne00 = src0->ne[0];
params.ne01 = src0->ne[1];
params.ne02 = src0->ne[2];
params.ne03 = src0->ne[3];
params.nb11 = nb11;
params.nb12 = nb12;
params.nb13 = nb13;
params.ne12 = ne12;
params.ne13 = ne13;
params.scale = scale;
params.max_bias = max_bias;
params.m0 = m0;
params.m1 = m1;
if (use_f16) {
soft_max_f32_sycl(src0_d, (const sycl::half *)src1_d,
(const float *)src2_d, dst_d, params, stream,
ctx.device);
} else {
/* mask unavailable */
soft_max_f32_sycl<float>(src0_dd, nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
soft_max_f32_sycl(src0_d, (const float *)src1_d, (const float *)src2_d,
dst_d, params, stream, ctx.device);
}
}
void ggml_sycl_op_soft_max_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
const ggml_tensor * src0 = dst->src[0]; // grad
const ggml_tensor * src1 = dst->src[1]; // forward pass output
const float * src0_d = (const float *) src0->data;
const float * src1_d = (const float *) src1->data;
float * dst_d = (float *) dst->data;
dpct::queue_ptr stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
float scale = 1.0f;
float max_bias = 0.0f;
memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
GGML_ASSERT(max_bias == 0.0f);
soft_max_back_f32_sycl(src0_d, src1_d, dst_d, ncols, nrows, scale, stream);
}
+4
View File
@@ -15,6 +15,10 @@
#include "common.hpp"
#define SYCL_SOFT_MAX_BLOCK_SIZE 1024
void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, ggml_tensor *dst);
void ggml_sycl_op_soft_max_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
#endif // GGML_SYCL_SOFTMAX_HPP
+28 -17
View File
@@ -1,5 +1,6 @@
cmake_minimum_required(VERSION 3.19)
cmake_policy(SET CMP0114 NEW)
cmake_policy(SET CMP0116 NEW)
find_package(Vulkan COMPONENTS glslc REQUIRED)
@@ -54,25 +55,25 @@ if (Vulkan_FOUND)
# Test all shader extensions
test_shader_extension_support(
"GL_KHR_cooperative_matrix"
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp"
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/coopmat.comp"
"GGML_VULKAN_COOPMAT_GLSLC_SUPPORT"
)
test_shader_extension_support(
"GL_NV_cooperative_matrix2"
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp"
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/coopmat2.comp"
"GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT"
)
test_shader_extension_support(
"GL_EXT_integer_dot_product"
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_integer_dot_support.comp"
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/integer_dot.comp"
"GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT"
)
test_shader_extension_support(
"GL_EXT_bfloat16"
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_bfloat16_support.comp"
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/bfloat16.comp"
"GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT"
)
@@ -160,7 +161,6 @@ if (Vulkan_FOUND)
set (_ggml_vk_genshaders_dir "${CMAKE_BINARY_DIR}/$<CONFIG>")
set (_ggml_vk_genshaders_cmd "${_ggml_vk_genshaders_dir}/vulkan-shaders-gen${_ggml_vk_host_suffix}")
set (_ggml_vk_header "${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp")
set (_ggml_vk_source "${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp")
set (_ggml_vk_input_dir "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders")
set (_ggml_vk_output_dir "${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv")
@@ -176,24 +176,35 @@ if (Vulkan_FOUND)
add_custom_command(
OUTPUT ${_ggml_vk_header}
${_ggml_vk_source}
COMMAND ${_ggml_vk_genshaders_cmd}
--glslc ${Vulkan_GLSLC_EXECUTABLE}
--input-dir ${_ggml_vk_input_dir}
--output-dir ${_ggml_vk_output_dir}
--target-hpp ${_ggml_vk_header}
--target-cpp ${_ggml_vk_source}
--no-clean
DEPENDS ${_ggml_vk_shader_files}
${_ggml_vk_shaders_gen_sources}
DEPENDS ${_ggml_vk_shaders_gen_sources}
vulkan-shaders-gen
COMMENT "Generate vulkan shaders"
COMMENT "Generate vulkan shaders header"
)
target_sources(ggml-vulkan PRIVATE ${_ggml_vk_header})
target_sources(ggml-vulkan PRIVATE ${_ggml_vk_source} ${_ggml_vk_header})
foreach (file_full ${_ggml_vk_shader_files})
get_filename_component(file ${file_full} NAME)
set (_ggml_vk_target_cpp "${CMAKE_CURRENT_BINARY_DIR}/${file}.cpp")
add_custom_command(
OUTPUT ${_ggml_vk_target_cpp}
DEPFILE ${_ggml_vk_target_cpp}.d
COMMAND ${_ggml_vk_genshaders_cmd}
--glslc ${Vulkan_GLSLC_EXECUTABLE}
--source ${file_full}
--output-dir ${_ggml_vk_output_dir}
--target-hpp ${_ggml_vk_header}
--target-cpp ${_ggml_vk_target_cpp}
DEPENDS ${file_full}
${_ggml_vk_shaders_gen_sources}
vulkan-shaders-gen
COMMENT "Generate vulkan shaders for ${file}"
)
target_sources(ggml-vulkan PRIVATE ${_ggml_vk_target_cpp})
endforeach()
else()
message(WARNING "Vulkan not found")
+95 -103
View File
@@ -393,6 +393,7 @@ struct vk_device_struct {
vk::PhysicalDeviceProperties properties;
std::string name;
uint64_t max_memory_allocation_size;
uint64_t max_buffer_size;
uint64_t suballocation_block_size;
bool fp16;
bool bf16;
@@ -1563,6 +1564,12 @@ typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx
static void ggml_backend_vk_free(ggml_backend_t backend);
static VkDeviceSize ggml_vk_get_max_buffer_range(const ggml_backend_vk_context * ctx, const vk_buffer &buf, const VkDeviceSize offset) {
const VkDeviceSize range = std::min(VkDeviceSize{buf->size - offset},
VkDeviceSize{ctx->device->properties.limits.maxStorageBufferRange});
return range;
}
// Wait for ctx->fence to be signaled.
static void ggml_vk_wait_for_fence(ggml_backend_vk_context * ctx) {
// Use waitForFences while most of the graph executes. Hopefully the CPU can sleep
@@ -2012,8 +2019,8 @@ static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_pr
static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list<vk::MemoryPropertyFlags> & req_flags_list) {
VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags_list.begin()[0]) << ", " << to_string(req_flags_list.begin()[req_flags_list.size()-1]) << ")");
if (size > device->max_memory_allocation_size) {
throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device memory allocation limit");
if (size > device->max_buffer_size) {
throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device buffer size limit");
}
vk_buffer buf = std::make_shared<vk_buffer_struct>();
@@ -2159,8 +2166,8 @@ static void ggml_vk_destroy_buffer(vk_buffer& buf) {
buf.reset();
}
static vk_subbuffer ggml_vk_subbuffer(vk_buffer& buf) {
return { buf, 0, VK_WHOLE_SIZE };
static vk_subbuffer ggml_vk_subbuffer(const ggml_backend_vk_context* ctx, const vk_buffer& buf, size_t offset = 0) {
return { buf, offset, ggml_vk_get_max_buffer_range(ctx, buf, offset) };
}
static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subctx) {
@@ -2614,8 +2621,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
const uint32_t D_lsb = D ^ (D & (D-1));
uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4);
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split};
};
@@ -3855,17 +3860,27 @@ static vk_device ggml_vk_get_device(size_t idx) {
const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE");
if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) {
device->max_memory_allocation_size = std::stoul(GGML_VK_FORCE_MAX_ALLOCATION_SIZE);
device->max_memory_allocation_size = std::stoull(GGML_VK_FORCE_MAX_ALLOCATION_SIZE);
} else if (maintenance4_support) {
device->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize);
} else {
device->max_memory_allocation_size = props3.maxMemoryAllocationSize;
}
const char* GGML_VK_FORCE_MAX_BUFFER_SIZE = getenv("GGML_VK_FORCE_MAX_BUFFER_SIZE");
if (GGML_VK_FORCE_MAX_BUFFER_SIZE != nullptr) {
device->max_buffer_size = std::stoull(GGML_VK_FORCE_MAX_BUFFER_SIZE);
} else if (maintenance4_support) {
device->max_buffer_size = props4.maxBufferSize;
} else {
device->max_buffer_size = device->max_memory_allocation_size;
}
const char* GGML_VK_SUBALLOCATION_BLOCK_SIZE = getenv("GGML_VK_SUBALLOCATION_BLOCK_SIZE");
if (GGML_VK_SUBALLOCATION_BLOCK_SIZE != nullptr) {
device->suballocation_block_size = std::stoul(GGML_VK_SUBALLOCATION_BLOCK_SIZE);
device->suballocation_block_size = std::stoull(GGML_VK_SUBALLOCATION_BLOCK_SIZE);
} else {
// Limit batching of allocations to 1GB by default to avoid fragmentation issues
device->suballocation_block_size = 1024*1024*1024;
@@ -6150,9 +6165,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
}
const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * split_k : 0;
if (
(qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
(qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size) ||
(split_k > 1 && split_k_size > ctx->device->max_memory_allocation_size)) {
(qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) ||
(qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) ||
(split_k > 1 && split_k_size > ctx->device->properties.limits.maxStorageBufferRange)) {
GGML_ABORT("Requested preallocation size is too large");
}
if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
@@ -6227,7 +6242,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
}
if (x_non_contig) {
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0));
} else if (qx_needs_dequant) {
const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
@@ -6239,7 +6254,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0));
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
ctx->prealloc_y_last_tensor_used = src1;
}
@@ -6250,7 +6265,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13, true);
ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne * ne12 * ne13, true);
ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
ctx->prealloc_y_last_tensor_used = src1;
}
@@ -6272,14 +6287,11 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
y_sz_total = CEIL_DIV(y_sz_total, 144) * 144;
}
// No bounds checking is needed for dst. This is basically VK_WHOLE_SIZE but clamped to maxStorageBufferRange.
VkDeviceSize d_range = std::min(VkDeviceSize{d_D->size - d_buf_offset}, VkDeviceSize{ctx->device->properties.limits.maxStorageBufferRange});
// compute
ggml_vk_matmul(
ctx, subctx, pipeline,
{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz_total },
{ d_D, d_buf_offset, d_range }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k },
ggml_vk_subbuffer(ctx, d_D, d_buf_offset), { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k },
ne01, ne11, ne10,
ne10, ne10, stride_d, stride_batch_x, stride_batch_y, stride_batch_d,
split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n
@@ -6446,8 +6458,8 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144;
}
if (
(qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
(qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) {
(qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) ||
(qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) {
GGML_ABORT("Requested preallocation size is too large");
}
if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
@@ -6512,7 +6524,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
}
GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0));
}
if (y_non_contig) {
GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
@@ -6521,7 +6533,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0));
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
ctx->prealloc_y_last_tensor_used = src1;
}
@@ -6532,7 +6544,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13, true);
ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne * ne12 * ne13, true);
ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
ctx->prealloc_y_last_tensor_used = src1;
}
@@ -6931,8 +6943,8 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
const uint64_t y_sz_upd = y_sz * ne12 * ne13;
if (
(qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
(qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) {
(qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) ||
(qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) {
GGML_ABORT("Requested preallocation size is too large");
}
if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
@@ -6999,7 +7011,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
}
if (x_non_contig) {
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0));
} else if (qx_needs_dequant) {
const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0,
@@ -7012,7 +7024,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0));
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
ctx->prealloc_y_last_tensor_used = src1;
}
@@ -7145,8 +7157,8 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
const uint64_t y_sz_upd = y_sz * ne12 * ne13;
if (
(qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
(qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) {
(qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) ||
(qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) {
GGML_ABORT("Requested preallocation size is too large");
}
if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
@@ -7212,7 +7224,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
if (x_non_contig) {
GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0));
}
if (y_non_contig) {
GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
@@ -7221,7 +7233,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0));
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
ctx->prealloc_y_last_tensor_used = src1;
}
@@ -7457,8 +7469,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
if (((HSK | HSV) % 16) != 0 && path == FA_COOPMAT2) {
aligned = false;
}
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
GGML_ASSERT((nem1 % GGML_KQ_MASK_PAD) == 0);
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
@@ -7498,7 +7508,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
// Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
// and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
if (split_k_size > ctx->device->max_memory_allocation_size) {
if (split_k_size > ctx->device->properties.limits.maxStorageBufferRange) {
GGML_ABORT("Requested preallocation size is too large");
}
if (ctx->prealloc_size_split_k < split_k_size) {
@@ -7620,12 +7630,12 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{
vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
ggml_vk_subbuffer(ctx, d_Q, q_buf_offset),
ggml_vk_subbuffer(ctx, d_K, k_buf_offset),
ggml_vk_subbuffer(ctx, d_V, v_buf_offset),
ggml_vk_subbuffer(ctx, d_M, m_buf_offset),
ggml_vk_subbuffer(ctx, d_S, s_buf_offset),
ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0),
},
// We only use split_k when group query attention is enabled, which means
// there's no more than one tile of rows (i.e. workgroups_x would have been
@@ -7637,21 +7647,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
const std::array<uint32_t, 5> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) };
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
{
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0),
ggml_vk_subbuffer(ctx, d_S, s_buf_offset),
ggml_vk_subbuffer(ctx, d_D, d_buf_offset),
},
pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
ctx->prealloc_split_k_need_sync = true;
} else {
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{
vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
ggml_vk_subbuffer(ctx, d_Q, q_buf_offset),
ggml_vk_subbuffer(ctx, d_K, k_buf_offset),
ggml_vk_subbuffer(ctx, d_V, v_buf_offset),
ggml_vk_subbuffer(ctx, d_M, m_buf_offset),
ggml_vk_subbuffer(ctx, d_S, s_buf_offset),
ggml_vk_subbuffer(ctx, d_D, d_buf_offset),
},
pc, { workgroups_x, workgroups_y, workgroups_z });
}
@@ -8360,18 +8370,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
}
}
uint64_t x_sz = ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0;
uint64_t y_sz = use_src1 ? ggml_type_size(src1->type) * ne1 : 0;
uint64_t z_sz = use_src2 ? ggml_type_size(src2->type) * ne2 : 0;
uint64_t d_sz = ggml_type_size(dst->type) * ned;
vk_buffer d_D = dst_buf_ctx->dev_buffer;
// Workaround for tiny tensor inputs on ROPE
if (op == GGML_OP_ROPE && use_src1 && y_sz > d_D->size) {
y_sz = VK_WHOLE_SIZE;
}
GGML_ASSERT(d_D != nullptr);
uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
if(!src0_uma) {
@@ -8396,26 +8396,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
z_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
if (op_supports_incontiguous) {
x_sz = ggml_nbytes(src0) + get_misalign_bytes(ctx, src0);
y_sz = use_src1 ? ggml_nbytes(src1) + get_misalign_bytes(ctx, src1) : 0;
z_sz = use_src2 ? ggml_nbytes(src2) + get_misalign_bytes(ctx, src2) : 0;
d_sz = ggml_nbytes(dst) + get_misalign_bytes(ctx, dst);
if (x_buf_offset + x_sz >= d_X->size) {
x_sz = VK_WHOLE_SIZE;
}
if (use_src1 && y_buf_offset + y_sz >= d_Y->size) {
y_sz = VK_WHOLE_SIZE;
}
if (use_src2 && z_buf_offset + z_sz >= d_Z->size) {
z_sz = VK_WHOLE_SIZE;
}
if (d_buf_offset + d_sz >= d_D->size) {
d_sz = VK_WHOLE_SIZE;
}
}
std::array<uint32_t, 3> elements;
// Single call if dimension 2 is contiguous
@@ -8606,19 +8586,31 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
break;
}
if (!op_supports_incontiguous) {
if (x_sz != VK_WHOLE_SIZE) {
x_sz *= ne02 * ne03;
uint64_t x_sz, y_sz, z_sz, d_sz;
if (op_supports_incontiguous) {
x_sz = ggml_nbytes(src0) + get_misalign_bytes(ctx, src0);
y_sz = use_src1 ? ggml_nbytes(src1) + get_misalign_bytes(ctx, src1) : 0;
z_sz = use_src2 ? ggml_nbytes(src2) + get_misalign_bytes(ctx, src2) : 0;
d_sz = ggml_nbytes(dst) + get_misalign_bytes(ctx, dst);
if (x_buf_offset + x_sz >= d_X->size) {
x_sz = ggml_vk_get_max_buffer_range(ctx, d_X, x_buf_offset);
}
if (use_src1 && y_sz != VK_WHOLE_SIZE) {
y_sz *= ne12 * ne13;
if (use_src1 && y_buf_offset + y_sz >= d_Y->size) {
y_sz = ggml_vk_get_max_buffer_range(ctx, d_Y, y_buf_offset);
}
if (use_src2 && z_sz != VK_WHOLE_SIZE) {
z_sz *= ne22 * ne23;
if (use_src2 && z_buf_offset + z_sz >= d_Z->size) {
z_sz = ggml_vk_get_max_buffer_range(ctx, d_Z, z_buf_offset);
}
if (d_sz != VK_WHOLE_SIZE) {
d_sz *= ned2 * ned3;
if (d_buf_offset + d_sz >= d_D->size) {
d_sz = ggml_vk_get_max_buffer_range(ctx, d_D, d_buf_offset);
}
} else {
x_sz = ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0 * ne02 * ne03;
y_sz = use_src1 ? ggml_type_size(src1->type) * ne1 * ne12 * ne13 : 0;
z_sz = use_src2 ? ggml_type_size(src2->type) * ne2 * ne22 * ne23 : 0;
d_sz = ggml_type_size(dst->type) * ned * ned2 * ned3;
}
if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) {
@@ -8628,7 +8620,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
{ vk_subbuffer{ d_X, x_buf_offset, x_sz },
vk_subbuffer{ d_Y, y_buf_offset, y_sz },
vk_subbuffer{ d_D, d_buf_offset, d_sz },
vk_subbuffer{ d_A, a_buf_offset, VK_WHOLE_SIZE },
ggml_vk_subbuffer(ctx, d_A, a_buf_offset),
}, pc, elements);
} else if (op == GGML_OP_GLU) {
// Empty src1 is possible in glu, but the shader needs a buffer
@@ -8821,18 +8813,18 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx,
static_assert(MAX_PARAMETER_COUNT == 12);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{
vk_subbuffer{ buf[0], offset[0], VK_WHOLE_SIZE },
vk_subbuffer{ buf[1], offset[1], VK_WHOLE_SIZE },
vk_subbuffer{ buf[2], offset[2], VK_WHOLE_SIZE },
vk_subbuffer{ buf[3], offset[3], VK_WHOLE_SIZE },
vk_subbuffer{ buf[4], offset[4], VK_WHOLE_SIZE },
vk_subbuffer{ buf[5], offset[5], VK_WHOLE_SIZE },
vk_subbuffer{ buf[6], offset[6], VK_WHOLE_SIZE },
vk_subbuffer{ buf[7], offset[7], VK_WHOLE_SIZE },
vk_subbuffer{ buf[8], offset[8], VK_WHOLE_SIZE },
vk_subbuffer{ buf[9], offset[9], VK_WHOLE_SIZE },
vk_subbuffer{ buf[10], offset[10], VK_WHOLE_SIZE },
vk_subbuffer{ buf[11], offset[11], VK_WHOLE_SIZE },
ggml_vk_subbuffer(ctx, buf[0], offset[0]),
ggml_vk_subbuffer(ctx, buf[1], offset[1]),
ggml_vk_subbuffer(ctx, buf[2], offset[2]),
ggml_vk_subbuffer(ctx, buf[3], offset[3]),
ggml_vk_subbuffer(ctx, buf[4], offset[4]),
ggml_vk_subbuffer(ctx, buf[5], offset[5]),
ggml_vk_subbuffer(ctx, buf[6], offset[6]),
ggml_vk_subbuffer(ctx, buf[7], offset[7]),
ggml_vk_subbuffer(ctx, buf[8], offset[8]),
ggml_vk_subbuffer(ctx, buf[9], offset[9]),
ggml_vk_subbuffer(ctx, buf[10], offset[10]),
ggml_vk_subbuffer(ctx, buf[11], offset[11]),
}, pc, elements);
}
@@ -10006,7 +9998,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
ggml_vk_ctx_begin(ctx->device, subctx);
for (size_t i = 0; i < num_it; i++) {
ggml_vk_matmul(
ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k),
ctx, subctx, p, ggml_vk_subbuffer(ctx, d_X), ggml_vk_subbuffer(ctx, d_Y), ggml_vk_subbuffer(ctx, d_D), ggml_vk_subbuffer(ctx, ctx->prealloc_split_k),
m, n, k,
k, k, m, k*m, k*n, m*n,
split_k, batch, batch, batch, 1, 1, n
@@ -10317,7 +10309,7 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
//
// vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
// ggml_vk_ctx_begin(ctx->device, subctx);
// ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(x_buf), ggml_vk_subbuffer(qx_buf), ne);
// ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, x_buf), ggml_vk_subbuffer(ctx, qx_buf), ne);
// ggml_vk_ctx_end(subctx);
//
// auto begin = std::chrono::high_resolution_clock::now();
+2 -2
View File
@@ -1,7 +1,7 @@
#version 450
#include "types.comp"
#include "generic_binary_head.comp"
#include "types.glsl"
#include "generic_binary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+2 -2
View File
@@ -6,8 +6,8 @@
#extension GL_KHR_shader_subgroup_basic : enable
#endif
#include "types.comp"
#include "generic_binary_head.comp"
#include "types.glsl"
#include "generic_binary_head.glsl"
const uint num_threads = 256;
@@ -2,7 +2,7 @@
#extension GL_EXT_control_flow_attributes : require
#include "types.comp"
#include "types.glsl"
layout (push_constant) uniform parameter
{
@@ -1,7 +1,7 @@
#version 450
#include "generic_head.comp"
#include "types.comp"
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
@@ -1,7 +1,7 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#include "types.comp"
#include "types.glsl"
layout(constant_id = 0) const int BLOCK_SIZE = 1024;
layout(constant_id = 1) const int BLOCK_SIZE_LOG2 = 10;

Some files were not shown because too many files have changed in this diff Show More